diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..8774451 --- /dev/null +++ b/__init__.py @@ -0,0 +1,9 @@ +from .superscaler import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS + +WEB_DIRECTORY = "./js" + +print("-------------------------------------------------") +print(" Node 'Pipeline SuperScaler' chargé avec succès ") +print("-------------------------------------------------") + +__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', 'WEB_DIRECTORY'] \ No newline at end of file diff --git a/superscaler.py b/superscaler.py new file mode 100644 index 0000000..74a1800 --- /dev/null +++ b/superscaler.py @@ -0,0 +1,523 @@ +import torch +import numpy as np +import comfy.utils +import nodes +import gc +from scipy.ndimage import gaussian_filter +import torch.nn.functional as F + +# (Optionnel, mais recommandé pour le Post-FX si ton code l'utilise) +try: + import cv2 +except ImportError: + print("Avertissement: OpenCV (cv2) n'est pas installé. Le module de sharpening pourrait ne pas fonctionner s'il en dépend.") + +class SuperScaler: + """ + Le node "Pipeline SuperScaler" tout-en-un. + Combine un raffinage latent (Pass 1), deux passes d'upscale génératif tiled (Pass 2, Pass 3), + et un post-traitement (Sharpening & Grain). + """ + + # --- Conversion Tenseur <-> NumPy (pour le Post-FX) --- + def tensor_to_np(self, tensor: torch.Tensor) -> np.ndarray: + """Convertit un tenseur d'image ComfyUI (BCHW, float32) en NumPy (HWC, uint8).""" + image_np = tensor.cpu().numpy().squeeze(0) # (B, H, W, C) -> (H, W, C) + image_np = np.clip(image_np * 255.0, 0, 255).astype(np.uint8) + return image_np + + def np_to_tensor(self, image_np: np.ndarray) -> torch.Tensor: + """Convertit un NumPy (HWC, uint8) en tenseur d'image ComfyUI (BCHW, float32).""" + image_tensor = torch.from_numpy(image_np.astype(np.float32) / 255.0) + image_tensor = image_tensor.unsqueeze(0) # (H, W, C) -> (B, H, W, C) + return image_tensor + + # --- Définition du Node --- + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "image_in": ("IMAGE",), + }, + "optional": { + # --- SECTION 1: LATENT REFINEMENT (PASS 1) --- + "enable_latent_pass": ("BOOLEAN", {"default": False}), + "model_pass_1": ("MODEL",), + "vae_pass_1": ("VAE",), + "positive_pass_1": ("CONDITIONING",), + "negative_pass_1": ("CONDITIONING",), + "latent_upscale_by": ("FLOAT", {"default": 1.1, "min": 1.0, "max": 4.0, "step": 0.1}), + "latent_denoise": ("FLOAT", {"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.01}), + "latent_sampler_name": (comfy.samplers.KSampler.SAMPLERS,), + "latent_scheduler": (comfy.samplers.KSampler.SCHEDULERS,), + "latent_steps": ("INT", {"default": 4, "min": 1, "max": 1000}), + "latent_cfg": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0}), + + # --- SECTION 2: TILED GENERATIVE UPSCALE (PASS 2) --- (MODIFIÉ: _2 ajouté) + "enable_tiled_pass_2": ("BOOLEAN", {"default": True}), + "model_pass_2": ("MODEL",), + "vae_pass_2": ("VAE",), + "positive_pass_2": ("CONDITIONING",), + "negative_pass_2": ("CONDITIONING",), + "tiled_upscale_by_2": ("FLOAT", {"default": 2.0, "min": 1.0, "max": 8.0, "step": 0.1}), + "tiled_denoise_2": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.01}), + "tile_size_2": ("INT", {"default": 768, "min": 256, "max": 2048, "step": 64}), + "tile_overlap_2": ("INT", {"default": 128, "min": 32, "max": 512, "step": 32}), + "tiled_sampler_name_2": (comfy.samplers.KSampler.SAMPLERS,), + "tiled_scheduler_2": (comfy.samplers.KSampler.SCHEDULERS,), + "tiled_steps_2": ("INT", {"default": 8, "min": 1, "max": 1000}), + "tiled_cfg_2": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), + + # --- SECTION 3: TILED GENERATIVE UPSCALE (PASS 3) --- (NOUVEAU) + "enable_tiled_pass_3": ("BOOLEAN", {"default": False}), + "model_pass_3": ("MODEL",), + "vae_pass_3": ("VAE",), + "positive_pass_3": ("CONDITIONING",), + "negative_pass_3": ("CONDITIONING",), + "tiled_upscale_by_3": ("FLOAT", {"default": 2.0, "min": 1.0, "max": 8.0, "step": 0.1}), + "tiled_denoise_3": ("FLOAT", {"default": 0.20, "min": 0.0, "max": 1.0, "step": 0.01}), + "tile_size_3": ("INT", {"default": 768, "min": 256, "max": 2048, "step": 64}), + "tile_overlap_3": ("INT", {"default": 128, "min": 32, "max": 512, "step": 32}), + "tiled_sampler_name_3": (comfy.samplers.KSampler.SAMPLERS,), + "tiled_scheduler_3": (comfy.samplers.KSampler.SCHEDULERS,), + "tiled_steps_3": ("INT", {"default": 8, "min": 1, "max": 1000}), + "tiled_cfg_3": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0}), + + # --- SECTION 4: POST-PROCESSING FX --- + "enable_sharpen": ("BOOLEAN", {"default": False}), + "sharpen_amount": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 5.0, "step": 0.05}), + "sharpen_radius": ("INT", {"default": 1, "min": 1, "max": 20, "step": 1}), + "enable_grain": ("BOOLEAN", {"default": False}), + "grain_type": (["poisson", "gaussian", "perlin"], {"default": "poisson"}), + "grain_intensity": ("FLOAT", {"default": 0.022, "min": 0.001, "max": 1.0, "step": 0.001}), + "grain_size": ("FLOAT", {"default": 1.5, "min": 1.0, "max": 16.0, "step": 0.1}), + "saturation_mix": ("FLOAT", {"default": 0.22, "min": 0.0, "max": 1.0, "step": 0.01}), + "adaptive_grain": ("FLOAT", {"default": 0.30, "min": 0.0, "max": 2.0, "step": 0.01}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + } + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("image_out",) + FUNCTION = "process" + CATEGORY = "upscaling" + + # --- MODULE 1: RAFFINAGE LATENT --- + def run_latent_refine(self, image, model, vae, positive, negative, + upscale_by, denoise, sampler, scheduler, steps, cfg, seed): + + print(f"[SuperScaler] PASS 1: Raffinage Latent (x{upscale_by})") + + # 1. Encoder l'image + latent = vae.encode(image) + + # 2. Upscaler le latent + b, c, h, w = latent.shape + new_h, new_w = int(h * upscale_by), int(w * upscale_by) + latent_upscaled = torch.nn.functional.interpolate(latent, size=(new_h, new_w), mode="bilinear", align_corners=False) + + # 3. Appeler le KSampler + refined_latent = nodes.common_ksampler( + model=model, + seed=seed, + steps=steps, + cfg=cfg, + sampler_name=sampler, + scheduler=scheduler, + positive=positive, + negative=negative, + latent={"samples": latent_upscaled}, + denoise=denoise + )[0] # [0] pour obtenir le latent de sortie + + # 4. Décoder + refined_image = vae.decode(refined_latent["samples"]) + + del latent, latent_upscaled, refined_latent + gc.collect() + torch.cuda.empty_cache() + + return refined_image + + # --- MODULE 2 & 3: UPSCALE TILED --- (MODIFIÉ: fonction générique) + def _run_tiled_pass(self, pass_name, image, model, vae, positive, negative, + upscale_by, denoise, tile_size, overlap, + sampler, scheduler, steps, cfg, seed): + + print(f"[SuperScaler] {pass_name}: Upscale Tiled (x{upscale_by})") + + device = comfy.model_management.get_torch_device() + image = image.to(device) + + # 1. Upscale "bête" (Pixel space) + b, h, w, c = image.shape + new_h, new_w = int(h * upscale_by), int(w * upscale_by) + # Permuter pour NCHW pour l'interpolation + image_nchw = image.permute(0, 3, 1, 2) + image_blurry_large = torch.nn.functional.interpolate( + image_nchw, size=(new_h, new_w), mode="bicubic", antialias=True + ) + + # Tenseur de sortie final, initialisé en noir + output_image = torch.zeros_like(image_blurry_large) + # Tenseur de "comptage" pour la moyenne des chevauchements + overlap_count = torch.zeros_like(image_blurry_large) + + # 2. Créer le masque de fondu (feather mask) + feather_mask = self._create_feather_mask(tile_size, tile_size, overlap, device=device) + + # 3. Boucle de Tiling + stride = tile_size - overlap + + # --- Début des Logs --- + y_steps = list(range(0, new_h, stride)) + x_steps = list(range(0, new_w, stride)) + total_tiles = len(y_steps) * len(x_steps) + current_tile = 0 + print(f"[SuperScaler] {pass_name}: Démarrage de l'upscale, {total_tiles} tuiles à calculer...") + # --- Fin des Logs --- + + for y in y_steps: # Utiliser les listes pré-calculées + for x in x_steps: # Utiliser les listes pré-calculées + + # --- Ajout du compteur --- + current_tile += 1 + print(f"[SuperScaler] {pass_name}: Inférence de la tuile {current_tile}/{total_tiles}...") + # --- Fin du compteur --- + + # S'assurer que la tuile ne dépasse pas les bords + y_end = min(y + tile_size, new_h) + x_end = min(x + tile_size, new_w) + y_start = max(0, y_end - tile_size) + x_start = max(0, x_end - tile_size) + + # Découper la tuile + tile = image_blurry_large[:, :, y_start:y_end, x_start:x_end] + + # 4. KSampler sur la tuile (Encode -> KSample -> Decode) + tile_latent = vae.encode(tile.permute(0, 2, 3, 1)) # NCHW -> NHWC + + tile_refined_latent = nodes.common_ksampler( + model=model, + seed=seed, # Utiliser la même seed pour la cohérence + steps=steps, + cfg=cfg, + sampler_name=sampler, + scheduler=scheduler, + positive=positive, + negative=negative, + latent={"samples": tile_latent}, + denoise=denoise + )[0] + + tile_refined = vae.decode(tile_refined_latent["samples"]) # Sortie en NHWC + tile_refined = tile_refined.permute(0, 3, 1, 2).to(device) # NHWC -> NCHW ET move to GPU + + # 5. Appliquer le masque de fondu et ré-assembler + current_mask = feather_mask + # Gérer les tuiles de bord qui sont plus petites + if tile_refined.shape[2] != tile_size or tile_refined.shape[3] != tile_size: + current_mask = self._create_feather_mask(tile_refined.shape[2], tile_refined.shape[3], overlap, device=device) + + output_image[:, :, y_start:y_end, x_start:x_end] += tile_refined * current_mask + overlap_count[:, :, y_start:y_end, x_start:x_end] += current_mask + + del tile, tile_latent, tile_refined_latent, tile_refined + + # Normaliser les zones de chevauchement + final_image = (output_image / overlap_count) + # Remplacer tous les NaN (résultats de 0/0) par 0.0 (noir) + final_image = torch.nan_to_num(final_image, nan=0.0) + + del image_blurry_large, output_image, overlap_count, feather_mask + gc.collect() + torch.cuda.empty_cache() + + # Re-permuter en NHWC (format IMAGE Comfy) + return final_image.permute(0, 2, 3, 1).cpu() + + def _create_feather_mask(self, h, w, overlap, device="cpu"): + """Crée un masque de fondu (linear blend) pour le tiling.""" + overlap_x = min(overlap, w // 2) + overlap_y = min(overlap, h // 2) + + # Créer les rampes linéaires + linear_x = torch.ones(w, device=device) + if overlap_x > 0: + linear_x[:overlap_x] = torch.linspace(0, 1, overlap_x, device=device) + linear_x[-overlap_x:] = torch.linspace(1, 0, overlap_x, device=device) + + linear_y = torch.ones(h, device=device) + if overlap_y > 0: + linear_y[:overlap_y] = torch.linspace(0, 1, overlap_y, device=device) + linear_y[-overlap_y:] = torch.linspace(1, 0, overlap_y, device=device) + + # Combiner en 2D + mask_2d = torch.outer(linear_y, linear_x) + return mask_2d.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) + + + # --- MODULE 4: POST-FX --- + def run_freq_split_sharpen(self, image_tensor, amount, radius): + strength = amount + print(f"[SuperScaler] PASS 4.1: Sharpening (Amount: {amount}, Radius: {radius})") + + if strength == 0: + return (image_tensor,) # Retourne l'original si force = 0 + + # --- TON CODE INSÉRÉ ICI --- + # Note: 'image_tensor' est déjà un float 0..1 (BHWC) + image_np = image_tensor.clone().numpy()[0] # BHWC -> HWC + low_freq = gaussian_filter(image_np, sigma=(radius, radius, 0)) + high_freq = image_np - low_freq + processed_high_freq = high_freq * strength + final_image_np = low_freq + processed_high_freq + final_image_np = np.clip(final_image_np, 0, 1) + final_image_tensor = torch.from_numpy(final_image_np).unsqueeze(0) # HWC -> BHWC + # --- FIN DE TON CODE --- + + return final_image_tensor + + def run_add_grain(self, images, grain_type, grain_intensity, grain_size, saturation_mix, adaptive_grain): + print(f"[SuperScaler] PASS 4.2: Grain (Type: {grain_type}, Intensity: {grain_intensity})") + + device = comfy.model_management.get_torch_device() + images = images.to(device) # Doit être un Tenseur (B, H, W, C) + + if grain_intensity == 0.0: + return images + + if grain_type == "gaussian": + grain = self._generate_gaussian(images, grain_size) + elif grain_type == "poisson": + grain = self._generate_poisson(images, grain_size) + elif grain_type == "perlin": + grain = self._generate_perlin(images, grain_size) + else: + # Fallback au cas où + print(f"[SuperScaler] Type de grain inconnu '{grain_type}', utilisation de 'poisson'") + grain = self._generate_poisson(images, grain_size) + + # Logique de Saturation Mix + gray = grain[:, :, :, 1].unsqueeze(3).repeat(1, 1, 1, 3) + grain = saturation_mix * grain + (1.0 - saturation_mix) * gray + + # Logique Adaptive Grain + if adaptive_grain > 0.0: + luma = images.mean(dim=3, keepdim=True) + gain = (1.0 - luma).pow(2.0) * 2.5 + grain = grain * (1.0 + adaptive_grain * gain) + + # Appliquer le grain + output = images + grain * grain_intensity + output = output.clamp(0.0, 1.0) + + # Retourne le tenseur (pas un tuple) + return output.to(comfy.model_management.intermediate_device()) + + def _generate_gaussian(self, images, size): + B, H, W, C = images.shape + size = int(size) + if size <= 1: + noise = torch.randn_like(images) + else: + small_H, small_W = H // size, W // size + noise = torch.randn(B, small_H, small_W, C, device=images.device) + noise = noise.permute(0, 3, 1, 2) + noise = F.interpolate(noise, size=(H, W), mode="nearest") + noise = noise.permute(0, 2, 3, 1) + + noise[:, :, :, 0] *= 2.0 + noise[:, :, :, 2] *= 3.0 + return noise + + def _generate_poisson(self, images, size): + B, H, W, C = images.shape + size = int(size) + if size <= 1: + target_images = images + else: + small_H, small_W = H // size, W // size + target_images = F.interpolate(images.permute(0, 3, 1, 2), size=(small_H, small_W), mode="bilinear", align_corners=False) + target_images = target_images.permute(0, 2, 3, 1) + + scaled = torch.clamp(target_images * 255.0, 0, 255).round() + noise = torch.poisson(scaled) - scaled + grain = (noise / 255.0) * 16.0 + + if size > 1: + grain = grain.permute(0, 3, 1, 2) + grain = F.interpolate(grain, size=(H, W), mode="nearest") + grain = grain.permute(0, 2, 3, 1) + + grain[:, :, :, 0] *= 2.0 + grain[:, :, :, 2] *= 3.0 + return grain + + def _generate_perlin(self, images, size): + B, H, W, C = images.shape + size = int(size) + scale = max(4, 32 // size) + + perlin = self._make_fractal_noise(B, H, W, scale, images.device) + perlin = (perlin - 0.5) * 2.0 + + perlin = perlin.unsqueeze(3).repeat(1, 1, 1, 3) + return perlin + + def _make_fractal_noise(self, B, H, W, scale, device, octaves=4, persistence=0.5, lacunarity=2.0): + total_noise = torch.zeros(B, H, W, device=device) + frequency = 1.0 + amplitude = 1.0 + max_amplitude = 0.0 + + for _ in range(octaves): + current_scale = max(2, int(scale * frequency)) + coarse_noise = torch.rand(B, current_scale, current_scale, 1, device=device) + upscaled_noise = F.interpolate(coarse_noise.permute(0, 3, 1, 2), size=(H, W), mode="bilinear", align_corners=False).squeeze(1) + total_noise += upscaled_noise * amplitude + max_amplitude += amplitude + amplitude *= persistence + frequency *= lacunarity + + if max_amplitude > 0: + total_noise /= max_amplitude + + return total_noise + + + # --- FONCTION PRINCIPALE --- + def process(self, image_in, **kwargs): + global_seed = kwargs.get("seed", 0) + # Récupérer tous les arguments + enable_latent_pass = kwargs.get("enable_latent_pass", False) + model_pass_1 = kwargs.get("model_pass_1", None) + vae_pass_1 = kwargs.get("vae_pass_1", None) + positive_pass_1 = kwargs.get("positive_pass_1", None) + negative_pass_1 = kwargs.get("negative_pass_1", None) + + # (MODIFIÉ: _2 ajouté) + enable_tiled_pass_2 = kwargs.get("enable_tiled_pass_2", False) + model_pass_2 = kwargs.get("model_pass_2", None) + vae_pass_2 = kwargs.get("vae_pass_2", None) + positive_pass_2 = kwargs.get("positive_pass_2", None) + negative_pass_2 = kwargs.get("negative_pass_2", None) + + # (NOUVEAU) + enable_tiled_pass_3 = kwargs.get("enable_tiled_pass_3", False) + model_pass_3 = kwargs.get("model_pass_3", None) + vae_pass_3 = kwargs.get("vae_pass_3", None) + positive_pass_3 = kwargs.get("positive_pass_3", None) + negative_pass_3 = kwargs.get("negative_pass_3", None) + + enable_sharpen = kwargs.get("enable_sharpen", False) + enable_grain = kwargs.get("enable_grain", False) + + # Démarrer le pipeline + current_image = image_in.clone() + + # --- Exécuter PASS 1 --- + if (enable_latent_pass and + model_pass_1 is not None and + vae_pass_1 is not None and + positive_pass_1 is not None and + negative_pass_1 is not None): + + current_image = self.run_latent_refine( + image=current_image, + model=model_pass_1, + vae=vae_pass_1, + positive=positive_pass_1, + negative=negative_pass_1, + upscale_by=kwargs.get("latent_upscale_by", 1.5), + denoise=kwargs.get("latent_denoise", 0.5), + sampler=kwargs.get("latent_sampler_name", "dpmpp_2m"), + scheduler=kwargs.get("latent_scheduler", "karras"), + steps=kwargs.get("latent_steps", 20), + cfg=kwargs.get("latent_cfg", 7.0), + seed=global_seed + ) + else: + print("[SuperScaler] PASS 1: Sauté (désactivé ou inputs manquants)") + + # --- Exécuter PASS 2 --- (MODIFIÉ: _2 ajouté) + if (enable_tiled_pass_2 and + model_pass_2 is not None and + vae_pass_2 is not None and + positive_pass_2 is not None and + negative_pass_2 is not None): + + current_image = self._run_tiled_pass( + pass_name="PASS 2", + image=current_image, + model=model_pass_2, + vae=vae_pass_2, + positive=positive_pass_2, + negative=negative_pass_2, + upscale_by=kwargs.get("tiled_upscale_by_2", 2.0), + denoise=kwargs.get("tiled_denoise_2", 0.35), + tile_size=kwargs.get("tile_size_2", 768), + overlap=kwargs.get("tile_overlap_2", 128), + sampler=kwargs.get("tiled_sampler_name_2", "dpmpp_2m"), + scheduler=kwargs.get("tiled_scheduler_2", "karras"), + steps=kwargs.get("tiled_steps_2", 15), + cfg=kwargs.get("tiled_cfg_2", 7.0), + seed=global_seed + ) + else: + print("[SuperScaler] PASS 2: Sauté (désactivé ou inputs manquants)") + + # --- Exécuter PASS 3 --- (NOUVEAU) + if (enable_tiled_pass_3 and + model_pass_3 is not None and + vae_pass_3 is not None and + positive_pass_3 is not None and + negative_pass_3 is not None): + + current_image = self._run_tiled_pass( + pass_name="PASS 3", + image=current_image, + model=model_pass_3, + vae=vae_pass_3, + positive=positive_pass_3, + negative=negative_pass_3, + upscale_by=kwargs.get("tiled_upscale_by_3", 2.0), + denoise=kwargs.get("tiled_denoise_3", 0.15), + tile_size=kwargs.get("tile_size_3", 512), + overlap=kwargs.get("tile_overlap_3", 64), + sampler=kwargs.get("tiled_sampler_name_3", "dpmpp_2m"), + scheduler=kwargs.get("tiled_scheduler_3", "karras"), + steps=kwargs.get("tiled_steps_3", 12), + cfg=kwargs.get("tiled_cfg_3", 7.0), + seed=global_seed + ) + else: + print("[SuperScaler] PASS 3: Sauté (désactivé ou inputs manquants)") + + # --- Exécuter PASS 4 (POST-FX) --- + if enable_sharpen: + current_image = self.run_freq_split_sharpen( + image_tensor=current_image, + amount=kwargs.get("sharpen_amount", 0.5), + radius=kwargs.get("sharpen_radius", 2) + ) + + if enable_grain: + current_image = self.run_add_grain( + images=current_image, + grain_type=kwargs.get("grain_type", "poisson"), + grain_intensity=kwargs.get("grain_intensity", 0.022), + grain_size=kwargs.get("grain_size", 1.5), + saturation_mix=kwargs.get("saturation_mix", 0.22), + adaptive_grain=kwargs.get("adaptive_grain", 0.30) + ) + + return (current_image,) + +# --- Enregistrement du Node --- +NODE_CLASS_MAPPINGS = { + "SuperScaler_Pipeline": SuperScaler +} +NODE_DISPLAY_NAME_MAPPINGS = { + "SuperScaler_Pipeline": "Pipeline SuperScaler" +} \ No newline at end of file