shanggao edit the tools class

This commit is contained in:
memoryunreal
2023-04-13 19:04:57 +00:00
parent ebd63a5152
commit 6d477b3ae3
3 changed files with 31 additions and 3 deletions

8
app.py
View File

@@ -5,6 +5,14 @@ import cv2
import time
from PIL import Image
import numpy as np
from tools.interact_tools import initialize
initialize()
def pause_video(play_state):
print("user pause_video")
play_state.append(time.time())

0
tools/__init__.py Normal file
View File

View File

@@ -7,10 +7,11 @@ from typing import Union
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import matplotlib.pyplot as plt
import PIL
from mask_painter import mask_painter as mask_painter2
from tools.mask_painter import mask_painter as mask_painter2
from base_segmenter import BaseSegmenter
from painter import mask_painter, point_painter
import os
import requests
mask_color = 3
mask_alpha = 0.7
@@ -24,11 +25,30 @@ contour_color = 2
contour_width = 5
def download_checkpoint(url, folder, filename):
os.makedirs(folder, exist_ok=True)
filepath = os.path.join(folder, filename)
if not os.path.exists(filepath):
response = requests.get(url, stream=True)
with open(filepath, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
return filepath
def initialize():
'''
initialize sam controler
'''
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
folder = "segmenter"
SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
model_type = 'vit_h'
device = "cuda:0"
sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)