mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-15 16:07:51 +01:00
shanggao edit the tools class
This commit is contained in:
8
app.py
8
app.py
@@ -5,6 +5,14 @@ import cv2
|
||||
import time
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
from tools.interact_tools import initialize
|
||||
|
||||
|
||||
initialize()
|
||||
|
||||
|
||||
def pause_video(play_state):
|
||||
print("user pause_video")
|
||||
play_state.append(time.time())
|
||||
|
||||
0
tools/__init__.py
Normal file
0
tools/__init__.py
Normal file
@@ -7,10 +7,11 @@ from typing import Union
|
||||
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
from mask_painter import mask_painter as mask_painter2
|
||||
from tools.mask_painter import mask_painter as mask_painter2
|
||||
from base_segmenter import BaseSegmenter
|
||||
from painter import mask_painter, point_painter
|
||||
|
||||
import os
|
||||
import requests
|
||||
|
||||
mask_color = 3
|
||||
mask_alpha = 0.7
|
||||
@@ -24,11 +25,30 @@ contour_color = 2
|
||||
contour_width = 5
|
||||
|
||||
|
||||
def download_checkpoint(url, folder, filename):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
filepath = os.path.join(folder, filename)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
response = requests.get(url, stream=True)
|
||||
with open(filepath, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
def initialize():
|
||||
'''
|
||||
initialize sam controler
|
||||
'''
|
||||
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
folder = "segmenter"
|
||||
SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
|
||||
download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
||||
|
||||
|
||||
model_type = 'vit_h'
|
||||
device = "cuda:0"
|
||||
sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
||||
|
||||
Reference in New Issue
Block a user