mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
default parameter for template image segment for the first test - watchtower
This commit is contained in:
76
app.py
76
app.py
@@ -5,11 +5,45 @@ import cv2
|
|||||||
import time
|
import time
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.append(sys.path[0]+"/tracker")
|
||||||
|
sys.path.append(sys.path[0]+"/tracker/model")
|
||||||
|
from track_anything import TrackingAnything
|
||||||
|
from track_anything import parse_augment
|
||||||
|
import requests
|
||||||
|
|
||||||
|
def download_checkpoint(url, folder, filename):
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
filepath = os.path.join(folder, filename)
|
||||||
|
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
print("download checkpoints ......")
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
with open(filepath, "wb") as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
print("download successfully!")
|
||||||
|
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
|
||||||
from tools.interact_tools import SamControler
|
# check and download checkpoints if needed
|
||||||
|
SAM_checkpoint = "sam_vit_h_4b8939.pth"
|
||||||
|
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||||
|
xmem_checkpoint = "XMem-s012.pth"
|
||||||
|
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
||||||
|
folder ="./checkpoints"
|
||||||
|
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
|
||||||
|
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
||||||
|
|
||||||
samc = SamControler()
|
# args, defined in track_anything.py
|
||||||
|
args = parse_augment()
|
||||||
|
args.port=12212
|
||||||
|
|
||||||
|
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -25,6 +59,22 @@ def play_video(play_state):
|
|||||||
play_state.append(time.time())
|
play_state.append(time.time())
|
||||||
return play_state
|
return play_state
|
||||||
|
|
||||||
|
# convert points input to prompt state
|
||||||
|
def get_prompt(inputs, click_state):
|
||||||
|
points = []
|
||||||
|
labels = []
|
||||||
|
for input in inputs:
|
||||||
|
points.append(input[:2])
|
||||||
|
labels.append(input[2])
|
||||||
|
click_state[0] = points
|
||||||
|
prompt = {
|
||||||
|
"prompt_type":["click"],
|
||||||
|
"input_point":click_state[0],
|
||||||
|
"input_label":click_state[1],
|
||||||
|
"multimask_output":"True",
|
||||||
|
}
|
||||||
|
return prompt
|
||||||
|
|
||||||
def get_frames_from_video(video_input, play_state):
|
def get_frames_from_video(video_input, play_state):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -56,6 +106,15 @@ def get_frames_from_video(video_input, play_state):
|
|||||||
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
|
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
|
||||||
return frames, nearest_frame
|
return frames, nearest_frame
|
||||||
|
|
||||||
|
def inference_all(template_frame, evt:gr.SelectData):
|
||||||
|
coordinate = "[[{},{},1]]".format(evt.index[0], evt.index[1])
|
||||||
|
|
||||||
|
# default value
|
||||||
|
points = np.array([[evt.index[0],evt.index[1]]])
|
||||||
|
labels= np.array([1])
|
||||||
|
mask, logit, painted_image = model.inference_step(first_flag=True, interact_flag=False, image=np.asarray(template_frame), same_image_flag=False,points=points, labels=labels,logits=None,multimask=True)
|
||||||
|
return painted_image
|
||||||
|
|
||||||
|
|
||||||
with gr.Blocks() as iface:
|
with gr.Blocks() as iface:
|
||||||
state = gr.State([])
|
state = gr.State([])
|
||||||
@@ -125,6 +184,17 @@ with gr.Blocks() as iface:
|
|||||||
outputs=[video_state, template_frame],
|
outputs=[video_state, template_frame],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
template_frame.select(
|
||||||
|
fn=inference_all,
|
||||||
|
inputs=[
|
||||||
|
template_frame
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
template_frame
|
||||||
|
]
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
# clear
|
# clear
|
||||||
# clear_button_clike.click(
|
# clear_button_clike.click(
|
||||||
# lambda x: ([[], [], []], x, ""),
|
# lambda x: ([[], [], []], x, ""),
|
||||||
@@ -149,7 +219,7 @@ with gr.Blocks() as iface:
|
|||||||
)
|
)
|
||||||
|
|
||||||
iface.queue(concurrency_count=1)
|
iface.queue(concurrency_count=1)
|
||||||
iface.launch(debug=True, enable_queue=True, server_port=12200, server_name="0.0.0.0")
|
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,33 +26,14 @@ point_radius = 15
|
|||||||
contour_color = 2
|
contour_color = 2
|
||||||
contour_width = 5
|
contour_width = 5
|
||||||
|
|
||||||
def download_checkpoint(url, folder, filename):
|
|
||||||
os.makedirs(folder, exist_ok=True)
|
|
||||||
filepath = os.path.join(folder, filename)
|
|
||||||
|
|
||||||
if not os.path.exists(filepath):
|
|
||||||
print("download sam checkpoints ......")
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
with open(filepath, "wb") as f:
|
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
|
||||||
if chunk:
|
|
||||||
f.write(chunk)
|
|
||||||
|
|
||||||
print("download successfully!")
|
|
||||||
return filepath
|
|
||||||
|
|
||||||
class SamControler():
|
class SamControler():
|
||||||
def __init__(self, sam_checkpoint, model_type, device):
|
def __init__(self, SAM_checkpoint, model_type, device):
|
||||||
'''
|
'''
|
||||||
initialize sam controler
|
initialize sam controler
|
||||||
'''
|
'''
|
||||||
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
|
||||||
folder ="checkpoints"
|
|
||||||
SAM_checkpoint= 'sam_vit_h_4b8939.pth'
|
|
||||||
SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
|
||||||
# SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
|
||||||
model_type = 'vit_h'
|
|
||||||
device = "cuda:0"
|
|
||||||
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,15 @@
|
|||||||
from tools.interact_tools import SamControler
|
from tools.interact_tools import SamControler
|
||||||
from tracker.base_tracker import BaseTracker
|
from tracker.base_tracker import BaseTracker
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TrackingAnything():
|
class TrackingAnything():
|
||||||
def __init__(self, cfg):
|
def __init__(self, sam_checkpoint, xmem_checkpoint, args):
|
||||||
self.cfg = cfg
|
self.args = args
|
||||||
self.samcontroler = SamControler(cfg.sam_checkpoint, cfg.model_type, cfg.device)
|
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
|
||||||
self.xmem = BaseTracker(cfg.device, cfg.xmem_checkpoint)
|
self.xmem = BaseTracker(xmem_checkpoint, device=args.device, )
|
||||||
|
|
||||||
|
|
||||||
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
||||||
@@ -25,4 +26,14 @@ class TrackingAnything():
|
|||||||
return mask, logit, painted_image
|
return mask, logit, painted_image
|
||||||
|
|
||||||
|
|
||||||
|
def parse_augment():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--device', type=str, default="cuda:0")
|
||||||
|
parser.add_argument('--sam_model_type', type=str, default="vit_h")
|
||||||
|
parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications")
|
||||||
|
parser.add_argument('--debug', action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.debug:
|
||||||
|
print(args)
|
||||||
|
return args
|
||||||
@@ -15,7 +15,7 @@ from dataset.range_transform import im_normalization
|
|||||||
|
|
||||||
|
|
||||||
class BaseTracker:
|
class BaseTracker:
|
||||||
def __init__(self, device, xmem_checkpoint) -> None:
|
def __init__(self, xmem_checkpoint, device) -> None:
|
||||||
"""
|
"""
|
||||||
device: model device
|
device: model device
|
||||||
xmem_checkpoint: checkpoint of XMem model
|
xmem_checkpoint: checkpoint of XMem model
|
||||||
|
|||||||
Reference in New Issue
Block a user