verison before multi-mask --li

This commit is contained in:
memoryunreal
2023-04-17 12:58:19 +00:00
parent 1635a50e4f
commit 98a568ce34
4 changed files with 5 additions and 5 deletions

4
app.py
View File

@@ -219,8 +219,8 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
# args, defined in track_anything.py
args = parse_augment()
args.port = 12212
args.device = "cuda:5"
args.port = 12213
args.device = "cuda:4"
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)

View File

@@ -85,7 +85,7 @@ if __name__ == "__main__":
# initialise BaseSegmenter
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
model_type = 'vit_h'
device = "cuda:0"
device = "cuda:4"
base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device)
# image embedding (once embedded, multiple prompts can be applied)

View File

@@ -12,7 +12,7 @@ class TrackingAnything():
def __init__(self, sam_checkpoint, xmem_checkpoint, args):
self.args = args
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
self.xmem = BaseTracker(xmem_checkpoint, device=args.device, sam_checkpoint=sam_checkpoint, model_type=args.sam_model_type)
self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,

View File

@@ -20,7 +20,7 @@ from torchvision.transforms import Resize
class BaseTracker:
def __init__(self, xmem_checkpoint, device, sam_model, model_type=None) -> None:
def __init__(self, xmem_checkpoint, device, sam_model=None, model_type=None) -> None:
"""
device: model device
xmem_checkpoint: checkpoint of XMem model