fix base_tracker

This commit is contained in:
gaomingqi
2023-04-25 20:59:44 +08:00
parent 2d05a3a8e5
commit 6d9c051661
4 changed files with 55 additions and 44 deletions

View File

@@ -28,6 +28,8 @@
<!-- ![avengers]() -->
## :rocket: Updates
- 2023/04/25: We are delighted to introduce [Caption-Anything](https://github.com/ttengwang/Caption-Anything) :writing_hand:, a versatile image processing tool that combines the capabilities of Segment Anything, Visual Captioning, and ChatGPT.
- 2023/04/20: We deployed [[DEMO]](https://huggingface.co/spaces/watchtowerss/Track-Anything) on Hugging Face :hugs:!
## Demo
@@ -56,13 +58,6 @@ cd Track-Anything
# Install dependencies:
pip install -r requirements.txt
# Install dependencies for inpainting:
pip install -U openmim
mim install mmcv
# Install dependencies for editing
pip install madgrad
# Run the Track-Anything gradio demo.
python app.py --device cuda:0 --sam_model_type vit_h --port 12212
```

View File

@@ -15,3 +15,5 @@ onnx
metaseg==0.6.1
pyyaml
av
mmcv-full
mmengine

View File

@@ -9,14 +9,14 @@ import yaml
import torch.nn.functional as F
from model.network import XMem
from inference.inference_core import InferenceCore
from util.mask_mapper import MaskMapper
from tracker.util.mask_mapper import MaskMapper
from torchvision import transforms
from util.range_transform import im_normalization
import sys
sys.path.insert(0, sys.path[0]+"/../")
from tracker.util.range_transform import im_normalization
from tools.painter import mask_painter
from tools.base_segmenter import BaseSegmenter
from torchvision.transforms import Resize
import progressbar
class BaseTracker:
@@ -101,6 +101,8 @@ class BaseTracker:
continue
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
return final_mask, final_mask, painted_image
@torch.no_grad()
@@ -126,50 +128,65 @@ class BaseTracker:
self.mapper.clear_labels()
## how to use:
## 1/3) prepare device and xmem_checkpoint
# device = 'cuda:2'
# XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
## 2/3) initialise Base Tracker
# tracker = BaseTracker(XMEM_checkpoint, device, None, device) # leave an interface for sam model (currently set None)
## 3/3)
if __name__ == '__main__':
# video frames (multiple objects)
# video frames (take videos from DAVIS-2017 as examples)
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
video_path_list.sort()
# first frame
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
# load frames
frames = []
for video_path in video_path_list:
frames.append(np.array(Image.open(video_path).convert('RGB')))
frames = np.stack(frames, 0) # N, H, W, C
frames = np.stack(frames, 0) # T, H, W, C
# load first frame annotation
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
# ----------------------------------------------------------
# initalise tracker
# ----------------------------------------------------------
device = 'cuda:4'
# ------------------------------------------------------------------------------------
# how to use
# ------------------------------------------------------------------------------------
# 1/4: set checkpoint and device
device = 'cuda:2'
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
model_type = 'vit_h'
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
# SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
# model_type = 'vit_h'
# ------------------------------------------------------------------------------------
# 2/4: initialise inpainter
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
# # test for storage efficiency
# frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
# first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
first_frame_annotation[first_frame_annotation==1] = 15
first_frame_annotation[first_frame_annotation==2] = 20
save_path = '/ssd1/gaomingqi/results/TrackA/multi-change1'
if not os.path.exists(save_path):
os.mkdir(save_path)
# ------------------------------------------------------------------------------------
# 3/4: for each frame, get tracking results by tracker.track(frame, first_frame_annotation)
# frame: numpy array (H, W, C), first_frame_annotation: numpy array (H, W), leave it blank when tracking begins
painted_frames = []
for ti, frame in enumerate(frames):
if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
mask, prob, painted_frame = tracker.track(frame, first_frame_annotation)
# mask:
else:
mask, prob, painted_image = tracker.track(frame)
# save
painted_image = Image.fromarray(painted_image)
painted_image.save(f'{save_path}/{ti:05d}.png')
mask, prob, painted_frame = tracker.track(frame)
painted_frames.append(painted_frame)
# ----------------------------------------------
# 3/4: clear memory in XMEM for the next video
tracker.clear_memory()
# ----------------------------------------------
# end
# ----------------------------------------------
print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
# set saving path
save_path = '/ssd1/gaomingqi/results/TAM/blackswan'
if not os.path.exists(save_path):
os.mkdir(save_path)
# save
for painted_frame in progressbar.progressbar(painted_frames):
painted_frame = Image.fromarray(painted_frame)
painted_frame.save(f'{save_path}/{ti:05d}.png')
# tracker.clear_memory()
# for ti, frame in enumerate(frames):
@@ -241,6 +258,3 @@ if __name__ == '__main__':
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')

View File

@@ -2,7 +2,7 @@ from inference.memory_manager import MemoryManager
from model.network import XMem
from model.aggregate import aggregate
from util.tensor_util import pad_divide_by, unpad
from tracker.util.tensor_util import pad_divide_by, unpad
class InferenceCore: