mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 16:37:58 +01:00
shanggao edit the tools class
This commit is contained in:
8
app.py
8
app.py
@@ -5,6 +5,14 @@ import cv2
|
|||||||
import time
|
import time
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
from tools.interact_tools import initialize
|
||||||
|
|
||||||
|
|
||||||
|
initialize()
|
||||||
|
|
||||||
|
|
||||||
def pause_video(play_state):
|
def pause_video(play_state):
|
||||||
print("user pause_video")
|
print("user pause_video")
|
||||||
play_state.append(time.time())
|
play_state.append(time.time())
|
||||||
|
|||||||
0
tools/__init__.py
Normal file
0
tools/__init__.py
Normal file
@@ -7,10 +7,11 @@ from typing import Union
|
|||||||
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import PIL
|
import PIL
|
||||||
from mask_painter import mask_painter as mask_painter2
|
from tools.mask_painter import mask_painter as mask_painter2
|
||||||
from base_segmenter import BaseSegmenter
|
from base_segmenter import BaseSegmenter
|
||||||
from painter import mask_painter, point_painter
|
from painter import mask_painter, point_painter
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
|
||||||
mask_color = 3
|
mask_color = 3
|
||||||
mask_alpha = 0.7
|
mask_alpha = 0.7
|
||||||
@@ -24,11 +25,30 @@ contour_color = 2
|
|||||||
contour_width = 5
|
contour_width = 5
|
||||||
|
|
||||||
|
|
||||||
|
def download_checkpoint(url, folder, filename):
|
||||||
|
os.makedirs(folder, exist_ok=True)
|
||||||
|
filepath = os.path.join(folder, filename)
|
||||||
|
|
||||||
|
if not os.path.exists(filepath):
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
with open(filepath, "wb") as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
|
||||||
def initialize():
|
def initialize():
|
||||||
'''
|
'''
|
||||||
initialize sam controler
|
initialize sam controler
|
||||||
'''
|
'''
|
||||||
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||||
|
folder = "segmenter"
|
||||||
|
SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
|
||||||
|
download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
||||||
|
|
||||||
|
|
||||||
model_type = 'vit_h'
|
model_type = 'vit_h'
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
||||||
|
|||||||
Reference in New Issue
Block a user