fix base_tracker

This commit is contained in:
gaomingqi
2023-04-25 20:59:44 +08:00
parent 2d05a3a8e5
commit 6d9c051661
4 changed files with 55 additions and 44 deletions

View File

@@ -28,6 +28,8 @@
<!-- ![avengers]() --> <!-- ![avengers]() -->
## :rocket: Updates ## :rocket: Updates
- 2023/04/25: We are delighted to introduce [Caption-Anything](https://github.com/ttengwang/Caption-Anything) :writing_hand:, a versatile image processing tool that combines the capabilities of Segment Anything, Visual Captioning, and ChatGPT.
- 2023/04/20: We deployed [[DEMO]](https://huggingface.co/spaces/watchtowerss/Track-Anything) on Hugging Face :hugs:! - 2023/04/20: We deployed [[DEMO]](https://huggingface.co/spaces/watchtowerss/Track-Anything) on Hugging Face :hugs:!
## Demo ## Demo
@@ -56,13 +58,6 @@ cd Track-Anything
# Install dependencies: # Install dependencies:
pip install -r requirements.txt pip install -r requirements.txt
# Install dependencies for inpainting:
pip install -U openmim
mim install mmcv
# Install dependencies for editing
pip install madgrad
# Run the Track-Anything gradio demo. # Run the Track-Anything gradio demo.
python app.py --device cuda:0 --sam_model_type vit_h --port 12212 python app.py --device cuda:0 --sam_model_type vit_h --port 12212
``` ```

View File

@@ -15,3 +15,5 @@ onnx
metaseg==0.6.1 metaseg==0.6.1
pyyaml pyyaml
av av
mmcv-full
mmengine

View File

@@ -9,14 +9,14 @@ import yaml
import torch.nn.functional as F import torch.nn.functional as F
from model.network import XMem from model.network import XMem
from inference.inference_core import InferenceCore from inference.inference_core import InferenceCore
from util.mask_mapper import MaskMapper from tracker.util.mask_mapper import MaskMapper
from torchvision import transforms from torchvision import transforms
from util.range_transform import im_normalization from tracker.util.range_transform import im_normalization
import sys
sys.path.insert(0, sys.path[0]+"/../")
from tools.painter import mask_painter from tools.painter import mask_painter
from tools.base_segmenter import BaseSegmenter from tools.base_segmenter import BaseSegmenter
from torchvision.transforms import Resize from torchvision.transforms import Resize
import progressbar
class BaseTracker: class BaseTracker:
@@ -101,6 +101,8 @@ class BaseTracker:
continue continue
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1) painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
return final_mask, final_mask, painted_image return final_mask, final_mask, painted_image
@torch.no_grad() @torch.no_grad()
@@ -126,50 +128,65 @@ class BaseTracker:
self.mapper.clear_labels() self.mapper.clear_labels()
## how to use:
## 1/3) prepare device and xmem_checkpoint
# device = 'cuda:2'
# XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
## 2/3) initialise Base Tracker
# tracker = BaseTracker(XMEM_checkpoint, device, None, device) # leave an interface for sam model (currently set None)
## 3/3)
if __name__ == '__main__': if __name__ == '__main__':
# video frames (multiple objects) # video frames (take videos from DAVIS-2017 as examples)
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg')) video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/horsejump-high', '*.jpg'))
video_path_list.sort() video_path_list.sort()
# first frame
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
# load frames # load frames
frames = [] frames = []
for video_path in video_path_list: for video_path in video_path_list:
frames.append(np.array(Image.open(video_path).convert('RGB'))) frames.append(np.array(Image.open(video_path).convert('RGB')))
frames = np.stack(frames, 0) # N, H, W, C frames = np.stack(frames, 0) # T, H, W, C
# load first frame annotation # load first frame annotation
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/horsejump-high/00000.png'
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
# ---------------------------------------------------------- # ------------------------------------------------------------------------------------
# initalise tracker # how to use
# ---------------------------------------------------------- # ------------------------------------------------------------------------------------
device = 'cuda:4' # 1/4: set checkpoint and device
device = 'cuda:2'
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth' XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' # SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
model_type = 'vit_h' # model_type = 'vit_h'
# ------------------------------------------------------------------------------------
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device) # 2/4: initialise inpainter
tracker = BaseTracker(XMEM_checkpoint, device, None, device) tracker = BaseTracker(XMEM_checkpoint, device, None, device)
# ------------------------------------------------------------------------------------
# # test for storage efficiency # 3/4: for each frame, get tracking results by tracker.track(frame, first_frame_annotation)
# frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy') # frame: numpy array (H, W, C), first_frame_annotation: numpy array (H, W), leave it blank when tracking begins
# first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png')) painted_frames = []
first_frame_annotation[first_frame_annotation==1] = 15
first_frame_annotation[first_frame_annotation==2] = 20
save_path = '/ssd1/gaomingqi/results/TrackA/multi-change1'
if not os.path.exists(save_path):
os.mkdir(save_path)
for ti, frame in enumerate(frames): for ti, frame in enumerate(frames):
if ti == 0: if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation) mask, prob, painted_frame = tracker.track(frame, first_frame_annotation)
# mask:
else: else:
mask, prob, painted_image = tracker.track(frame) mask, prob, painted_frame = tracker.track(frame)
# save painted_frames.append(painted_frame)
painted_image = Image.fromarray(painted_image) # ----------------------------------------------
painted_image.save(f'{save_path}/{ti:05d}.png') # 3/4: clear memory in XMEM for the next video
tracker.clear_memory()
# ----------------------------------------------
# end
# ----------------------------------------------
print(f'max memory allocated: {torch.cuda.max_memory_allocated()/(2**20)} MB')
# set saving path
save_path = '/ssd1/gaomingqi/results/TAM/blackswan'
if not os.path.exists(save_path):
os.mkdir(save_path)
# save
for painted_frame in progressbar.progressbar(painted_frames):
painted_frame = Image.fromarray(painted_frame)
painted_frame.save(f'{save_path}/{ti:05d}.png')
# tracker.clear_memory() # tracker.clear_memory()
# for ti, frame in enumerate(frames): # for ti, frame in enumerate(frames):
@@ -241,6 +258,3 @@ if __name__ == '__main__':
# prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8')) # prob = Image.fromarray((probs[1].cpu().numpy()*255).astype('uint8'))
# # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png') # # prob.save(f'/ssd1/gaomingqi/failure/probs/{ti:05d}.png')

View File

@@ -2,7 +2,7 @@ from inference.memory_manager import MemoryManager
from model.network import XMem from model.network import XMem
from model.aggregate import aggregate from model.aggregate import aggregate
from util.tensor_util import pad_divide_by, unpad from tracker.util.tensor_util import pad_divide_by, unpad
class InferenceCore: class InferenceCore: