mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-15 16:07:51 +01:00
fix base_tracker
This commit is contained in:
@@ -28,6 +28,8 @@
|
||||
<!-- ![avengers]() -->
|
||||
|
||||
## :rocket: Updates
|
||||
- 2023/04/25: We are delighted to introduce [Caption-Anything](https://github.com/ttengwang/Caption-Anything) :writing_hand:, a versatile image processing tool that combines the capabilities of Segment Anything, Visual Captioning, and ChatGPT.
|
||||
|
||||
- 2023/04/20: We deployed [[DEMO]](https://huggingface.co/spaces/watchtowerss/Track-Anything) on Hugging Face :hugs:!
|
||||
|
||||
## Demo
|
||||
@@ -56,13 +58,6 @@ cd Track-Anything
|
||||
# Install dependencies:
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Install dependencies for inpainting:
|
||||
pip install -U openmim
|
||||
mim install mmcv
|
||||
|
||||
# Install dependencies for editing
|
||||
pip install madgrad
|
||||
|
||||
# Run the Track-Anything gradio demo.
|
||||
python app.py --device cuda:0 --sam_model_type vit_h --port 12212
|
||||
```
|
||||
|
||||
@@ -15,3 +15,5 @@ onnx
|
||||
metaseg==0.6.1
|
||||
pyyaml
|
||||
av
|
||||
mmcv-full
|
||||
mmengine
|
||||
|
||||
@@ -9,14 +9,14 @@ import yaml
|
||||
import torch.nn.functional as F
|
||||
from model.network import XMem
|
||||
from inference.inference_core import InferenceCore
|
||||
from util.mask_mapper import MaskMapper
|
||||
from tracker.util.mask_mapper import MaskMapper
|
||||
from torchvision import transforms
|
||||
from util.range_transform import im_normalization
|
||||
import sys
|
||||
sys.path.insert(0, sys.path[0]+"/../")
|
||||
from tracker.util.range_transform import im_normalization
|
||||
|
||||
from tools.painter import mask_painter
|
||||
from tools.base_segmenter import BaseSegmenter
|
||||
from torchvision.transforms import Resize
|
||||
import progressbar
|
||||
|
||||
|
||||
class BaseTracker:
|
||||
@@ -101,6 +101,8 @@ class BaseTracker:
|
||||
continue
|
||||
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
|
||||
|
||||
print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
|
||||
|
||||
return final_mask, final_mask, painted_image
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -126,50 +128,65 @@ class BaseTracker:
|
||||
self.mapper.clear_labels()
|
||||
|
||||
|
||||
## how to use:
|
||||
## 1/3) prepare device and xmem_checkpoint
|
||||
# device = 'cuda:2'
|
||||
# XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||||
## 2/3) initialise Base Tracker
|
||||
# tracker = BaseTracker(XMEM_checkpoint, device, None, device) # leave an interface for sam model (currently set None)
|
||||
## 3/3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# video frames (multiple objects)
|
||||
# video frames (take videos from DAVIS-2017 as examples)
|
||||
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
|
||||
video_path_list.sort()
|
||||
# first frame
|
||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
|
||||
# load frames
|
||||
frames = []
|
||||
for video_path in video_path_list:
|
||||
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
||||
frames = np.stack(frames, 0) # N, H, W, C
|
||||
frames = np.stack(frames, 0) # T, H, W, C
|
||||
# load first frame annotation
|
||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
|
||||
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
|
||||
|
||||
# ----------------------------------------------------------
|
||||
# initalise tracker
|
||||
# ----------------------------------------------------------
|
||||
device = 'cuda:4'
|
||||
# ------------------------------------------------------------------------------------
|
||||
# how to use
|
||||
# ------------------------------------------------------------------------------------
|
||||
# 1/4: set checkpoint and device
|
||||
device = 'cuda:2'
|
||||
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
|
||||
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||
model_type = 'vit_h'
|
||||
|
||||
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
|
||||
# SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||
# model_type = 'vit_h'
|
||||
# ------------------------------------------------------------------------------------
|
||||
# 2/4: initialise inpainter
|
||||
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
|
||||
|
||||
# # test for storage efficiency
|
||||
# frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
|
||||
# first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
|
||||
|
||||
first_frame_annotation[first_frame_annotation==1] = 15
|
||||
first_frame_annotation[first_frame_annotation==2] = 20
|
||||
|
||||
save_path = '/ssd1/gaomingqi/results/TrackA/multi-change1'
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
|
||||
# ------------------------------------------------------------------------------------
|
||||
# 3/4: for each frame, get tracking results by tracker.track(frame, first_frame_annotation)
|
||||
# frame: numpy array (H, W, C), first_frame_annotation: numpy array (H, W), leave it blank when tracking begins
|
||||
painted_frames = []
|
||||
for ti, frame in enumerate(frames):
|
||||
if ti == 0:
|
||||
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||
mask, prob, painted_frame = tracker.track(frame, first_frame_annotation)
|
||||
# mask:
|
||||
else:
|
||||
mask, prob, painted_image = tracker.track(frame)
|
||||
# save
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
painted_image.save(f'{save_path}/{ti:05d}.png')
|
||||
mask, prob, painted_frame = tracker.track(frame)
|
||||
painted_frames.append(painted_frame)
|
||||
# ----------------------------------------------
|
||||
# 3/4: clear memory in XMEM for the next video
|
||||
tracker.clear_memory()
|
||||
# ----------------------------------------------
|
||||
# end
|
||||
# ----------------------------------------------
|
||||
print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
|
||||
# set saving path
|
||||
save_path = '/ssd1/gaomingqi/results/TAM/blackswan'
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
# save
|
||||
for painted_frame in progressbar.progressbar(painted_frames):
|
||||
painted_frame = Image.fromarray(painted_frame)
|
||||
painted_frame.save(f'{save_path}/{ti:05d}.png')
|
||||
|
||||
# tracker.clear_memory()
|
||||
# for ti, frame in enumerate(frames):
|
||||
@@ -241,6 +258,3 @@ if __name__ == '__main__':
|
||||
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
|
||||
|
||||
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from inference.memory_manager import MemoryManager
|
||||
from model.network import XMem
|
||||
from model.aggregate import aggregate
|
||||
|
||||
from util.tensor_util import pad_divide_by, unpad
|
||||
from tracker.util.tensor_util import pad_divide_by, unpad
|
||||
|
||||
|
||||
class InferenceCore:
|
||||
|
||||
Reference in New Issue
Block a user