mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-15 16:07:51 +01:00
verison before multi-mask --li
This commit is contained in:
4
app.py
4
app.py
@@ -219,8 +219,8 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
|
||||
|
||||
# args, defined in track_anything.py
|
||||
args = parse_augment()
|
||||
args.port = 12212
|
||||
args.device = "cuda:5"
|
||||
args.port = 12213
|
||||
args.device = "cuda:4"
|
||||
|
||||
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ if __name__ == "__main__":
|
||||
# initialise BaseSegmenter
|
||||
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||
model_type = 'vit_h'
|
||||
device = "cuda:0"
|
||||
device = "cuda:4"
|
||||
base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
|
||||
|
||||
# image embedding (once embedded, multiple prompts can be applied)
|
||||
|
||||
@@ -12,7 +12,7 @@ class TrackingAnything():
|
||||
def __init__(self, sam_checkpoint, xmem_checkpoint, args):
|
||||
self.args = args
|
||||
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
|
||||
self.xmem = BaseTracker(xmem_checkpoint, device=args.device, sam_checkpoint=sam_checkpoint, model_type=args.sam_model_type)
|
||||
self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
|
||||
|
||||
|
||||
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
||||
|
||||
@@ -20,7 +20,7 @@ from torchvision.transforms import Resize
|
||||
|
||||
|
||||
class BaseTracker:
|
||||
def __init__(self, xmem_checkpoint, device, sam_model, model_type=None) -> None:
|
||||
def __init__(self, xmem_checkpoint, device, sam_model=None, model_type=None) -> None:
|
||||
"""
|
||||
device: model device
|
||||
xmem_checkpoint: checkpoint of XMem model
|
||||
|
||||
Reference in New Issue
Block a user