mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 00:17:50 +01:00
Merge branch 'master' of https://github.com/gaomingqi/Track-Anything
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
This is Demo
|
||||
## Get Started
|
||||
|
||||
|
||||
This is Get Started.
|
||||
## Acknowledgement
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ class SamControler():
|
||||
initialize sam controler
|
||||
'''
|
||||
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
folder ="segmenter/checkpoints"
|
||||
folder ="checkpoints"
|
||||
SAM_checkpoint= 'sam_vit_h_4b8939.pth'
|
||||
SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
||||
# SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||
|
||||
@@ -9,6 +9,9 @@ import torch
|
||||
import yaml
|
||||
from model.network import XMem
|
||||
from inference.inference_core import InferenceCore
|
||||
# for data transormation
|
||||
from torchvision import transforms
|
||||
from dataset.range_transform import im_normalization
|
||||
|
||||
|
||||
class BaseTracker:
|
||||
@@ -24,14 +27,27 @@ class BaseTracker:
|
||||
network = XMem(config, xmem_checkpoint).to(device).eval()
|
||||
# initialise IncerenceCore
|
||||
self.tracker = InferenceCore(network, config)
|
||||
# set data transformation
|
||||
# self.data_transform =
|
||||
# data transformation
|
||||
self.im_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
im_normalization,
|
||||
])
|
||||
|
||||
def track(self, frames, first_frame_annotation):
|
||||
"""
|
||||
Input:
|
||||
frames: numpy arrays: T, H, W, 3 (T: number of frames)
|
||||
first_frame_annotation: numpy array: H, W
|
||||
|
||||
Output:
|
||||
masks: numpy arrays: T, H, W
|
||||
"""
|
||||
# data transformation
|
||||
for frame in frames:
|
||||
frame = self.im_transform(frame)
|
||||
|
||||
# tracking
|
||||
pass
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -42,7 +58,10 @@ if __name__ == '__main__':
|
||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
||||
|
||||
# load frames
|
||||
frames = []
|
||||
<<<<<<< HEAD
|
||||
=======
|
||||
frames = ["test_confict"]
|
||||
>>>>>>> a5606340a199569856ffa1585aeeff5a40cc34ba
|
||||
for video_path in video_path_list:
|
||||
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
||||
frames = np.stack(frames, 0) # N, H, W, C
|
||||
|
||||
Reference in New Issue
Block a user