mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
fix base_tracker
This commit is contained in:
@@ -28,6 +28,8 @@
|
|||||||
<!-- ![avengers]() -->
|
<!-- ![avengers]() -->
|
||||||
|
|
||||||
## :rocket: Updates
|
## :rocket: Updates
|
||||||
|
- 2023/04/25: We are delighted to introduce [Caption-Anything](https://github.com/ttengwang/Caption-Anything) :writing_hand:, a versatile image processing tool that combines the capabilities of Segment Anything, Visual Captioning, and ChatGPT.
|
||||||
|
|
||||||
- 2023/04/20: We deployed [[DEMO]](https://huggingface.co/spaces/watchtowerss/Track-Anything) on Hugging Face :hugs:!
|
- 2023/04/20: We deployed [[DEMO]](https://huggingface.co/spaces/watchtowerss/Track-Anything) on Hugging Face :hugs:!
|
||||||
|
|
||||||
## Demo
|
## Demo
|
||||||
@@ -56,13 +58,6 @@ cd Track-Anything
|
|||||||
# Install dependencies:
|
# Install dependencies:
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|
||||||
# Install dependencies for inpainting:
|
|
||||||
pip install -U openmim
|
|
||||||
mim install mmcv
|
|
||||||
|
|
||||||
# Install dependencies for editing
|
|
||||||
pip install madgrad
|
|
||||||
|
|
||||||
# Run the Track-Anything gradio demo.
|
# Run the Track-Anything gradio demo.
|
||||||
python app.py --device cuda:0 --sam_model_type vit_h --port 12212
|
python app.py --device cuda:0 --sam_model_type vit_h --port 12212
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -15,3 +15,5 @@ onnx
|
|||||||
metaseg==0.6.1
|
metaseg==0.6.1
|
||||||
pyyaml
|
pyyaml
|
||||||
av
|
av
|
||||||
|
mmcv-full
|
||||||
|
mmengine
|
||||||
|
|||||||
@@ -9,14 +9,14 @@ import yaml
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from model.network import XMem
|
from model.network import XMem
|
||||||
from inference.inference_core import InferenceCore
|
from inference.inference_core import InferenceCore
|
||||||
from util.mask_mapper import MaskMapper
|
from tracker.util.mask_mapper import MaskMapper
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from util.range_transform import im_normalization
|
from tracker.util.range_transform import im_normalization
|
||||||
import sys
|
|
||||||
sys.path.insert(0, sys.path[0]+"/../")
|
|
||||||
from tools.painter import mask_painter
|
from tools.painter import mask_painter
|
||||||
from tools.base_segmenter import BaseSegmenter
|
from tools.base_segmenter import BaseSegmenter
|
||||||
from torchvision.transforms import Resize
|
from torchvision.transforms import Resize
|
||||||
|
import progressbar
|
||||||
|
|
||||||
|
|
||||||
class BaseTracker:
|
class BaseTracker:
|
||||||
@@ -101,6 +101,8 @@ class BaseTracker:
|
|||||||
continue
|
continue
|
||||||
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
|
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
|
||||||
|
|
||||||
|
print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
|
||||||
|
|
||||||
return final_mask, final_mask, painted_image
|
return final_mask, final_mask, painted_image
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@@ -126,50 +128,65 @@ class BaseTracker:
|
|||||||
self.mapper.clear_labels()
|
self.mapper.clear_labels()
|
||||||
|
|
||||||
|
|
||||||
|
## how to use:
|
||||||
|
## 1/3) prepare device and xmem_checkpoint
|
||||||
|
# device = 'cuda:2'
|
||||||
|
# XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||||||
|
## 2/3) initialise Base Tracker
|
||||||
|
# tracker = BaseTracker(XMEM_checkpoint, device, None, device) # leave an interface for sam model (currently set None)
|
||||||
|
## 3/3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# video frames (multiple objects)
|
# video frames (take videos from DAVIS-2017 as examples)
|
||||||
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
|
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
|
||||||
video_path_list.sort()
|
video_path_list.sort()
|
||||||
# first frame
|
|
||||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
|
|
||||||
# load frames
|
# load frames
|
||||||
frames = []
|
frames = []
|
||||||
for video_path in video_path_list:
|
for video_path in video_path_list:
|
||||||
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
||||||
frames = np.stack(frames, 0) # N, H, W, C
|
frames = np.stack(frames, 0) # T, H, W, C
|
||||||
# load first frame annotation
|
# load first frame annotation
|
||||||
|
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
|
||||||
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
||||||
|
|
||||||
# ----------------------------------------------------------
|
# ------------------------------------------------------------------------------------
|
||||||
# initalise tracker
|
# how to use
|
||||||
# ----------------------------------------------------------
|
# ------------------------------------------------------------------------------------
|
||||||
device = 'cuda:4'
|
# 1/4: set checkpoint and device
|
||||||
|
device = 'cuda:2'
|
||||||
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||||||
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
# SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||||
model_type = 'vit_h'
|
# model_type = 'vit_h'
|
||||||
|
# ------------------------------------------------------------------------------------
|
||||||
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
|
# 2/4: initialise inpainter
|
||||||
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
|
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
|
||||||
|
# ------------------------------------------------------------------------------------
|
||||||
# # test for storage efficiency
|
# 3/4: for each frame, get tracking results by tracker.track(frame, first_frame_annotation)
|
||||||
# frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
|
# frame: numpy array (H, W, C), first_frame_annotation: numpy array (H, W), leave it blank when tracking begins
|
||||||
# first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
|
painted_frames = []
|
||||||
|
|
||||||
first_frame_annotation[first_frame_annotation==1] = 15
|
|
||||||
first_frame_annotation[first_frame_annotation==2] = 20
|
|
||||||
|
|
||||||
save_path = '/ssd1/gaomingqi/results/TrackA/multi-change1'
|
|
||||||
if not os.path.exists(save_path):
|
|
||||||
os.mkdir(save_path)
|
|
||||||
|
|
||||||
for ti, frame in enumerate(frames):
|
for ti, frame in enumerate(frames):
|
||||||
if ti == 0:
|
if ti == 0:
|
||||||
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
mask, prob, painted_frame = tracker.track(frame, first_frame_annotation)
|
||||||
|
# mask:
|
||||||
else:
|
else:
|
||||||
mask, prob, painted_image = tracker.track(frame)
|
mask, prob, painted_frame = tracker.track(frame)
|
||||||
|
painted_frames.append(painted_frame)
|
||||||
|
# ----------------------------------------------
|
||||||
|
# 3/4: clear memory in XMEM for the next video
|
||||||
|
tracker.clear_memory()
|
||||||
|
# ----------------------------------------------
|
||||||
|
# end
|
||||||
|
# ----------------------------------------------
|
||||||
|
print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
|
||||||
|
# set saving path
|
||||||
|
save_path = '/ssd1/gaomingqi/results/TAM/blackswan'
|
||||||
|
if not os.path.exists(save_path):
|
||||||
|
os.mkdir(save_path)
|
||||||
# save
|
# save
|
||||||
painted_image = Image.fromarray(painted_image)
|
for painted_frame in progressbar.progressbar(painted_frames):
|
||||||
painted_image.save(f'{save_path}/{ti:05d}.png')
|
painted_frame = Image.fromarray(painted_frame)
|
||||||
|
painted_frame.save(f'{save_path}/{ti:05d}.png')
|
||||||
|
|
||||||
# tracker.clear_memory()
|
# tracker.clear_memory()
|
||||||
# for ti, frame in enumerate(frames):
|
# for ti, frame in enumerate(frames):
|
||||||
@@ -241,6 +258,3 @@ if __name__ == '__main__':
|
|||||||
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
|
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
|
||||||
|
|
||||||
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
|
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from inference.memory_manager import MemoryManager
|
|||||||
from model.network import XMem
|
from model.network import XMem
|
||||||
from model.aggregate import aggregate
|
from model.aggregate import aggregate
|
||||||
|
|
||||||
from util.tensor_util import pad_divide_by, unpad
|
from tracker.util.tensor_util import pad_divide_by, unpad
|
||||||
|
|
||||||
|
|
||||||
class InferenceCore:
|
class InferenceCore:
|
||||||
|
|||||||
Reference in New Issue
Block a user