mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 00:17:50 +01:00
Merge branch 'master' of https://github.com/gaomingqi/VOS-Anything
This commit is contained in:
130
app.py
130
app.py
@@ -3,18 +3,36 @@ from demo import automask_image_app, automask_video_app, sahi_autoseg_app
|
||||
import argparse
|
||||
import cv2
|
||||
import time
|
||||
def pause_video():
|
||||
print(time.time())
|
||||
def play_video():
|
||||
print("play video")
|
||||
print(time.time)
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
def get_frames_from_video(video_path, timestamp):
|
||||
|
||||
from tools.interact_tools import initialize
|
||||
|
||||
|
||||
initialize()
|
||||
|
||||
|
||||
def pause_video(play_state):
|
||||
print("user pause_video")
|
||||
play_state.append(time.time())
|
||||
return play_state
|
||||
|
||||
def play_video(play_state):
|
||||
print("user play_video")
|
||||
play_state.append(time.time())
|
||||
return play_state
|
||||
|
||||
def get_frames_from_video(video_input, play_state):
|
||||
"""
|
||||
Args:
|
||||
video_path:str
|
||||
timestamp:float64
|
||||
return [[0:nearest_frame-1], [nearest_frame+1], nearest_frame]
|
||||
Return
|
||||
[[0:nearest_frame-1], [nearest_frame+1], nearest_frame]
|
||||
"""
|
||||
video_path = video_input
|
||||
timestamp = play_state[1] - play_state[0]
|
||||
frames = []
|
||||
try:
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
@@ -27,52 +45,49 @@ def get_frames_from_video(video_path, timestamp):
|
||||
break
|
||||
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
|
||||
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
|
||||
|
||||
for frame in frames:
|
||||
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
||||
|
||||
key_frame_index = int(timestamp * fps)
|
||||
nearest_frame = frames[key_frame_index]
|
||||
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
|
||||
return frames
|
||||
return frames, nearest_frame
|
||||
|
||||
|
||||
with gr.Blocks() as iface:
|
||||
state = gr.State([])
|
||||
play_state = gr.State([])
|
||||
video_state = gr.State([[],[],[]])
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1.0):
|
||||
seg_automask_video_file = gr.Video().style(height=720)
|
||||
seg_automask_video_file.play(fn=play_video)
|
||||
seg_automask_video_file.pause(fn=pause_video)
|
||||
video_input = gr.Video().style(height=720)
|
||||
|
||||
# listen to the user action for play and pause input video
|
||||
video_input.play(fn=play_video, inputs=play_state, outputs=play_state)
|
||||
video_input.pause(fn=pause_video, inputs=play_state, outputs=play_state)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
seg_automask_video_model_type = gr.Dropdown(
|
||||
choices=[
|
||||
"vit_h",
|
||||
"vit_l",
|
||||
"vit_b",
|
||||
],
|
||||
value="vit_l",
|
||||
label="Model Type",
|
||||
)
|
||||
seg_automask_video_min_area = gr.Number(
|
||||
value=1000,
|
||||
label="Min Area",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
seg_automask_video_points_per_side = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=32,
|
||||
step=2,
|
||||
value=16,
|
||||
label="Points per Side",
|
||||
)
|
||||
|
||||
seg_automask_video_points_per_batch = gr.Slider(
|
||||
minimum=0,
|
||||
maximum=64,
|
||||
step=2,
|
||||
value=64,
|
||||
label="Points per Batch",
|
||||
)
|
||||
with gr.Column(scale=0.5):
|
||||
template_frame = gr.Image(type="pil", interactive=True, elem_id="template_frame")
|
||||
with gr.Column():
|
||||
template_select_button = gr.Button(value="Template select", interactive=True, variant="primary")
|
||||
|
||||
with gr.Column(scale=0.5):
|
||||
with gr.Row(scale=0.4):
|
||||
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
|
||||
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
||||
|
||||
# seg_automask_video_points_per_batch = gr.Slider(
|
||||
# minimum=0,
|
||||
# maximum=64,
|
||||
# step=2,
|
||||
# value=64,
|
||||
# label="Points per Batch",
|
||||
# )
|
||||
|
||||
seg_automask_video_predict = gr.Button(value="Generator")
|
||||
|
||||
@@ -99,9 +114,40 @@ with gr.Blocks() as iface:
|
||||
# ],
|
||||
# outputs=[output_video],
|
||||
# )
|
||||
template_select_button.click(
|
||||
fn=get_frames_from_video,
|
||||
inputs=[
|
||||
video_input,
|
||||
play_state
|
||||
],
|
||||
outputs=[video_state, template_frame],
|
||||
)
|
||||
|
||||
# clear
|
||||
# clear_button_clike.click(
|
||||
# lambda x: ([[], [], []], x, ""),
|
||||
# [origin_image],
|
||||
# [click_state, image_input, wiki_output],
|
||||
# queue=False,
|
||||
# show_progress=False
|
||||
# )
|
||||
# clear_button_image.click(
|
||||
# lambda: (None, [], [], [[], [], []], "", ""),
|
||||
# [],
|
||||
# [image_input, chatbot, state, click_state, wiki_output, origin_image],
|
||||
# queue=False,
|
||||
# show_progress=False
|
||||
# )
|
||||
video_input.clear(
|
||||
lambda: (None, [], [], [[], [], []], None),
|
||||
[],
|
||||
[video_input, state, play_state, video_state, template_frame],
|
||||
queue=False,
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
iface.queue(concurrency_count=1)
|
||||
iface.launch(debug=True, enable_queue=True, server_port=12212, server_name="0.0.0.0")
|
||||
iface.launch(debug=True, enable_queue=True, server_port=122, server_name="0.0.0.0")
|
||||
|
||||
|
||||
|
||||
|
||||
0
tools/__init__.py
Normal file
0
tools/__init__.py
Normal file
@@ -7,27 +7,48 @@ from typing import Union
|
||||
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
||||
import matplotlib.pyplot as plt
|
||||
import PIL
|
||||
from mask_painter import mask_painter as mask_painter2
|
||||
from tools.mask_painter import mask_painter as mask_painter2
|
||||
from base_segmenter import BaseSegmenter
|
||||
from painter import mask_painter, point_painter
|
||||
|
||||
import os
|
||||
import requests
|
||||
|
||||
mask_color = 3
|
||||
mask_alpha = 0.7
|
||||
contour_color = 1
|
||||
contour_width = 5
|
||||
point_color = 5
|
||||
point_color_ne = 8
|
||||
point_color_ps = 50
|
||||
point_alpha = 0.9
|
||||
point_radius = 15
|
||||
contour_color = 2
|
||||
contour_width = 5
|
||||
|
||||
|
||||
def download_checkpoint(url, folder, filename):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
filepath = os.path.join(folder, filename)
|
||||
|
||||
if not os.path.exists(filepath):
|
||||
response = requests.get(url, stream=True)
|
||||
with open(filepath, "wb") as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
def initialize():
|
||||
'''
|
||||
initialize sam controler
|
||||
'''
|
||||
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
|
||||
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
||||
folder = "segmenter"
|
||||
SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
|
||||
download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
|
||||
|
||||
|
||||
model_type = 'vit_h'
|
||||
device = "cuda:0"
|
||||
sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
||||
@@ -56,8 +77,11 @@ def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, label
|
||||
masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
assert len(points)==len(labels)
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, points, point_color, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
|
||||
return mask, logit, painted_image
|
||||
@@ -76,7 +100,8 @@ def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, points, point_color, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
|
||||
return mask, logit, painted_image
|
||||
@@ -96,7 +121,8 @@ def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, points, point_color, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
|
||||
return mask, logit, painted_image
|
||||
@@ -115,6 +141,7 @@ if __name__ == "__main__":
|
||||
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
|
||||
painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
|
||||
|
||||
mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
|
||||
|
||||
Reference in New Issue
Block a user