This commit is contained in:
ShangGaoG
2023-04-14 04:40:18 +08:00
3 changed files with 25 additions and 5 deletions

View File

@@ -8,6 +8,7 @@
This is Demo This is Demo
## Get Started ## Get Started
This is Get Started. This is Get Started.
## Acknowledgement ## Acknowledgement

View File

@@ -45,7 +45,7 @@ class SamControler():
initialize sam controler initialize sam controler
''' '''
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
folder ="segmenter/checkpoints" folder ="checkpoints"
SAM_checkpoint= 'sam_vit_h_4b8939.pth' SAM_checkpoint= 'sam_vit_h_4b8939.pth'
SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint) SAM_checkpoint = download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
# SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' # SAM_checkpoint = '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'

View File

@@ -9,6 +9,9 @@ import torch
import yaml import yaml
from model.network import XMem from model.network import XMem
from inference.inference_core import InferenceCore from inference.inference_core import InferenceCore
# for data transormation
from torchvision import transforms
from dataset.range_transform import im_normalization
class BaseTracker: class BaseTracker:
@@ -24,14 +27,27 @@ class BaseTracker:
network = XMem(config, xmem_checkpoint).to(device).eval() network = XMem(config, xmem_checkpoint).to(device).eval()
# initialise IncerenceCore # initialise IncerenceCore
self.tracker = InferenceCore(network, config) self.tracker = InferenceCore(network, config)
# set data transformation # data transformation
# self.data_transform = self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
def track(self, frames, first_frame_annotation): def track(self, frames, first_frame_annotation):
"""
Input:
frames: numpy arrays: T, H, W, 3 (T: number of frames)
first_frame_annotation: numpy array: H, W
Output:
masks: numpy arrays: T, H, W
"""
# data transformation # data transformation
for frame in frames:
frame = self.im_transform(frame)
# tracking # tracking
pass
if __name__ == '__main__': if __name__ == '__main__':
@@ -42,7 +58,10 @@ if __name__ == '__main__':
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png' first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
# load frames # load frames
frames = [] <<<<<<< HEAD
=======
frames = ["test_confict"]
>>>>>>> a5606340a199569856ffa1585aeeff5a40cc34ba
for video_path in video_path_list: for video_path in video_path_list:
frames.append(np.array(Image.open(video_path).convert('RGB'))) frames.append(np.array(Image.open(video_path).convert('RGB')))
frames = np.stack(frames, 0) # N, H, W, C frames = np.stack(frames, 0) # N, H, W, C