mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-16 08:27:49 +01:00
Merge branch 'master' of https://github.com/gaomingqi/VOS-Anything
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
|
||||
***Track-Anything*** is a flexible and interactive tool for video object tracking and segmentation. It is developed upon [Segment Anything](https://github.com/facebookresearch/segment-anything), can specify anything to track and segment via user clicks only. During tracking, users can flexibly change the objects they wanna track or correct the region of interest if there are any ambiguities. These characteristics enable ***Track-Anything*** to be suitable for:
|
||||
- Video object tracking and segmentation with shot changes.
|
||||
- Data annnotation for video object tracking and segmentation.
|
||||
- Visualized development and data annnotation for video object tracking and segmentation.
|
||||
- Object-centric downstream video tasks, such as video inpainting and editing.
|
||||
|
||||
## Demo
|
||||
|
||||
245
app.py
245
app.py
@@ -17,7 +17,7 @@ import torchvision
|
||||
import torch
|
||||
import concurrent.futures
|
||||
import queue
|
||||
|
||||
from tools.painter import mask_painter, point_painter
|
||||
# download checkpoints
|
||||
def download_checkpoint(url, folder, filename):
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
@@ -84,12 +84,21 @@ def get_frames_from_video(video_input, video_state):
|
||||
"masks": [None]*len(frames),
|
||||
"logits": [None]*len(frames),
|
||||
"select_frame_number": 0,
|
||||
"fps": 30
|
||||
"fps": fps
|
||||
}
|
||||
return video_state, gr.update(visible=True, maximum=len(frames), value=1)
|
||||
video_info = "Video Name: {}, FPS: {}, Total Frames: {}".format(video_state["video_name"], video_state["fps"], len(frames))
|
||||
|
||||
model.samcontroler.sam_controler.reset_image()
|
||||
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
|
||||
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
|
||||
gr.update(visible=True), gr.update(visible=True), \
|
||||
gr.update(visible=True), gr.update(visible=True), \
|
||||
gr.update(visible=True), gr.update(visible=True), \
|
||||
gr.update(visible=True), gr.update(visible=True), \
|
||||
gr.update(visible=True)
|
||||
|
||||
# get the select frame from gradio slider
|
||||
def select_template(image_selection_slider, video_state):
|
||||
def select_template(image_selection_slider, video_state, interactive_state):
|
||||
|
||||
# images = video_state[1]
|
||||
image_selection_slider -= 1
|
||||
@@ -100,8 +109,14 @@ def select_template(image_selection_slider, video_state):
|
||||
model.samcontroler.sam_controler.reset_image()
|
||||
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
|
||||
|
||||
# # clear multi mask
|
||||
# interactive_state["multi_mask"] = {"masks":[], "mask_names":[]}
|
||||
|
||||
return video_state["painted_images"][image_selection_slider], video_state
|
||||
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
|
||||
|
||||
def get_end_number(track_pause_number_slider, interactive_state):
|
||||
interactive_state["track_end_number"] = track_pause_number_slider
|
||||
return interactive_state
|
||||
|
||||
# use sam to get the mask
|
||||
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
|
||||
@@ -133,17 +148,65 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
|
||||
|
||||
return painted_image, video_state, interactive_state
|
||||
|
||||
def add_multi_mask(video_state, interactive_state, mask_dropdown):
|
||||
mask = video_state["masks"][video_state["select_frame_number"]]
|
||||
interactive_state["multi_mask"]["masks"].append(mask)
|
||||
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
||||
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
|
||||
select_frame = show_mask(video_state, interactive_state, mask_dropdown)
|
||||
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
|
||||
|
||||
def clear_click(video_state, click_state):
|
||||
click_state = [[],[]]
|
||||
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
||||
return template_frame, click_state
|
||||
|
||||
def remove_multi_mask(interactive_state):
|
||||
interactive_state["multi_mask"]["mask_names"]= []
|
||||
interactive_state["multi_mask"]["masks"] = []
|
||||
return interactive_state
|
||||
|
||||
def show_mask(video_state, interactive_state, mask_dropdown):
|
||||
mask_dropdown.sort()
|
||||
select_frame = video_state["origin_images"][video_state["select_frame_number"]]
|
||||
|
||||
for i in range(len(mask_dropdown)):
|
||||
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
||||
mask = interactive_state["multi_mask"]["masks"][mask_number]
|
||||
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
|
||||
|
||||
return select_frame
|
||||
|
||||
# tracking vos
|
||||
def vos_tracking_video(video_state, interactive_state):
|
||||
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
|
||||
model.xmem.clear_memory()
|
||||
following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
|
||||
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
||||
if interactive_state["track_end_number"]:
|
||||
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
|
||||
else:
|
||||
following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
|
||||
|
||||
if interactive_state["multi_mask"]["masks"]:
|
||||
if len(mask_dropdown) == 0:
|
||||
mask_dropdown = ["mask_001"]
|
||||
mask_dropdown.sort()
|
||||
template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
|
||||
for i in range(1,len(mask_dropdown)):
|
||||
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
|
||||
template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
|
||||
video_state["masks"][video_state["select_frame_number"]]= template_mask
|
||||
else:
|
||||
template_mask = video_state["masks"][video_state["select_frame_number"]]
|
||||
fps = video_state["fps"]
|
||||
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
|
||||
|
||||
video_state["masks"][video_state["select_frame_number"]:] = masks
|
||||
video_state["logits"][video_state["select_frame_number"]:] = logits
|
||||
video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
|
||||
if interactive_state["track_end_number"]:
|
||||
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
|
||||
video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
|
||||
video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
|
||||
else:
|
||||
video_state["masks"][video_state["select_frame_number"]:] = masks
|
||||
video_state["logits"][video_state["select_frame_number"]:] = logits
|
||||
video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
|
||||
|
||||
video_output = generate_video_from_frames(video_state["painted_images"], output_path="./result/{}".format(video_state["video_name"]), fps=fps) # import video_input to name the output video
|
||||
interactive_state["inference_times"] += 1
|
||||
@@ -176,6 +239,14 @@ def generate_video_from_frames(frames, output_path, fps=30):
|
||||
output_path (str): The path to save the generated video.
|
||||
fps (int, optional): The frame rate of the output video. Defaults to 30.
|
||||
"""
|
||||
# height, width, layers = frames[0].shape
|
||||
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
||||
# video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
||||
# print(output_path)
|
||||
# for frame in frames:
|
||||
# video.write(frame)
|
||||
|
||||
# video.release()
|
||||
frames = torch.from_numpy(np.asarray(frames))
|
||||
if not os.path.exists(os.path.dirname(output_path)):
|
||||
os.makedirs(os.path.dirname(output_path))
|
||||
@@ -193,8 +264,8 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
|
||||
|
||||
# args, defined in track_anything.py
|
||||
args = parse_augment()
|
||||
# args.port = 12212
|
||||
# args.device = "cuda:4"
|
||||
# args.port = 12315
|
||||
# args.device = "cuda:1"
|
||||
# args.mask_save = True
|
||||
|
||||
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
||||
@@ -208,8 +279,15 @@ with gr.Blocks() as iface:
|
||||
"inference_times": 0,
|
||||
"negative_click_times" : 0,
|
||||
"positive_click_times": 0,
|
||||
"mask_save": args.mask_save
|
||||
})
|
||||
"mask_save": args.mask_save,
|
||||
"multi_mask": {
|
||||
"mask_names": [],
|
||||
"masks": []
|
||||
},
|
||||
"track_end_number": None
|
||||
}
|
||||
)
|
||||
|
||||
video_state = gr.State(
|
||||
{
|
||||
"video_name": "",
|
||||
@@ -225,43 +303,47 @@ with gr.Blocks() as iface:
|
||||
with gr.Row():
|
||||
|
||||
# for user video input
|
||||
with gr.Column(scale=1.0):
|
||||
video_input = gr.Video().style(height=360)
|
||||
with gr.Column():
|
||||
with gr.Row(scale=0.4):
|
||||
video_input = gr.Video(autosize=True)
|
||||
video_info = gr.Textbox()
|
||||
|
||||
|
||||
|
||||
with gr.Row(scale=1):
|
||||
with gr.Row():
|
||||
# put the template frame under the radio button
|
||||
with gr.Column(scale=0.5):
|
||||
with gr.Column():
|
||||
# extract frames
|
||||
with gr.Column():
|
||||
extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
|
||||
|
||||
# click points settins, negative or positive, mode continuous or single
|
||||
with gr.Row():
|
||||
with gr.Row(scale=0.5):
|
||||
with gr.Row():
|
||||
point_prompt = gr.Radio(
|
||||
choices=["Positive", "Negative"],
|
||||
value="Positive",
|
||||
label="Point Prompt",
|
||||
interactive=True)
|
||||
interactive=True,
|
||||
visible=False)
|
||||
click_mode = gr.Radio(
|
||||
choices=["Continuous", "Single"],
|
||||
value="Continuous",
|
||||
label="Clicking Mode",
|
||||
interactive=True)
|
||||
with gr.Row(scale=0.5):
|
||||
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True).style(height=160)
|
||||
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
||||
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360)
|
||||
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", invisible=False)
|
||||
interactive=True,
|
||||
visible=False)
|
||||
with gr.Row():
|
||||
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
|
||||
Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
|
||||
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
|
||||
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", visible=False)
|
||||
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
|
||||
|
||||
|
||||
|
||||
|
||||
with gr.Column(scale=0.5):
|
||||
video_output = gr.Video().style(height=360)
|
||||
tracking_video_predict_button = gr.Button(value="Tracking")
|
||||
with gr.Column():
|
||||
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask_select", info=".", visible=False)
|
||||
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
|
||||
video_output = gr.Video(autosize=True, visible=False).style(height=360)
|
||||
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
|
||||
|
||||
# first step: get the video information
|
||||
extract_frames_button.click(
|
||||
@@ -269,27 +351,52 @@ with gr.Blocks() as iface:
|
||||
inputs=[
|
||||
video_input, video_state
|
||||
],
|
||||
outputs=[video_state, image_selection_slider],
|
||||
outputs=[video_state, video_info, template_frame,
|
||||
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
|
||||
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button]
|
||||
)
|
||||
|
||||
# second step: select images from slider
|
||||
image_selection_slider.release(fn=select_template,
|
||||
inputs=[image_selection_slider, video_state],
|
||||
outputs=[template_frame, video_state], api_name="select_image")
|
||||
|
||||
inputs=[image_selection_slider, video_state, interactive_state],
|
||||
outputs=[template_frame, video_state, interactive_state], api_name="select_image")
|
||||
track_pause_number_slider.release(fn=get_end_number,
|
||||
inputs=[track_pause_number_slider, interactive_state],
|
||||
outputs=[interactive_state], api_name="end_image")
|
||||
|
||||
# click select image to get mask using sam
|
||||
template_frame.select(
|
||||
fn=sam_refine,
|
||||
inputs=[video_state, point_prompt, click_state, interactive_state],
|
||||
outputs=[template_frame, video_state, interactive_state]
|
||||
)
|
||||
|
||||
# add different mask
|
||||
Add_mask_button.click(
|
||||
fn=add_multi_mask,
|
||||
inputs=[video_state, interactive_state, mask_dropdown],
|
||||
outputs=[interactive_state, mask_dropdown, template_frame, click_state]
|
||||
)
|
||||
|
||||
remove_mask_button.click(
|
||||
fn=remove_multi_mask,
|
||||
inputs=[interactive_state],
|
||||
outputs=[interactive_state]
|
||||
)
|
||||
|
||||
# tracking video from select image and mask
|
||||
tracking_video_predict_button.click(
|
||||
fn=vos_tracking_video,
|
||||
inputs=[video_state, interactive_state],
|
||||
inputs=[video_state, interactive_state, mask_dropdown],
|
||||
outputs=[video_output, video_state, interactive_state]
|
||||
)
|
||||
|
||||
# click to get mask
|
||||
mask_dropdown.change(
|
||||
fn=show_mask,
|
||||
inputs=[video_state, interactive_state, mask_dropdown],
|
||||
outputs=[template_frame]
|
||||
)
|
||||
|
||||
# clear input
|
||||
video_input.clear(
|
||||
@@ -306,55 +413,41 @@ with gr.Blocks() as iface:
|
||||
"inference_times": 0,
|
||||
"negative_click_times" : 0,
|
||||
"positive_click_times": 0,
|
||||
"mask_save": args.mask_save
|
||||
"mask_save": args.mask_save,
|
||||
"multi_mask": {
|
||||
"mask_names": [],
|
||||
"masks": []
|
||||
},
|
||||
[[],[]]
|
||||
),
|
||||
"track_end_number": 0
|
||||
},
|
||||
[[],[]],
|
||||
None,
|
||||
None,
|
||||
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
||||
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
|
||||
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False) \
|
||||
|
||||
),
|
||||
[],
|
||||
[
|
||||
video_state,
|
||||
interactive_state,
|
||||
click_state,
|
||||
video_output,
|
||||
template_frame,
|
||||
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
|
||||
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button
|
||||
],
|
||||
queue=False,
|
||||
show_progress=False
|
||||
)
|
||||
clear_button_image.click(
|
||||
lambda: (
|
||||
{
|
||||
"origin_images": None,
|
||||
"painted_images": None,
|
||||
"masks": None,
|
||||
"logits": None,
|
||||
"select_frame_number": 0,
|
||||
"fps": 30
|
||||
},
|
||||
{
|
||||
"inference_times": 0,
|
||||
"negative_click_times" : 0,
|
||||
"positive_click_times": 0,
|
||||
"mask_save": args.mask_save
|
||||
},
|
||||
[[],[]]
|
||||
),
|
||||
[],
|
||||
[
|
||||
video_state,
|
||||
interactive_state,
|
||||
click_state,
|
||||
],
|
||||
show_progress=False)
|
||||
|
||||
queue=False,
|
||||
show_progress=False
|
||||
# points clear
|
||||
clear_button_click.click(
|
||||
fn = clear_click,
|
||||
inputs = [video_state, click_state,],
|
||||
outputs = [template_frame,click_state],
|
||||
|
||||
)
|
||||
clear_button_clike.click(
|
||||
lambda: ([[],[]]),
|
||||
[],
|
||||
[click_state],
|
||||
queue=False,
|
||||
show_progress=False
|
||||
)
|
||||
iface.queue(concurrency_count=1)
|
||||
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")
|
||||
|
||||
|
||||
61
app_test.py
61
app_test.py
@@ -1,23 +1,46 @@
|
||||
# import gradio as gr
|
||||
|
||||
# def update_iframe(slider_value):
|
||||
# return f'''
|
||||
# <script>
|
||||
# window.addEventListener('message', function(event) {{
|
||||
# if (event.data.sliderValue !== undefined) {{
|
||||
# var iframe = document.getElementById("text_iframe");
|
||||
# iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
|
||||
# }}
|
||||
# }}, false);
|
||||
# </script>
|
||||
# <iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
|
||||
# '''
|
||||
|
||||
# iface = gr.Interface(
|
||||
# fn=update_iframe,
|
||||
# inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
|
||||
# outputs=gr.outputs.HTML(),
|
||||
# allow_flagging=False,
|
||||
# )
|
||||
|
||||
# iface.launch(server_name='0.0.0.0', server_port=12212)
|
||||
|
||||
import gradio as gr
|
||||
|
||||
def update_iframe(slider_value):
|
||||
return f'''
|
||||
<script>
|
||||
window.addEventListener('message', function(event) {{
|
||||
if (event.data.sliderValue !== undefined) {{
|
||||
var iframe = document.getElementById("text_iframe");
|
||||
iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
|
||||
}}
|
||||
}}, false);
|
||||
</script>
|
||||
<iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
|
||||
'''
|
||||
|
||||
iface = gr.Interface(
|
||||
fn=update_iframe,
|
||||
inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
|
||||
outputs=gr.outputs.HTML(),
|
||||
allow_flagging=False,
|
||||
)
|
||||
def change_mask(drop):
|
||||
return gr.update(choices=["hello", "kitty"])
|
||||
|
||||
iface.launch(server_name='0.0.0.0', server_port=12212)
|
||||
with gr.Blocks() as iface:
|
||||
drop = gr.Dropdown(
|
||||
choices=["cat", "dog", "bird"], label="Animal", info="Will add more animals later!"
|
||||
)
|
||||
radio = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
|
||||
multi_drop = gr.Dropdown(
|
||||
["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True, label="Activity", info="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed auctor, nisl eget ultricies aliquam, nunc nisl aliquet nunc, eget aliquam nisl nunc vel nisl."
|
||||
)
|
||||
|
||||
multi_drop.change(
|
||||
fn=change_mask,
|
||||
inputs = multi_drop,
|
||||
outputs=multi_drop
|
||||
)
|
||||
|
||||
iface.launch(server_name='0.0.0.0', server_port=1223)
|
||||
0
test_beta.txt
Normal file
0
test_beta.txt
Normal file
@@ -37,16 +37,16 @@ class SamControler():
|
||||
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
|
||||
|
||||
|
||||
def seg_again(self, image: np.ndarray):
|
||||
'''
|
||||
it is used when interact in video
|
||||
'''
|
||||
self.sam_controler.reset_image()
|
||||
self.sam_controler.set_image(image)
|
||||
return
|
||||
# def seg_again(self, image: np.ndarray):
|
||||
# '''
|
||||
# it is used when interact in video
|
||||
# '''
|
||||
# self.sam_controler.reset_image()
|
||||
# self.sam_controler.set_image(image)
|
||||
# return
|
||||
|
||||
|
||||
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
|
||||
'''
|
||||
it is used in first frame in video
|
||||
return: mask, logit, painted image(mask+point)
|
||||
@@ -88,47 +88,47 @@ class SamControler():
|
||||
|
||||
return mask, logit, painted_image
|
||||
|
||||
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
origal_image = self.sam_controler.orignal_image
|
||||
if same:
|
||||
'''
|
||||
true; loop in the same image
|
||||
'''
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
'mask_input': logits[None, :, :]
|
||||
}
|
||||
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
# def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
# origal_image = self.sam_controler.orignal_image
|
||||
# if same:
|
||||
# '''
|
||||
# true; loop in the same image
|
||||
# '''
|
||||
# prompts = {
|
||||
# 'point_coords': points,
|
||||
# 'point_labels': labels,
|
||||
# 'mask_input': logits[None, :, :]
|
||||
# }
|
||||
# masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
|
||||
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
|
||||
return mask, logit, painted_image
|
||||
else:
|
||||
'''
|
||||
loop in the different image, interact in the video
|
||||
'''
|
||||
if image is None:
|
||||
raise('Image error')
|
||||
else:
|
||||
self.seg_again(image)
|
||||
prompts = {
|
||||
'point_coords': points,
|
||||
'point_labels': labels,
|
||||
}
|
||||
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
# return mask, logit, painted_image
|
||||
# else:
|
||||
# '''
|
||||
# loop in the different image, interact in the video
|
||||
# '''
|
||||
# if image is None:
|
||||
# raise('Image error')
|
||||
# else:
|
||||
# self.seg_again(image)
|
||||
# prompts = {
|
||||
# 'point_coords': points,
|
||||
# 'point_labels': labels,
|
||||
# }
|
||||
# masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
|
||||
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
|
||||
|
||||
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
|
||||
return mask, logit, painted_image
|
||||
# return mask, logit, painted_image
|
||||
|
||||
|
||||
|
||||
@@ -226,31 +226,31 @@ class SamControler():
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
points = np.array([[500, 375], [1125, 625]])
|
||||
labels = np.array([1, 1])
|
||||
image = cv2.imread('/hhd3/gaoshang/truck.jpg')
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
# if __name__ == "__main__":
|
||||
# points = np.array([[500, 375], [1125, 625]])
|
||||
# labels = np.array([1, 1])
|
||||
# image = cv2.imread('/hhd3/gaoshang/truck.jpg')
|
||||
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
sam_controler = initialize()
|
||||
mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
|
||||
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
|
||||
painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
|
||||
# sam_controler = initialize()
|
||||
# mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
|
||||
# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
||||
# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
# cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
|
||||
# cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
|
||||
# painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
|
||||
|
||||
mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
|
||||
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
|
||||
painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
|
||||
# mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
|
||||
# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
||||
# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
# cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
|
||||
# painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
|
||||
|
||||
mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
|
||||
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
||||
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
|
||||
painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')
|
||||
# mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
|
||||
# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
|
||||
# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
|
||||
# cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
|
||||
# painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -15,26 +15,26 @@ class TrackingAnything():
|
||||
self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
|
||||
|
||||
|
||||
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
||||
same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
if first_flag:
|
||||
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
||||
return mask, logit, painted_image
|
||||
# def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
|
||||
# same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
# if first_flag:
|
||||
# mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
||||
# return mask, logit, painted_image
|
||||
|
||||
if interact_flag:
|
||||
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
||||
return mask, logit, painted_image
|
||||
# if interact_flag:
|
||||
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
||||
# return mask, logit, painted_image
|
||||
|
||||
mask, logit, painted_image = self.xmem.track(image, logit)
|
||||
return mask, logit, painted_image
|
||||
# mask, logit, painted_image = self.xmem.track(image, logit)
|
||||
# return mask, logit, painted_image
|
||||
|
||||
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
|
||||
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
|
||||
return mask, logit, painted_image
|
||||
|
||||
def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
||||
return mask, logit, painted_image
|
||||
# def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
|
||||
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
|
||||
# return mask, logit, painted_image
|
||||
|
||||
def generator(self, images: list, template_mask:np.ndarray):
|
||||
|
||||
@@ -53,6 +53,7 @@ class TrackingAnything():
|
||||
masks.append(mask)
|
||||
logits.append(logit)
|
||||
painted_images.append(painted_image)
|
||||
print("tracking image {}".format(i))
|
||||
return masks, logits, painted_images
|
||||
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ class BaseTracker:
|
||||
logit: numpy arrays, probability map (H, W)
|
||||
painted_image: numpy array (H, W, 3)
|
||||
"""
|
||||
|
||||
if first_frame_annotation is not None: # first frame mask
|
||||
# initialisation
|
||||
mask, labels = self.mapper.convert_mask(first_frame_annotation)
|
||||
@@ -87,12 +88,20 @@ class BaseTracker:
|
||||
out_mask = torch.argmax(probs, dim=0)
|
||||
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
|
||||
|
||||
num_objs = out_mask.max()
|
||||
final_mask = np.zeros_like(out_mask)
|
||||
|
||||
# map back
|
||||
for k, v in self.mapper.remappings.items():
|
||||
final_mask[out_mask == v] = k
|
||||
|
||||
num_objs = final_mask.max()
|
||||
painted_image = frame
|
||||
for obj in range(1, num_objs+1):
|
||||
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
|
||||
if np.max(final_mask==obj) == 0:
|
||||
continue
|
||||
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
|
||||
|
||||
return out_mask, out_mask, painted_image
|
||||
return final_mask, final_mask, painted_image
|
||||
|
||||
@torch.no_grad()
|
||||
def sam_refinement(self, frame, logits, ti):
|
||||
@@ -142,34 +151,38 @@ if __name__ == '__main__':
|
||||
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
|
||||
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
|
||||
|
||||
# test for storage efficiency
|
||||
frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
|
||||
first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
|
||||
# # test for storage efficiency
|
||||
# frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
|
||||
# first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
|
||||
|
||||
first_frame_annotation[first_frame_annotation==1] = 15
|
||||
first_frame_annotation[first_frame_annotation==2] = 20
|
||||
|
||||
save_path = '/ssd1/gaomingqi/results/TrackA/multi-change1'
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
|
||||
for ti, frame in enumerate(frames):
|
||||
print(ti)
|
||||
if ti > 200:
|
||||
break
|
||||
if ti == 0:
|
||||
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||
else:
|
||||
mask, prob, painted_image = tracker.track(frame)
|
||||
# save
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
||||
painted_image.save(f'{save_path}/{ti:05d}.png')
|
||||
|
||||
tracker.clear_memory()
|
||||
for ti, frame in enumerate(frames):
|
||||
print(ti)
|
||||
# if ti > 200:
|
||||
# break
|
||||
if ti == 0:
|
||||
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||
else:
|
||||
mask, prob, painted_image = tracker.track(frame)
|
||||
# save
|
||||
painted_image = Image.fromarray(painted_image)
|
||||
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
||||
# tracker.clear_memory()
|
||||
# for ti, frame in enumerate(frames):
|
||||
# print(ti)
|
||||
# # if ti > 200:
|
||||
# # break
|
||||
# if ti == 0:
|
||||
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
|
||||
# else:
|
||||
# mask, prob, painted_image = tracker.track(frame)
|
||||
# # save
|
||||
# painted_image = Image.fromarray(painted_image)
|
||||
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
|
||||
|
||||
# # track anything given in the first frame annotation
|
||||
# for ti, frame in enumerate(frames):
|
||||
|
||||
Reference in New Issue
Block a user