mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
Merge branch 'master' of https://github.com/gaomingqi/Track-Anything
This commit is contained in:
@@ -8,6 +8,7 @@
|
|||||||
This is Demo
|
This is Demo
|
||||||
## Get Started
|
## Get Started
|
||||||
|
|
||||||
|
|
||||||
This is Get Started.
|
This is Get Started.
|
||||||
## Acknowledgement
|
## Acknowledgement
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class SamControler():
|
|||||||
initialize sam controler
|
initialize sam controler
|
||||||
'''
|
'''
|
||||||
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||||
folder ="segmenter/checkpoints"
|
folder ="checkpoints"
|
||||||
SAM_checkpoint= 'sam_vit_h_4b8939.pth'
|
SAM_checkpoint= 'sam_vit_h_4b8939.pth'
|
||||||
SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
||||||
# SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
# SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ import torch
|
|||||||
import yaml
|
import yaml
|
||||||
from model.network import XMem
|
from model.network import XMem
|
||||||
from inference.inference_core import InferenceCore
|
from inference.inference_core import InferenceCore
|
||||||
|
# for data transormation
|
||||||
|
from torchvision import transforms
|
||||||
|
from dataset.range_transform import im_normalization
|
||||||
|
|
||||||
|
|
||||||
class BaseTracker:
|
class BaseTracker:
|
||||||
@@ -24,14 +27,27 @@ class BaseTracker:
|
|||||||
network = XMem(config, xmem_checkpoint).to(device).eval()
|
network = XMem(config, xmem_checkpoint).to(device).eval()
|
||||||
# initialise IncerenceCore
|
# initialise IncerenceCore
|
||||||
self.tracker = InferenceCore(network, config)
|
self.tracker = InferenceCore(network, config)
|
||||||
# set data transformation
|
# data transformation
|
||||||
# self.data_transform =
|
self.im_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
im_normalization,
|
||||||
|
])
|
||||||
|
|
||||||
def track(self, frames, first_frame_annotation):
|
def track(self, frames, first_frame_annotation):
|
||||||
|
"""
|
||||||
|
Input:
|
||||||
|
frames: numpy arrays: T, H, W, 3 (T: number of frames)
|
||||||
|
first_frame_annotation: numpy array: H, W
|
||||||
|
|
||||||
|
Output:
|
||||||
|
masks: numpy arrays: T, H, W
|
||||||
|
"""
|
||||||
# data transformation
|
# data transformation
|
||||||
|
for frame in frames:
|
||||||
|
frame = self.im_transform(frame)
|
||||||
|
|
||||||
# tracking
|
# tracking
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -42,7 +58,10 @@ if __name__ == '__main__':
|
|||||||
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
|
||||||
|
|
||||||
# load frames
|
# load frames
|
||||||
frames = []
|
<<<<<<< HEAD
|
||||||
|
=======
|
||||||
|
frames = ["test_confict"]
|
||||||
|
>>>>>>> a5606340a199569856ffa1585aeeff5a40cc34ba
|
||||||
for video_path in video_path_list:
|
for video_path in video_path_list:
|
||||||
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
frames.append(np.array(Image.open(video_path).convert('RGB')))
|
||||||
frames = np.stack(frames, 0) # N, H, W, C
|
frames = np.stack(frames, 0) # N, H, W, C
|
||||||
|
|||||||
Reference in New Issue
Block a user