fix mapper color

This commit is contained in:
gaomingqi
2023-04-19 19:58:41 +08:00
parent 5ed65a6ee6
commit 5c8fb53563
2 changed files with 39 additions and 24 deletions

View File

@@ -1,8 +1,10 @@
# Track-Anything
<!-- [![](https://img.shields.io/badge/arxiv-23xx.xxxxx-red.svg?style=flat-square)](linkUrl) &nbsp; [![](https://img.shields.io/badge/:hugs:-Open_in_Spaces-informational.svg?style=flat-square)](linkUrl) &nbsp; [![](https://img.shields.io/badge/contributors-SUSTech_VIP_Lab-important.svg?style=flat-square)](https://zhengfenglab.com/) -->
***Track-Anything*** is a flexible and interactive tool for video object tracking and segmentation. It is developed upon [Segment Anything](https://github.com/facebookresearch/segment-anything) and [XMem](https://github.com/hkchengrex/XMem), can specify anything to track and segment via user clicks only. During tracking, users can flexibly change the objects they wanna track or correct the region of interest if there are any ambiguities. These characteristics enable ***Track-Anything*** to be suitable for:
- Video object tracking and segmentation with shot changes.
- Data annnotation for video object tracking and segmentation.
- Visualized development and data annnotation for video object tracking and segmentation.
- Object-centric downstream video tasks, such as video inpainting and editing.
## Demo

View File

@@ -67,6 +67,7 @@ class BaseTracker:
logit: numpy arrays, probability map (H, W)
painted_image: numpy array (H, W, 3)
"""
if first_frame_annotation is not None: # first frame mask
# initialisation
mask, labels = self.mapper.convert_mask(first_frame_annotation)
@@ -87,12 +88,20 @@ class BaseTracker:
out_mask = torch.argmax(probs, dim=0)
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
num_objs = out_mask.max()
final_mask = np.zeros_like(out_mask)
# map back
for k, v in self.mapper.remappings.items():
final_mask[out_mask == v] = k
num_objs = final_mask.max()
painted_image = frame
for obj in range(1, num_objs+1):
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
return out_mask, out_mask, painted_image
if np.max(final_mask==obj) == 0:
continue
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
return final_mask, final_mask, painted_image
@torch.no_grad()
def sam_refinement(self, frame, logits, ti):
@@ -142,34 +151,38 @@ if __name__ == '__main__':
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
# test for storage efficiency
frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
# # test for storage efficiency
# frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
# first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
first_frame_annotation[first_frame_annotation==1] = 15
first_frame_annotation[first_frame_annotation==2] = 20
save_path = '/ssd1/gaomingqi/results/TrackA/multi-change1'
if not os.path.exists(save_path):
os.mkdir(save_path)
for ti, frame in enumerate(frames):
print(ti)
if ti > 200:
break
if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
else:
mask, prob, painted_image = tracker.track(frame)
# save
painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
painted_image.save(f'{save_path}/{ti:05d}.png')
tracker.clear_memory()
for ti, frame in enumerate(frames):
print(ti)
# if ti > 200:
# break
if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
else:
mask, prob, painted_image = tracker.track(frame)
# save
painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
# tracker.clear_memory()
# for ti, frame in enumerate(frames):
# print(ti)
# # if ti > 200:
# # break
# if ti == 0:
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
# else:
# mask, prob, painted_image = tracker.track(frame)
# # save
# painted_image = Image.fromarray(painted_image)
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
# # track anything given in the first frame annotation
# for ti, frame in enumerate(frames):