diff --git a/inpainter/base_inpainter.py b/inpainter/base_inpainter.py index 40d802d..18fdfce 100644 --- a/inpainter/base_inpainter.py +++ b/inpainter/base_inpainter.py @@ -7,6 +7,8 @@ import yaml import cv2 import importlib import numpy as np +from tqdm import tqdm + from inpainter.util.tensor_util import resize_frames, resize_masks @@ -66,15 +68,15 @@ class BaseInpainter: if ratio == 1: size = None else: - size = (int(W*ratio), int(H*ratio)) + size = [int(W*ratio), int(H*ratio)] if size[0] % 2 > 0: size[0] += 1 if size[1] % 2 > 0: size[1] += 1 masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1 - binary_masks = resize_masks(masks, size) - frames = resize_frames(frames, size) # T, H, W, 3 + binary_masks = resize_masks(masks, tuple(size)) + frames = resize_frames(frames, tuple(size)) # T, H, W, 3 # frames and binary_masks are numpy arrays h, w = frames.shape[1:3] @@ -87,7 +89,7 @@ class BaseInpainter: imgs, masks = imgs.to(self.device), masks.to(self.device) comp_frames = [None] * video_length - for f in range(0, video_length, self.neighbor_stride): + for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'): neighbor_ids = [ i for i in range(max(0, f - self.neighbor_stride), min(video_length, f + self.neighbor_stride + 1)) diff --git a/inpainter/model/modules/tfocal_transformer_hq.py b/inpainter/model/modules/tfocal_transformer_hq.py index 1a24dfa..efabefb 100644 --- a/inpainter/model/modules/tfocal_transformer_hq.py +++ b/inpainter/model/modules/tfocal_transformer_hq.py @@ -128,8 +128,10 @@ def window_partition(x, window_size): windows: (B*num_windows, T*window_size*window_size, C) """ B, T, H, W, C = x.shape + x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view( -1, T * window_size[0] * window_size[1], C) return windows diff --git a/requirements.txt b/requirements.txt index d7acc2b..410219f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ pycocotools matplotlib pyyaml av -openmim \ No newline at end of file +openmim +tqdm \ No newline at end of file diff --git a/track_anything.py b/track_anything.py index 3786b6a..5275252 100644 --- a/track_anything.py +++ b/track_anything.py @@ -1,4 +1,6 @@ -import PIL +import PIL +from tqdm import tqdm + from tools.interact_tools import SamControler from tracker.base_tracker import BaseTracker from inpainter.base_inpainter import BaseInpainter @@ -42,7 +44,7 @@ class TrackingAnything(): masks = [] logits = [] painted_images = [] - for i in range(len(images)): + for i in tqdm(range(len(images)), desc="Tracking image"): if i ==0: mask, logit, painted_image = self.xmem.track(images[i], template_mask) masks.append(mask) @@ -54,7 +56,6 @@ class TrackingAnything(): masks.append(mask) logits.append(logit) painted_images.append(painted_image) - print("tracking image {}".format(i)) return masks, logits, painted_images