mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-15 16:07:51 +01:00
add multi-type sam model support args.sam_model_type -- li
This commit is contained in:
23
app.py
23
app.py
@@ -308,20 +308,33 @@ def generate_video_from_frames(frames, output_path, fps=30):
|
||||
torchvision.io.write_video(output_path, frames, fps=fps, video_codec="libx264")
|
||||
return output_path
|
||||
|
||||
|
||||
# args, defined in track_anything.py
|
||||
args = parse_augment()
|
||||
|
||||
# check and download checkpoints if needed
|
||||
SAM_checkpoint = "sam_vit_h_4b8939.pth"
|
||||
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
SAM_checkpoint_dict = {
|
||||
'vit_h': "sam_vit_h_4b8939.pth",
|
||||
'vit_l': "sam_vit_l_0b3195.pth",
|
||||
"vit_b": "sam_vit_b_01ec64.pth"
|
||||
}
|
||||
SAM_checkpoint_url_dict = {
|
||||
'vit_h': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
|
||||
'vit_l': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
|
||||
'vit_b': "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
||||
}
|
||||
sam_checkpoint = SAM_checkpoint_dict[args.sam_model_type]
|
||||
sam_checkpoint_url = SAM_checkpoint_url_dict[args.sam_model_type]
|
||||
xmem_checkpoint = "XMem-s012.pth"
|
||||
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
||||
e2fgvi_checkpoint = "E2FGVI-HQ-CVPR22.pth"
|
||||
e2fgvi_checkpoint_id = "10wGdKSUOie0XmCr8SQ2A2FeDe-mfn5w3"
|
||||
|
||||
|
||||
folder ="./checkpoints"
|
||||
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, SAM_checkpoint)
|
||||
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
||||
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
||||
e2fgvi_checkpoint = download_checkpoint_from_google_drive(e2fgvi_checkpoint_id, folder, e2fgvi_checkpoint)
|
||||
# args, defined in track_anything.py
|
||||
args = parse_augment()
|
||||
# args.port = 12315
|
||||
# args.device = "cuda:2"
|
||||
# args.mask_save = True
|
||||
|
||||
Reference in New Issue
Block a user