mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-15 16:07:51 +01:00
Merge remote-tracking branch 'origin/master'
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
<a src="https://img.shields.io/badge/%F0%9F%93%96-Open_in_Spaces-informational.svg?style=flat-square" href="https://arxiv.org/abs/2304.11968">
|
||||
<img src="https://img.shields.io/badge/%F0%9F%93%96-Arxiv_2304.11968-red.svg?style=flat-square">
|
||||
</a>
|
||||
<a src="https://img.shields.io/badge/%F0%9F%A4%97-Open_in_Spaces-informational.svg?style=flat-square" href="https://huggingface.co/spaces/watchtowerss/Track-Anything">
|
||||
<a src="https://img.shields.io/badge/%F0%9F%A4%97-Open_in_Spaces-informational.svg?style=flat-square" href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=trueg">
|
||||
<img src="https://img.shields.io/badge/%F0%9F%A4%97-Open_in_Spaces-informational.svg?style=flat-square">
|
||||
</a>
|
||||
<a src="https://img.shields.io/badge/%F0%9F%9A%80-SUSTech_VIP_Lab-important.svg?style=flat-square" href="https://zhengfenglab.com/">
|
||||
@@ -30,8 +30,7 @@
|
||||
## :rocket: Updates
|
||||
- 2023/04/25: We are delighted to introduce [Caption-Anything](https://github.com/ttengwang/Caption-Anything) :writing_hand:, an inventive project from our lab that combines the capabilities of Segment Anything, Visual Captioning, and ChatGPT.
|
||||
|
||||
- 2023/04/20: We deployed [[DEMO]](https://huggingface.co/spaces/watchtowerss/Track-Anything) on Hugging Face :hugs:!
|
||||
|
||||
- 2023/04/20: We deployed [[DEMO]](https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=trueg) on Hugging Face :hugs:!
|
||||
## Demo
|
||||
|
||||
https://user-images.githubusercontent.com/28050374/232842703-8395af24-b13e-4b8e-aafb-e94b61e6c449.MP4
|
||||
|
||||
15
app.py
15
app.py
@@ -13,6 +13,8 @@ import requests
|
||||
import json
|
||||
import torchvision
|
||||
import torch
|
||||
from tools.interact_tools import SamControler
|
||||
from tracker.base_tracker import BaseTracker
|
||||
from tools.painter import mask_painter
|
||||
try:
|
||||
from mmcv.cnn import ConvModule
|
||||
@@ -204,6 +206,7 @@ def show_mask(video_state, interactive_state, mask_dropdown):
|
||||
|
||||
# tracking vos
|
||||
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
||||
|
||||
model.xmem.clear_memory()
|
||||
if interactive_state["track_end_number"]:
|
||||
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
||||
@@ -223,6 +226,8 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
||||
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
||||
fps = video_state["fps"]
|
||||
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
||||
# clear GPU memory
|
||||
model.xmem.clear_memory()
|
||||
|
||||
if interactive_state["track_end_number"]:
|
||||
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
||||
@@ -262,6 +267,7 @@ def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
||||
|
||||
# inpaint
|
||||
def inpaint_video(video_state, interactive_state, mask_dropdown):
|
||||
|
||||
frames = np.asarray(video_state["origin_images"])
|
||||
fps = video_state["fps"]
|
||||
inpaint_masks = np.asarray(video_state["masks"])
|
||||
@@ -342,6 +348,12 @@ e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id,
|
||||
# initialize sam, xmem, e2fgvi models
|
||||
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, e2fgvi_checkpoint,args)
|
||||
|
||||
|
||||
title = """<p><h1 align="center">Track-Anything</h1></p>
|
||||
"""
|
||||
description = """<p>Gradio demo for Track Anything, a flexible and interactive tool for video object tracking, segmentation, and inpainting. I To use it, simply upload your video, or click one of the examples to load them. Code: <a href="https://github.com/gaomingqi/Track-Anything">https://github.com/gaomingqi/Track-Anything</a> <a href="https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
||||
|
||||
|
||||
with gr.Blocks() as iface:
|
||||
"""
|
||||
state for
|
||||
@@ -373,7 +385,8 @@ with gr.Blocks() as iface:
|
||||
"fps": 30
|
||||
}
|
||||
)
|
||||
|
||||
gr.Markdown(title)
|
||||
gr.Markdown(description)
|
||||
with gr.Row():
|
||||
|
||||
# for user video input
|
||||
|
||||
@@ -69,10 +69,11 @@ class BaseInpainter:
|
||||
size = None
|
||||
else:
|
||||
size = [int(W*ratio), int(H*ratio)]
|
||||
if size[0] % 2 > 0:
|
||||
size[0] += 1
|
||||
if size[1] % 2 > 0:
|
||||
size[1] += 1
|
||||
size = [si+1 if si%2>0 else si for si in size] # only consider even values
|
||||
# shortest side should be larger than 50
|
||||
if min(size) < 50:
|
||||
ratio = 50. / min(H, W)
|
||||
size = [int(W*ratio), int(H*ratio)]
|
||||
|
||||
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
||||
binary_masks = resize_masks(masks, tuple(size))
|
||||
@@ -156,7 +157,7 @@ if __name__ == '__main__':
|
||||
base_inpainter = BaseInpainter(checkpoint, device)
|
||||
# 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
|
||||
# ratio: (0, 1], ratio for down sample, default value is 1
|
||||
inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=1) # numpy array, T, H, W, 3
|
||||
inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.01) # numpy array, T, H, W, 3
|
||||
# ----------------------------------------------
|
||||
# end
|
||||
# ----------------------------------------------
|
||||
|
||||
@@ -12,9 +12,12 @@ import argparse
|
||||
class TrackingAnything():
|
||||
def __init__(self, sam_checkpoint, xmem_checkpoint, e2fgvi_checkpoint, args):
|
||||
self.args = args
|
||||
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
|
||||
self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
|
||||
self.baseinpainter = BaseInpainter(e2fgvi_checkpoint, args.device)
|
||||
self.sam_checkpoint = sam_checkpoint
|
||||
self.xmem_checkpoint = xmem_checkpoint
|
||||
self.e2fgvi_checkpoint = e2fgvi_checkpoint
|
||||
self.samcontroler = SamControler(self.sam_checkpoint, args.sam_model_type, args.device)
|
||||
self.xmem = BaseTracker(self.xmem_checkpoint, device=args.device)
|
||||
self.baseinpainter = BaseInpainter(self.e2fgvi_checkpoint, args.device)
|
||||
# def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
||||
# same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
# if first_flag:
|
||||
|
||||
@@ -126,6 +126,7 @@ class BaseTracker:
|
||||
def clear_memory(self):
|
||||
self.tracker.clear_memory()
|
||||
self.mapper.clear_labels()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
## how to use:
|
||||
|
||||
Reference in New Issue
Block a user