mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-15 16:07:51 +01:00
Merge branch 'master' of github.com:gaomingqi/Track-Anything
This commit is contained in:
@@ -7,6 +7,8 @@ import yaml
|
|||||||
import cv2
|
import cv2
|
||||||
import importlib
|
import importlib
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from inpainter.util.tensor_util import resize_frames, resize_masks
|
from inpainter.util.tensor_util import resize_frames, resize_masks
|
||||||
|
|
||||||
|
|
||||||
@@ -66,15 +68,15 @@ class BaseInpainter:
|
|||||||
if ratio == 1:
|
if ratio == 1:
|
||||||
size = None
|
size = None
|
||||||
else:
|
else:
|
||||||
size = (int(W*ratio), int(H*ratio))
|
size = [int(W*ratio), int(H*ratio)]
|
||||||
if size[0] % 2 > 0:
|
if size[0] % 2 > 0:
|
||||||
size[0] += 1
|
size[0] += 1
|
||||||
if size[1] % 2 > 0:
|
if size[1] % 2 > 0:
|
||||||
size[1] += 1
|
size[1] += 1
|
||||||
|
|
||||||
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
||||||
binary_masks = resize_masks(masks, size)
|
binary_masks = resize_masks(masks, tuple(size))
|
||||||
frames = resize_frames(frames, size) # T, H, W, 3
|
frames = resize_frames(frames, tuple(size)) # T, H, W, 3
|
||||||
# frames and binary_masks are numpy arrays
|
# frames and binary_masks are numpy arrays
|
||||||
|
|
||||||
h, w = frames.shape[1:3]
|
h, w = frames.shape[1:3]
|
||||||
@@ -87,7 +89,7 @@ class BaseInpainter:
|
|||||||
imgs, masks = imgs.to(self.device), masks.to(self.device)
|
imgs, masks = imgs.to(self.device), masks.to(self.device)
|
||||||
comp_frames = [None] * video_length
|
comp_frames = [None] * video_length
|
||||||
|
|
||||||
for f in range(0, video_length, self.neighbor_stride):
|
for f in tqdm(range(0, video_length, self.neighbor_stride), desc='Inpainting image'):
|
||||||
neighbor_ids = [
|
neighbor_ids = [
|
||||||
i for i in range(max(0, f - self.neighbor_stride),
|
i for i in range(max(0, f - self.neighbor_stride),
|
||||||
min(video_length, f + self.neighbor_stride + 1))
|
min(video_length, f + self.neighbor_stride + 1))
|
||||||
|
|||||||
@@ -128,8 +128,10 @@ def window_partition(x, window_size):
|
|||||||
windows: (B*num_windows, T*window_size*window_size, C)
|
windows: (B*num_windows, T*window_size*window_size, C)
|
||||||
"""
|
"""
|
||||||
B, T, H, W, C = x.shape
|
B, T, H, W, C = x.shape
|
||||||
|
|
||||||
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
||||||
window_size[1], C)
|
window_size[1], C)
|
||||||
|
|
||||||
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
|
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
|
||||||
-1, T * window_size[0] * window_size[1], C)
|
-1, T * window_size[0] * window_size[1], C)
|
||||||
return windows
|
return windows
|
||||||
|
|||||||
@@ -12,4 +12,5 @@ pycocotools
|
|||||||
matplotlib
|
matplotlib
|
||||||
pyyaml
|
pyyaml
|
||||||
av
|
av
|
||||||
openmim
|
openmim
|
||||||
|
tqdm
|
||||||
@@ -1,4 +1,6 @@
|
|||||||
import PIL
|
import PIL
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from tools.interact_tools import SamControler
|
from tools.interact_tools import SamControler
|
||||||
from tracker.base_tracker import BaseTracker
|
from tracker.base_tracker import BaseTracker
|
||||||
from inpainter.base_inpainter import BaseInpainter
|
from inpainter.base_inpainter import BaseInpainter
|
||||||
@@ -42,7 +44,7 @@ class TrackingAnything():
|
|||||||
masks = []
|
masks = []
|
||||||
logits = []
|
logits = []
|
||||||
painted_images = []
|
painted_images = []
|
||||||
for i in range(len(images)):
|
for i in tqdm(range(len(images)), desc="Tracking image"):
|
||||||
if i ==0:
|
if i ==0:
|
||||||
mask, logit, painted_image = self.xmem.track(images[i], template_mask)
|
mask, logit, painted_image = self.xmem.track(images[i], template_mask)
|
||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
@@ -54,7 +56,6 @@ class TrackingAnything():
|
|||||||
masks.append(mask)
|
masks.append(mask)
|
||||||
logits.append(logit)
|
logits.append(logit)
|
||||||
painted_images.append(painted_image)
|
painted_images.append(painted_image)
|
||||||
print("tracking image {}".format(i))
|
|
||||||
return masks, logits, painted_images
|
return masks, logits, painted_images
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user