This commit is contained in:
gaomingqi
2023-04-14 03:19:00 +08:00
3 changed files with 122 additions and 49 deletions

130
app.py
View File

@@ -3,18 +3,36 @@ from demo import automask_image_app, automask_video_app, sahi_autoseg_app
import argparse
import cv2
import time
def pause_video():
print(time.time())
def play_video():
print("play video")
print(time.time)
from PIL import Image
import numpy as np
def get_frames_from_video(video_path, timestamp):
from tools.interact_tools import initialize
initialize()
def pause_video(play_state):
print("user pause_video")
play_state.append(time.time())
return play_state
def play_video(play_state):
print("user play_video")
play_state.append(time.time())
return play_state
def get_frames_from_video(video_input, play_state):
"""
Args:
video_path:str
timestamp:float64
return [[0:nearest_frame-1], [nearest_frame+1], nearest_frame]
Return
[[0:nearest_frame-1], [nearest_frame+1], nearest_frame]
"""
video_path = video_input
timestamp = play_state[1] - play_state[0]
frames = []
try:
cap = cv2.VideoCapture(video_path)
@@ -27,52 +45,49 @@ def get_frames_from_video(video_path, timestamp):
break
except (OSError, TypeError, ValueError, KeyError, SyntaxError) as e:
print("read_frame_source:{} error. {}\n".format(video_path, str(e)))
for frame in frames:
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
key_frame_index = int(timestamp * fps)
nearest_frame = frames[key_frame_index]
frames = [frames[:key_frame_index], frames[key_frame_index:], nearest_frame]
return frames
return frames, nearest_frame
with gr.Blocks() as iface:
state = gr.State([])
play_state = gr.State([])
video_state = gr.State([[],[],[]])
with gr.Row():
with gr.Column(scale=1.0):
seg_automask_video_file = gr.Video().style(height=720)
seg_automask_video_file.play(fn=play_video)
seg_automask_video_file.pause(fn=pause_video)
video_input = gr.Video().style(height=720)
# listen to the user action for play and pause input video
video_input.play(fn=play_video, inputs=play_state, outputs=play_state)
video_input.pause(fn=pause_video, inputs=play_state, outputs=play_state)
with gr.Row():
with gr.Column():
seg_automask_video_model_type = gr.Dropdown(
choices=[
"vit_h",
"vit_l",
"vit_b",
],
value="vit_l",
label="Model Type",
)
seg_automask_video_min_area = gr.Number(
value=1000,
label="Min Area",
)
with gr.Row():
with gr.Column():
seg_automask_video_points_per_side = gr.Slider(
minimum=0,
maximum=32,
step=2,
value=16,
label="Points per Side",
)
seg_automask_video_points_per_batch = gr.Slider(
minimum=0,
maximum=64,
step=2,
value=64,
label="Points per Batch",
)
with gr.Column(scale=0.5):
template_frame = gr.Image(type="pil", interactive=True, elem_id="template_frame")
with gr.Column():
template_select_button = gr.Button(value="Template select", interactive=True, variant="primary")
with gr.Column(scale=0.5):
with gr.Row(scale=0.4):
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
clear_button_image = gr.Button(value="Clear Image", interactive=True)
# seg_automask_video_points_per_batch = gr.Slider(
# minimum=0,
# maximum=64,
# step=2,
# value=64,
# label="Points per Batch",
# )
seg_automask_video_predict = gr.Button(value="Generator")
@@ -99,9 +114,40 @@ with gr.Blocks() as iface:
# ],
# outputs=[output_video],
# )
template_select_button.click(
fn=get_frames_from_video,
inputs=[
video_input,
play_state
],
outputs=[video_state, template_frame],
)
# clear
# clear_button_clike.click(
# lambda x: ([[], [], []], x, ""),
# [origin_image],
# [click_state, image_input, wiki_output],
# queue=False,
# show_progress=False
# )
# clear_button_image.click(
# lambda: (None, [], [], [[], [], []], "", ""),
# [],
# [image_input, chatbot, state, click_state, wiki_output, origin_image],
# queue=False,
# show_progress=False
# )
video_input.clear(
lambda: (None, [], [], [[], [], []], None),
[],
[video_input, state, play_state, video_state, template_frame],
queue=False,
show_progress=False
)
iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True, server_port=12212, server_name="0.0.0.0")
iface.launch(debug=True, enable_queue=True, server_port=122, server_name="0.0.0.0")

0
tools/__init__.py Normal file
View File

View File

@@ -7,27 +7,48 @@ from typing import Union
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
import matplotlib.pyplot as plt
import PIL
from mask_painter import mask_painter as mask_painter2
from tools.mask_painter import mask_painter as mask_painter2
from base_segmenter import BaseSegmenter
from painter import mask_painter, point_painter
import os
import requests
mask_color = 3
mask_alpha = 0.7
contour_color = 1
contour_width = 5
point_color = 5
point_color_ne = 8
point_color_ps = 50
point_alpha = 0.9
point_radius = 15
contour_color = 2
contour_width = 5
def download_checkpoint(url, folder, filename):
os.makedirs(folder, exist_ok=True)
filepath = os.path.join(folder, filename)
if not os.path.exists(filepath):
response = requests.get(url, stream=True)
with open(filepath, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
return filepath
def initialize():
'''
initialize sam controler
'''
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth'
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
folder = "segmenter"
SAM_checkpoint= './checkpoints/sam_vit_h_4b8939.pth'
download_checkpoint(checkpoint_url, folder, SAM_checkpoint)
model_type = 'vit_h'
device = "cuda:0"
sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
@@ -56,8 +77,11 @@ def first_frame_click(sam_controler, image: np.ndarray, points:np.ndarray, label
masks, scores, logits = sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
assert len(points)==len(labels)
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, points, point_color, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image
@@ -76,7 +100,8 @@ def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, points, point_color, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image
@@ -96,7 +121,8 @@ def interact_loop(sam_controler, image:np.ndarray, same: bool, points:np.ndarray
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, points, point_color, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image
@@ -115,6 +141,7 @@ if __name__ == "__main__":
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)