Merge branch 'master' of github.com:gaomingqi/Track-Anything

This commit is contained in:
memoryunreal
2023-04-26 05:05:08 +00:00
4 changed files with 14 additions and 8 deletions

View File

@@ -7,6 +7,8 @@ import yaml
import cv2
import importlib
import numpy as np
from tqdm import tqdm
from inpainter.util.tensor_util import resize_frames, resize_masks
@@ -66,15 +68,15 @@ class BaseInpainter:
if ratio == 1:
size = None
else:
size = (int(W*ratio), int(H*ratio))
size = [int(W*ratio), int(H*ratio)]
if size[0] % 2 > 0:
size[0] += 1
if size[1] % 2 > 0:
size[1] += 1
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
binary_masks = resize_masks(masks, size)
frames = resize_frames(frames, size) # T, H, W, 3
binary_masks = resize_masks(masks, tuple(size))
frames = resize_frames(frames, tuple(size)) # T, H, W, 3
# frames and binary_masks are numpy arrays
h, w = frames.shape[1:3]
@@ -87,7 +89,7 @@ class BaseInpainter:
imgs, masks = imgs.to(self.device), masks.to(self.device)
comp_frames = [None] * video_length
for f in range(0, video_length, self.neighbor_stride):
for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
neighbor_ids = [
i for i in range(max(0, f - self.neighbor_stride),
min(video_length, f + self.neighbor_stride + 1))

View File

@@ -128,8 +128,10 @@ def window_partition(x, window_size):
windows: (B*num_windows, T*window_size*window_size, C)
"""
B, T, H, W, C = x.shape
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
window_size[1], C)
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
-1, T * window_size[0] * window_size[1], C)
return windows

View File

@@ -12,4 +12,5 @@ pycocotools
matplotlib
pyyaml
av
openmim
openmim
tqdm

View File

@@ -1,4 +1,6 @@
import PIL
import PIL
from tqdm import tqdm
from tools.interact_tools import SamControler
from tracker.base_tracker import BaseTracker
from inpainter.base_inpainter import BaseInpainter
@@ -42,7 +44,7 @@ class TrackingAnything():
masks = []
logits = []
painted_images = []
for i in range(len(images)):
for i in tqdm(range(len(images)), desc="Tracking image"):
if i ==0:
mask, logit, painted_image = self.xmem.track(images[i], template_mask)
masks.append(mask)
@@ -54,7 +56,6 @@ class TrackingAnything():
masks.append(mask)
logits.append(logit)
painted_images.append(painted_image)
print("tracking image {}".format(i))
return masks, logits, painted_images