mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 00:17:50 +01:00
update base_tracker.py
This commit is contained in:
4
app.py
4
app.py
@@ -31,7 +31,7 @@ def get_frames_from_video(video_input, play_state):
|
||||
video_path:str
|
||||
timestamp:float64
|
||||
Return
|
||||
[[0:nearest_frame-1], [nearest_frame+1], nearest_frame]
|
||||
[[0:nearest_frame], [nearest_frame:], nearest_frame]
|
||||
"""
|
||||
video_path = video_input
|
||||
timestamp = play_state[1] - play_state[0]
|
||||
@@ -149,7 +149,7 @@ with gr.Blocks() as iface:
|
||||
)
|
||||
|
||||
iface.queue(concurrency_count=1)
|
||||
iface.launch(debug=True, enable_queue=True, server_port=122, server_name="0.0.0.0")
|
||||
iface.launch(debug=True, enable_queue=True, server_port=12200, server_name="0.0.0.0")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -31,21 +31,23 @@ def download_checkpoint(url, folder, filename):
|
||||
filepath = os.path.join(folder, filename)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
print("download sam checkpoints ......")
|
||||
response = requests.get(url, stream=True)
|
||||
with open(filepath, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
|
||||
print("download successfully!")
|
||||
return filepath
|
||||
|
||||
class SamControler():
|
||||
def __init__(self):
|
||||
def __init__(self, sam_checkpoint, model_type, device):
|
||||
'''
|
||||
initialize sam controler
|
||||
'''
|
||||
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
folder ="segmenter/checkpoints"
|
||||
folder ="checkpoints"
|
||||
SAM_checkpoint= 'sam_vit_h_4b8939.pth'
|
||||
SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
||||
# SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||
|
||||
@@ -1,12 +1,28 @@
|
||||
from tools.interact_tools import SamControler
|
||||
from tracker.xmem import XMem
|
||||
|
||||
from tracker.base_tracker import BaseTracker
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class TrackingAnything():
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.samcontroler = SamControler()
|
||||
self.xmem =
|
||||
pass
|
||||
self.samcontroler = SamControler(cfg.sam_checkpoint, cfg.model_type, cfg.device)
|
||||
self.xmem = BaseTracker(cfg.device, cfg.xmem_checkpoint)
|
||||
|
||||
|
||||
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
||||
same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
if first_flag:
|
||||
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
||||
return mask, logit, painted_image
|
||||
|
||||
if interact_flag:
|
||||
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
||||
return mask, logit, painted_image
|
||||
|
||||
mask, logit, painted_image = self.xmem.track(image, logit)
|
||||
return mask, logit, painted_image
|
||||
|
||||
|
||||
|
||||
@@ -94,7 +94,11 @@ if __name__ == '__main__':
|
||||
# first frame
|
||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
||||
# load frames
|
||||
<<<<<<< HEAD
|
||||
frames = []
|
||||
=======
|
||||
frames = ["test_confict"]
|
||||
>>>>>>> 5ca44baea36b7c66043342afc9ffb966e6d24417
|
||||
for video_path in video_path_list:
|
||||
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
||||
frames = np.stack(frames, 0) # N, H, W, C
|
||||
|
||||
Reference in New Issue
Block a user