update base_tracker.py

This commit is contained in:
gaomingqi
2023-04-14 05:38:02 +08:00
4 changed files with 31 additions and 9 deletions

4
app.py
View File

@@ -31,7 +31,7 @@ def get_frames_from_video(video_input, play_state):
video_path:str
timestamp:float64
Return
[[0:nearest_frame-1], [nearest_frame+1], nearest_frame]
[[0:nearest_frame], [nearest_frame:], nearest_frame]
"""
video_path = video_input
timestamp = play_state[1] - play_state[0]
@@ -149,7 +149,7 @@ with gr.Blocks() as iface:
)
iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True, server_port=122, server_name="0.0.0.0")
iface.launch(debug=True, enable_queue=True, server_port=12200, server_name="0.0.0.0")

View File

@@ -31,21 +31,23 @@ def download_checkpoint(url, folder, filename):
filepath = os.path.join(folder, filename)
if not os.path.exists(filepath):
print("download sam checkpoints ......")
response = requests.get(url, stream=True)
with open(filepath, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
print("download successfully!")
return filepath
class SamControler():
def __init__(self):
def __init__(self, sam_checkpoint, model_type, device):
'''
initialize sam controler
'''
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
folder ="segmenter/checkpoints"
folder ="checkpoints"
SAM_checkpoint= 'sam_vit_h_4b8939.pth'
SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
# SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'

View File

@@ -1,12 +1,28 @@
from tools.interact_tools import SamControler
from tracker.xmem import XMem
from tracker.base_tracker import BaseTracker
import numpy as np
class TrackingAnything():
def __init__(self, cfg):
self.cfg = cfg
self.samcontroler = SamControler()
self.xmem =
pass
self.samcontroler = SamControler(cfg.sam_checkpoint, cfg.model_type, cfg.device)
self.xmem = BaseTracker(cfg.device, cfg.xmem_checkpoint)
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
if first_flag:
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
return mask, logit, painted_image
if interact_flag:
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
return mask, logit, painted_image
mask, logit, painted_image = self.xmem.track(image, logit)
return mask, logit, painted_image

View File

@@ -94,7 +94,11 @@ if __name__ == '__main__':
# first frame
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
# load frames
<<<<<<< HEAD
frames = []
=======
frames = ["test_confict"]
>>>>>>> 5ca44baea36b7c66043342afc9ffb966e6d24417
for video_path in video_path_list:
frames.append(np.array(Image.open(video_path).convert('RGB')))
frames = np.stack(frames, 0) # N, H, W, C