Add mask in

This commit is contained in:
tritant
2025-11-05 16:51:01 +01:00
committed by GitHub
parent 6b9aaa4920
commit 6c6668d354

View File

@@ -41,6 +41,7 @@ class SuperScaler:
},
"optional": {
# --- SECTION 1: LATENT REFINEMENT (PASS 1) ---
"mask_in": ("MASK",),
"enable_latent_pass": ("BOOLEAN", {"default": False}),
"model_pass_1": ("MODEL",),
"vae_pass_1": ("VAE",),
@@ -93,6 +94,7 @@ class SuperScaler:
"grain_size": ("FLOAT", {"default": 1.3, "min": 1.0, "max": 16.0, "step": 0.1}),
"saturation_mix": ("FLOAT", {"default": 0.20, "min": 0.0, "max": 1.0, "step": 0.01}),
"adaptive_grain": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 2.0, "step": 0.01}),
"mask_blend_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
}
}
@@ -388,6 +390,14 @@ class SuperScaler:
# --- FONCTION PRINCIPALE ---
def process(self, image_in, **kwargs):
# --- NOUVEAU : Sauvegarde de l'image originale ---
original_clean = image_in.clone()
# Récupérer les nouveaux arguments
mask_in = kwargs.get("mask_in", None)
mask_blend_weight = kwargs.get("mask_blend_weight", 1.0)
global_seed = kwargs.get("seed", 0)
# Récupérer tous les arguments
enable_latent_pass = kwargs.get("enable_latent_pass", False)
@@ -512,6 +522,59 @@ class SuperScaler:
adaptive_grain=kwargs.get("adaptive_grain", 0.30)
)
# --- NOUVEAU : PASS 5 (FINAL MASKED BLEND) ---
# current_image est notre image finale "traitée" (sur CPU)
# On vérifie si un masque est fourni ET si le poids est > 0
if mask_in is not None and mask_blend_weight > 0.0:
print(f"[SuperScaler] PASS 5: Mélange final avec masque (Poids: {mask_blend_weight})")
device = comfy.model_management.get_torch_device()
# 1. Préparer l'image finale traitée (sur GPU)
final_processed = current_image.to(device)
B, H, W, C = final_processed.shape
# 2. Préparer l'image originale "propre" (upscalée simplement)
original_clean_nchw = original_clean.to(device).permute(0, 3, 1, 2)
original_clean_upscaled_nchw = F.interpolate(
original_clean_nchw,
size=(H, W),
mode="bicubic",
antialias=True
)
original_clean_upscaled = original_clean_upscaled_nchw.permute(0, 2, 3, 1)
# 3. Préparer le masque (redimensionné à la taille finale)
# Le masque est (B, H, W), on doit le passer en (B, 1, H, W) pour interpolate
mask_nchw = mask_in.to(device).reshape(B, 1, mask_in.shape[1], mask_in.shape[2])
mask_resized_nchw = F.interpolate(
mask_nchw,
size=(H, W),
mode="bilinear",
align_corners=False
)
# Re-permuter en (B, H, W, 1) pour le blending
mask_final = mask_resized_nchw.permute(0, 2, 3, 1)
# 4. Logique de mélange
# Définir ce qu'est la "zone protégée" (le ciel)
# Le poids contrôle à quel point la zone protégée est "propre"
protected_image = (original_clean_upscaled * mask_blend_weight) + \
(final_processed * (1.0 - mask_blend_weight))
# Logique INVERSÉE comme demandé :
# Masque NOIR (0.0) -> protected_image
# Masque BLANC (1.0) -> final_processed
current_image = (final_processed * (1.0 - mask_final)) + (protected_image * mask_final)
# Nettoyage
del final_processed, original_clean_nchw, original_clean_upscaled, mask_nchw, mask_resized_nchw, mask_final
gc.collect()
torch.cuda.empty_cache()
# Retourner l'image finale depuis le GPU vers le CPU
return (current_image.cpu(),)
return (current_image,)
# --- Enregistrement du Node ---