From ebfb0c00f94300fbd8f0b7a936f443ba0005f529 Mon Sep 17 00:00:00 2001 From: gaomingqi Date: Fri, 14 Apr 2023 03:13:58 +0800 Subject: [PATCH 1/2] add base_tracker --- .gitignore | 1 + requirements.txt | 1 + tracker/base_tracker.py | 59 ++++++++++++++++++++++++++++++++++ tracker/config/config.yaml | 15 +++++++++ tracker/dataset/vos_dataset.py | 2 +- tracker/xmem.py | 29 ----------------- 6 files changed, 77 insertions(+), 30 deletions(-) create mode 100644 tracker/base_tracker.py create mode 100644 tracker/config/config.yaml delete mode 100644 tracker/xmem.py diff --git a/.gitignore b/.gitignore index d3d66d2..22b2b22 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ docs/ *.mp4 debug_images/ *.png +*.jpg diff --git a/requirements.txt b/requirements.txt index ca365f5..4bdca8b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ matplotlib onnxruntime onnx metaseg +pyyaml diff --git a/tracker/base_tracker.py b/tracker/base_tracker.py new file mode 100644 index 0000000..67940e1 --- /dev/null +++ b/tracker/base_tracker.py @@ -0,0 +1,59 @@ +# input: frame list, first frame mask +# output: segmentation results on all frames +import os +import glob +import numpy as np +from PIL import Image + +import torch +import yaml +from model.network import XMem +from inference.inference_core import InferenceCore + + +class BaseTracker: + def __init__(self, device, xmem_checkpoint) -> None: + """ + device: model device + xmem_checkpoint: checkpoint of XMem model + """ + # load configurations + with open("tracker/config/config.yaml", 'r') as stream: + config = yaml.safe_load(stream) + # initialise XMem + network = XMem(config, xmem_checkpoint).to(device).eval() + # initialise IncerenceCore + self.tracker = InferenceCore(network, config) + # set data transformation + # self.data_transform = + + def track(self, frames, first_frame_annotation): + # data transformation + + # tracking + pass + + +if __name__ == '__main__': + # video frames + video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg')) + video_path_list.sort() + # first frame + first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png' + + # load frames + frames = [] + for video_path in video_path_list: + frames.append(np.array(Image.open(video_path).convert('RGB'))) + frames = np.stack(frames, 0) # N, H, W, C + + # load first frame annotation + first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C + + # initalise tracker + device = 'cuda:0' + XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth' + tracker = BaseTracker('cuda:0', XMEM_checkpoint) + + # track anything given in the first frame annotation + tracker.track(frames, first_frame_annotation) diff --git a/tracker/config/config.yaml b/tracker/config/config.yaml new file mode 100644 index 0000000..82e494d --- /dev/null +++ b/tracker/config/config.yaml @@ -0,0 +1,15 @@ +# config info for XMem +benchmark: False +disable_long_term: False +max_mid_term_frames: 10 +min_mid_term_frames: 5 +max_long_term_elements: 10000 +num_prototypes: 128 +top_k: 30 +mem_every: 5 +deep_update_every: -1 +save_scores: False +flip: False +size: 480 +enable_long_term: True +enable_long_term_count_usage: True diff --git a/tracker/dataset/vos_dataset.py b/tracker/dataset/vos_dataset.py index be0f8a1..2b5d365 100644 --- a/tracker/dataset/vos_dataset.py +++ b/tracker/dataset/vos_dataset.py @@ -213,4 +213,4 @@ class VOSDataset(Dataset): return data def __len__(self): - return len(self.videos) \ No newline at end of file + return len(self.videos) diff --git a/tracker/xmem.py b/tracker/xmem.py deleted file mode 100644 index 1465843..0000000 --- a/tracker/xmem.py +++ /dev/null @@ -1,29 +0,0 @@ -# input: frame list, first frame mask -# output: segmentation results on all frames -import os -import glob -import numpy as np -from PIL import Image - - -class XMem: - # based on https://github.com/hkchengrex/XMem - pass - - -if __name__ == '__main__': - # video frames - video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg')) - video_path_list.sort() - # first frame - first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png' - - # load frames - frames = [] - for video_path in video_path_list: - frames.append(np.array(Image.open(video_path).convert('RGB'))) - frames = np.stack(frames, 0) # N, H, W, C - - # load first frame annotation - first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C - From 1fd5afa8489dadbe1359ab884012804e232fab56 Mon Sep 17 00:00:00 2001 From: memoryunreal <814514103@qq.com> Date: Thu, 13 Apr 2023 19:25:59 +0000 Subject: [PATCH 2/2] rm app_test --- app.py | 4 ++-- app_test.py | 37 ------------------------------------- 2 files changed, 2 insertions(+), 39 deletions(-) delete mode 100644 app_test.py diff --git a/app.py b/app.py index 26f487b..79ba26b 100644 --- a/app.py +++ b/app.py @@ -7,10 +7,10 @@ from PIL import Image import numpy as np -from tools.interact_tools import initialize +# from tools.interact_tools import initialize -initialize() +# initialize() def pause_video(play_state): diff --git a/app_test.py b/app_test.py deleted file mode 100644 index 817f7f7..0000000 --- a/app_test.py +++ /dev/null @@ -1,37 +0,0 @@ -import gradio as gr -import time - -def capture_frame(video): - frame = video.get_frame_at_sec(video.current_time) - return frame - -def capture_time(video): - while True: - if video.paused: - time_paused = video.current_time - return time_paused - -iface = gr.Interface(fn=capture_frame, - inputs=[gr.inputs.Video(type="mp4", label="Input video", - source="upload")], - outputs=["image"], - server_port=12212, - server_name="0.0.0.0", - capture_session=True) - -video_player = iface.video[0] -video_player.pause = False - -time_interface = gr.Interface(fn=capture_time, - inputs=[gr.inputs.Video(type="mp4", label="Input video", - source="upload", max_duration=10)], - outputs=["text"], - server_port=12212, - server_name="0.0.0.0", - capture_session=True) - -time_interface.video[0].play = False -time_interface.video[0].pause = False - -iface.launch() -time_interface.launch()