mirror of
https://github.com/tritant/ComfyUI_SuperScaler.git
synced 2025-12-16 08:37:42 +01:00
Add mask in
This commit is contained in:
@@ -41,6 +41,7 @@ class SuperScaler:
|
|||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
# --- SECTION 1: LATENT REFINEMENT (PASS 1) ---
|
# --- SECTION 1: LATENT REFINEMENT (PASS 1) ---
|
||||||
|
"mask_in": ("MASK",),
|
||||||
"enable_latent_pass": ("BOOLEAN", {"default": False}),
|
"enable_latent_pass": ("BOOLEAN", {"default": False}),
|
||||||
"model_pass_1": ("MODEL",),
|
"model_pass_1": ("MODEL",),
|
||||||
"vae_pass_1": ("VAE",),
|
"vae_pass_1": ("VAE",),
|
||||||
@@ -93,6 +94,7 @@ class SuperScaler:
|
|||||||
"grain_size": ("FLOAT", {"default": 1.3, "min": 1.0, "max": 16.0, "step": 0.1}),
|
"grain_size": ("FLOAT", {"default": 1.3, "min": 1.0, "max": 16.0, "step": 0.1}),
|
||||||
"saturation_mix": ("FLOAT", {"default": 0.20, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"saturation_mix": ("FLOAT", {"default": 0.20, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
"adaptive_grain": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 2.0, "step": 0.01}),
|
"adaptive_grain": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 2.0, "step": 0.01}),
|
||||||
|
"mask_blend_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -388,6 +390,14 @@ class SuperScaler:
|
|||||||
|
|
||||||
# --- FONCTION PRINCIPALE ---
|
# --- FONCTION PRINCIPALE ---
|
||||||
def process(self, image_in, **kwargs):
|
def process(self, image_in, **kwargs):
|
||||||
|
|
||||||
|
# --- NOUVEAU : Sauvegarde de l'image originale ---
|
||||||
|
original_clean = image_in.clone()
|
||||||
|
|
||||||
|
# Récupérer les nouveaux arguments
|
||||||
|
mask_in = kwargs.get("mask_in", None)
|
||||||
|
mask_blend_weight = kwargs.get("mask_blend_weight", 1.0)
|
||||||
|
|
||||||
global_seed = kwargs.get("seed", 0)
|
global_seed = kwargs.get("seed", 0)
|
||||||
# Récupérer tous les arguments
|
# Récupérer tous les arguments
|
||||||
enable_latent_pass = kwargs.get("enable_latent_pass", False)
|
enable_latent_pass = kwargs.get("enable_latent_pass", False)
|
||||||
@@ -512,6 +522,59 @@ class SuperScaler:
|
|||||||
adaptive_grain=kwargs.get("adaptive_grain", 0.30)
|
adaptive_grain=kwargs.get("adaptive_grain", 0.30)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# --- NOUVEAU : PASS 5 (FINAL MASKED BLEND) ---
|
||||||
|
# current_image est notre image finale "traitée" (sur CPU)
|
||||||
|
# On vérifie si un masque est fourni ET si le poids est > 0
|
||||||
|
if mask_in is not None and mask_blend_weight > 0.0:
|
||||||
|
print(f"[SuperScaler] PASS 5: Mélange final avec masque (Poids: {mask_blend_weight})")
|
||||||
|
|
||||||
|
device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
|
# 1. Préparer l'image finale traitée (sur GPU)
|
||||||
|
final_processed = current_image.to(device)
|
||||||
|
B, H, W, C = final_processed.shape
|
||||||
|
|
||||||
|
# 2. Préparer l'image originale "propre" (upscalée simplement)
|
||||||
|
original_clean_nchw = original_clean.to(device).permute(0, 3, 1, 2)
|
||||||
|
original_clean_upscaled_nchw = F.interpolate(
|
||||||
|
original_clean_nchw,
|
||||||
|
size=(H, W),
|
||||||
|
mode="bicubic",
|
||||||
|
antialias=True
|
||||||
|
)
|
||||||
|
original_clean_upscaled = original_clean_upscaled_nchw.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
# 3. Préparer le masque (redimensionné à la taille finale)
|
||||||
|
# Le masque est (B, H, W), on doit le passer en (B, 1, H, W) pour interpolate
|
||||||
|
mask_nchw = mask_in.to(device).reshape(B, 1, mask_in.shape[1], mask_in.shape[2])
|
||||||
|
mask_resized_nchw = F.interpolate(
|
||||||
|
mask_nchw,
|
||||||
|
size=(H, W),
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
# Re-permuter en (B, H, W, 1) pour le blending
|
||||||
|
mask_final = mask_resized_nchw.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
# 4. Logique de mélange
|
||||||
|
# Définir ce qu'est la "zone protégée" (le ciel)
|
||||||
|
# Le poids contrôle à quel point la zone protégée est "propre"
|
||||||
|
protected_image = (original_clean_upscaled * mask_blend_weight) + \
|
||||||
|
(final_processed * (1.0 - mask_blend_weight))
|
||||||
|
|
||||||
|
# Logique INVERSÉE comme demandé :
|
||||||
|
# Masque NOIR (0.0) -> protected_image
|
||||||
|
# Masque BLANC (1.0) -> final_processed
|
||||||
|
current_image = (final_processed * (1.0 - mask_final)) + (protected_image * mask_final)
|
||||||
|
|
||||||
|
# Nettoyage
|
||||||
|
del final_processed, original_clean_nchw, original_clean_upscaled, mask_nchw, mask_resized_nchw, mask_final
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# Retourner l'image finale depuis le GPU vers le CPU
|
||||||
|
return (current_image.cpu(),)
|
||||||
|
|
||||||
return (current_image,)
|
return (current_image,)
|
||||||
|
|
||||||
# --- Enregistrement du Node ---
|
# --- Enregistrement du Node ---
|
||||||
|
|||||||
Reference in New Issue
Block a user