diff --git a/superscaler.py b/superscaler.py index f5b6906..da1b710 100644 --- a/superscaler.py +++ b/superscaler.py @@ -41,6 +41,7 @@ class SuperScaler: }, "optional": { # --- SECTION 1: LATENT REFINEMENT (PASS 1) --- + "mask_in": ("MASK",), "enable_latent_pass": ("BOOLEAN", {"default": False}), "model_pass_1": ("MODEL",), "vae_pass_1": ("VAE",), @@ -93,6 +94,7 @@ class SuperScaler: "grain_size": ("FLOAT", {"default": 1.3, "min": 1.0, "max": 16.0, "step": 0.1}), "saturation_mix": ("FLOAT", {"default": 0.20, "min": 0.0, "max": 1.0, "step": 0.01}), "adaptive_grain": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 2.0, "step": 0.01}), + "mask_blend_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), } } @@ -388,6 +390,14 @@ class SuperScaler: # --- FONCTION PRINCIPALE --- def process(self, image_in, **kwargs): + + # --- NOUVEAU : Sauvegarde de l'image originale --- + original_clean = image_in.clone() + + # Récupérer les nouveaux arguments + mask_in = kwargs.get("mask_in", None) + mask_blend_weight = kwargs.get("mask_blend_weight", 1.0) + global_seed = kwargs.get("seed", 0) # Récupérer tous les arguments enable_latent_pass = kwargs.get("enable_latent_pass", False) @@ -512,6 +522,59 @@ class SuperScaler: adaptive_grain=kwargs.get("adaptive_grain", 0.30) ) + # --- NOUVEAU : PASS 5 (FINAL MASKED BLEND) --- + # current_image est notre image finale "traitée" (sur CPU) + # On vérifie si un masque est fourni ET si le poids est > 0 + if mask_in is not None and mask_blend_weight > 0.0: + print(f"[SuperScaler] PASS 5: Mélange final avec masque (Poids: {mask_blend_weight})") + + device = comfy.model_management.get_torch_device() + + # 1. Préparer l'image finale traitée (sur GPU) + final_processed = current_image.to(device) + B, H, W, C = final_processed.shape + + # 2. Préparer l'image originale "propre" (upscalée simplement) + original_clean_nchw = original_clean.to(device).permute(0, 3, 1, 2) + original_clean_upscaled_nchw = F.interpolate( + original_clean_nchw, + size=(H, W), + mode="bicubic", + antialias=True + ) + original_clean_upscaled = original_clean_upscaled_nchw.permute(0, 2, 3, 1) + + # 3. Préparer le masque (redimensionné à la taille finale) + # Le masque est (B, H, W), on doit le passer en (B, 1, H, W) pour interpolate + mask_nchw = mask_in.to(device).reshape(B, 1, mask_in.shape[1], mask_in.shape[2]) + mask_resized_nchw = F.interpolate( + mask_nchw, + size=(H, W), + mode="bilinear", + align_corners=False + ) + # Re-permuter en (B, H, W, 1) pour le blending + mask_final = mask_resized_nchw.permute(0, 2, 3, 1) + + # 4. Logique de mélange + # Définir ce qu'est la "zone protégée" (le ciel) + # Le poids contrôle à quel point la zone protégée est "propre" + protected_image = (original_clean_upscaled * mask_blend_weight) + \ + (final_processed * (1.0 - mask_blend_weight)) + + # Logique INVERSÉE comme demandé : + # Masque NOIR (0.0) -> protected_image + # Masque BLANC (1.0) -> final_processed + current_image = (final_processed * (1.0 - mask_final)) + (protected_image * mask_final) + + # Nettoyage + del final_processed, original_clean_nchw, original_clean_upscaled, mask_nchw, mask_resized_nchw, mask_final + gc.collect() + torch.cuda.empty_cache() + + # Retourner l'image finale depuis le GPU vers le CPU + return (current_image.cpu(),) + return (current_image,) # --- Enregistrement du Node ---