diff --git a/app.py b/app.py index 9a7440a..488d3d7 100644 --- a/app.py +++ b/app.py @@ -5,6 +5,14 @@ import cv2 import time from PIL import Image import numpy as np + + +from tools.interact_tools import initialize + + +initialize() + + def pause_video(play_state): print("user pause_video") play_state.append(time.time()) diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tools/interact_tools.py b/tools/interact_tools.py index 88c6580..997fcdf 100644 --- a/tools/interact_tools.py +++ b/tools/interact_tools.py @@ -7,10 +7,11 @@ from typing import Union from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator import matplotlib.pyplot as plt import PIL -from mask_painter import mask_painter as mask_painter2 +from tools.mask_painter import mask_painter as mask_painter2 from base_segmenter import BaseSegmenter from painter import mask_painter, point_painter - +import os +import requests mask_color = 3 mask_alpha = 0.7 @@ -24,11 +25,30 @@ contour_color = 2 contour_width = 5 +def download_checkpoint(url, folder, filename): + os.makedirs(folder, exist_ok=True) + filepath = os.path.join(folder, filename) + + if not os.path.exists(filepath): + response = requests.get(url, stream=True) + with open(filepath, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + f.write(chunk) + + return filepath + + def initialize(): ''' initialize sam controler ''' - SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' + checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" + folder = "segmenter" + SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth' + download_checkpoint(checkpoint_url, folder, SAM_checkpoint) + + model_type = 'vit_h' device = "cuda:0" sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)