This commit is contained in:
gaomingqi
2023-04-20 20:51:30 +08:00
8 changed files with 330 additions and 200 deletions

View File

@@ -5,7 +5,7 @@
***Track-Anything*** is a flexible and interactive tool for video object tracking and segmentation. It is developed upon [Segment Anything](https://github.com/facebookresearch/segment-anything), can specify anything to track and segment via user clicks only. During tracking, users can flexibly change the objects they wanna track or correct the region of interest if there are any ambiguities. These characteristics enable ***Track-Anything*** to be suitable for: ***Track-Anything*** is a flexible and interactive tool for video object tracking and segmentation. It is developed upon [Segment Anything](https://github.com/facebookresearch/segment-anything), can specify anything to track and segment via user clicks only. During tracking, users can flexibly change the objects they wanna track or correct the region of interest if there are any ambiguities. These characteristics enable ***Track-Anything*** to be suitable for:
- Video object tracking and segmentation with shot changes. - Video object tracking and segmentation with shot changes.
- Data annnotation for video object tracking and segmentation. - Visualized development and data annnotation for video object tracking and segmentation.
- Object-centric downstream video tasks, such as video inpainting and editing. - Object-centric downstream video tasks, such as video inpainting and editing.
## Demo ## Demo

233
app.py
View File

@@ -17,7 +17,7 @@ import torchvision
import torch import torch
import concurrent.futures import concurrent.futures
import queue import queue
from tools.painter import mask_painter, point_painter
# download checkpoints # download checkpoints
def download_checkpoint(url, folder, filename): def download_checkpoint(url, folder, filename):
os.makedirs(folder, exist_ok=True) os.makedirs(folder, exist_ok=True)
@@ -84,12 +84,21 @@ def get_frames_from_video(video_input, video_state):
"masks": [None]*len(frames), "masks": [None]*len(frames),
"logits": [None]*len(frames), "logits": [None]*len(frames),
"select_frame_number": 0, "select_frame_number": 0,
"fps": 30 "fps": fps
} }
return video_state, gr.update(visible=True, maximum=len(frames), value=1) video_info = "Video Name: {}, FPS: {}, Total Frames: {}".format(video_state["video_name"], video_state["fps"], len(frames))
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True)
# get the select frame from gradio slider # get the select frame from gradio slider
def select_template(image_selection_slider, video_state): def select_template(image_selection_slider, video_state, interactive_state):
# images = video_state[1] # images = video_state[1]
image_selection_slider -= 1 image_selection_slider -= 1
@@ -100,8 +109,14 @@ def select_template(image_selection_slider, video_state):
model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
# # clear multi mask
# interactive_state["multi_mask"] = {"masks":[], "mask_names":[]}
return video_state["painted_images"][image_selection_slider], video_state return video_state["painted_images"][image_selection_slider], video_state, interactive_state
def get_end_number(track_pause_number_slider, interactive_state):
interactive_state["track_end_number"] = track_pause_number_slider
return interactive_state
# use sam to get the mask # use sam to get the mask
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData): def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
@@ -133,14 +148,62 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
return painted_image, video_state, interactive_state return painted_image, video_state, interactive_state
def add_multi_mask(video_state, interactive_state, mask_dropdown):
mask = video_state["masks"][video_state["select_frame_number"]]
interactive_state["multi_mask"]["masks"].append(mask)
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
select_frame = show_mask(video_state, interactive_state, mask_dropdown)
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
def clear_click(video_state, click_state):
click_state = [[],[]]
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
return template_frame, click_state
def remove_multi_mask(interactive_state):
interactive_state["multi_mask"]["mask_names"]= []
interactive_state["multi_mask"]["masks"] = []
return interactive_state
def show_mask(video_state, interactive_state, mask_dropdown):
mask_dropdown.sort()
select_frame = video_state["origin_images"][video_state["select_frame_number"]]
for i in range(len(mask_dropdown)):
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
mask = interactive_state["multi_mask"]["masks"][mask_number]
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
return select_frame
# tracking vos # tracking vos
def vos_tracking_video(video_state, interactive_state): def vos_tracking_video(video_state, interactive_state, mask_dropdown):
model.xmem.clear_memory() model.xmem.clear_memory()
if interactive_state["track_end_number"]:
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
else:
following_frames = video_state["origin_images"][video_state["select_frame_number"]:] following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
if interactive_state["multi_mask"]["masks"]:
if len(mask_dropdown) == 0:
mask_dropdown = ["mask_001"]
mask_dropdown.sort()
template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
for i in range(1,len(mask_dropdown)):
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
video_state["masks"][video_state["select_frame_number"]]= template_mask
else:
template_mask = video_state["masks"][video_state["select_frame_number"]] template_mask = video_state["masks"][video_state["select_frame_number"]]
fps = video_state["fps"] fps = video_state["fps"]
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask) masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
if interactive_state["track_end_number"]:
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
else:
video_state["masks"][video_state["select_frame_number"]:] = masks video_state["masks"][video_state["select_frame_number"]:] = masks
video_state["logits"][video_state["select_frame_number"]:] = logits video_state["logits"][video_state["select_frame_number"]:] = logits
video_state["painted_images"][video_state["select_frame_number"]:] = painted_images video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
@@ -176,6 +239,14 @@ def generate_video_from_frames(frames, output_path, fps=30):
output_path (str): The path to save the generated video. output_path (str): The path to save the generated video.
fps (int, optional): The frame rate of the output video. Defaults to 30. fps (int, optional): The frame rate of the output video. Defaults to 30.
""" """
# height, width, layers = frames[0].shape
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
# video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# print(output_path)
# for frame in frames:
# video.write(frame)
# video.release()
frames = torch.from_numpy(np.asarray(frames)) frames = torch.from_numpy(np.asarray(frames))
if not os.path.exists(os.path.dirname(output_path)): if not os.path.exists(os.path.dirname(output_path)):
os.makedirs(os.path.dirname(output_path)) os.makedirs(os.path.dirname(output_path))
@@ -193,8 +264,8 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
# args, defined in track_anything.py # args, defined in track_anything.py
args = parse_augment() args = parse_augment()
# args.port = 12212 # args.port = 12315
# args.device = "cuda:4" # args.device = "cuda:1"
# args.mask_save = True # args.mask_save = True
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args) model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
@@ -208,8 +279,15 @@ with gr.Blocks() as iface:
"inference_times": 0, "inference_times": 0,
"negative_click_times" : 0, "negative_click_times" : 0,
"positive_click_times": 0, "positive_click_times": 0,
"mask_save": args.mask_save "mask_save": args.mask_save,
}) "multi_mask": {
"mask_names": [],
"masks": []
},
"track_end_number": None
}
)
video_state = gr.State( video_state = gr.State(
{ {
"video_name": "", "video_name": "",
@@ -225,43 +303,47 @@ with gr.Blocks() as iface:
with gr.Row(): with gr.Row():
# for user video input # for user video input
with gr.Column(scale=1.0): with gr.Column():
video_input = gr.Video().style(height=360) with gr.Row(scale=0.4):
video_input = gr.Video(autosize=True)
video_info = gr.Textbox()
with gr.Row(scale=1): with gr.Row():
# put the template frame under the radio button # put the template frame under the radio button
with gr.Column(scale=0.5): with gr.Column():
# extract frames # extract frames
with gr.Column(): with gr.Column():
extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary") extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
# click points settins, negative or positive, mode continuous or single # click points settins, negative or positive, mode continuous or single
with gr.Row(): with gr.Row():
with gr.Row(scale=0.5): with gr.Row():
point_prompt = gr.Radio( point_prompt = gr.Radio(
choices=["Positive", "Negative"], choices=["Positive", "Negative"],
value="Positive", value="Positive",
label="Point Prompt", label="Point Prompt",
interactive=True) interactive=True,
visible=False)
click_mode = gr.Radio( click_mode = gr.Radio(
choices=["Continuous", "Single"], choices=["Continuous", "Single"],
value="Continuous", value="Continuous",
label="Clicking Mode", label="Clicking Mode",
interactive=True) interactive=True,
with gr.Row(scale=0.5): visible=False)
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True).style(height=160) with gr.Row():
clear_button_image = gr.Button(value="Clear Image", interactive=True) clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360) Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", invisible=False) template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", visible=False)
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
with gr.Column():
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask_select", info=".", visible=False)
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
with gr.Column(scale=0.5): video_output = gr.Video(autosize=True, visible=False).style(height=360)
video_output = gr.Video().style(height=360) tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
tracking_video_predict_button = gr.Button(value="Tracking")
# first step: get the video information # first step: get the video information
extract_frames_button.click( extract_frames_button.click(
@@ -269,27 +351,52 @@ with gr.Blocks() as iface:
inputs=[ inputs=[
video_input, video_state video_input, video_state
], ],
outputs=[video_state, image_selection_slider], outputs=[video_state, video_info, template_frame,
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button]
) )
# second step: select images from slider # second step: select images from slider
image_selection_slider.release(fn=select_template, image_selection_slider.release(fn=select_template,
inputs=[image_selection_slider, video_state], inputs=[image_selection_slider, video_state, interactive_state],
outputs=[template_frame, video_state], api_name="select_image") outputs=[template_frame, video_state, interactive_state], api_name="select_image")
track_pause_number_slider.release(fn=get_end_number,
inputs=[track_pause_number_slider, interactive_state],
outputs=[interactive_state], api_name="end_image")
# click select image to get mask using sam
template_frame.select( template_frame.select(
fn=sam_refine, fn=sam_refine,
inputs=[video_state, point_prompt, click_state, interactive_state], inputs=[video_state, point_prompt, click_state, interactive_state],
outputs=[template_frame, video_state, interactive_state] outputs=[template_frame, video_state, interactive_state]
) )
# add different mask
Add_mask_button.click(
fn=add_multi_mask,
inputs=[video_state, interactive_state, mask_dropdown],
outputs=[interactive_state, mask_dropdown, template_frame, click_state]
)
remove_mask_button.click(
fn=remove_multi_mask,
inputs=[interactive_state],
outputs=[interactive_state]
)
# tracking video from select image and mask
tracking_video_predict_button.click( tracking_video_predict_button.click(
fn=vos_tracking_video, fn=vos_tracking_video,
inputs=[video_state, interactive_state], inputs=[video_state, interactive_state, mask_dropdown],
outputs=[video_output, video_state, interactive_state] outputs=[video_output, video_state, interactive_state]
) )
# click to get mask
mask_dropdown.change(
fn=show_mask,
inputs=[video_state, interactive_state, mask_dropdown],
outputs=[template_frame]
)
# clear input # clear input
video_input.clear( video_input.clear(
@@ -306,55 +413,41 @@ with gr.Blocks() as iface:
"inference_times": 0, "inference_times": 0,
"negative_click_times" : 0, "negative_click_times" : 0,
"positive_click_times": 0, "positive_click_times": 0,
"mask_save": args.mask_save "mask_save": args.mask_save,
"multi_mask": {
"mask_names": [],
"masks": []
}, },
[[],[]] "track_end_number": 0
},
[[],[]],
None,
None,
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False) \
), ),
[], [],
[ [
video_state, video_state,
interactive_state, interactive_state,
click_state, click_state,
video_output,
template_frame,
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button
], ],
queue=False, queue=False,
show_progress=False show_progress=False)
)
clear_button_image.click(
lambda: (
{
"origin_images": None,
"painted_images": None,
"masks": None,
"logits": None,
"select_frame_number": 0,
"fps": 30
},
{
"inference_times": 0,
"negative_click_times" : 0,
"positive_click_times": 0,
"mask_save": args.mask_save
},
[[],[]]
),
[],
[
video_state,
interactive_state,
click_state,
],
queue=False, # points clear
show_progress=False clear_button_click.click(
fn = clear_click,
inputs = [video_state, click_state,],
outputs = [template_frame,click_state],
) )
clear_button_clike.click(
lambda: ([[],[]]),
[],
[click_state],
queue=False,
show_progress=False
)
iface.queue(concurrency_count=1) iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0") iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")

View File

@@ -1,23 +1,46 @@
# import gradio as gr
# def update_iframe(slider_value):
# return f'''
# <script>
# window.addEventListener('message', function(event) {{
# if (event.data.sliderValue !== undefined) {{
# var iframe = document.getElementById("text_iframe");
# iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
# }}
# }}, false);
# </script>
# <iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
# '''
# iface = gr.Interface(
# fn=update_iframe,
# inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
# outputs=gr.outputs.HTML(),
# allow_flagging=False,
# )
# iface.launch(server_name='0.0.0.0', server_port=12212)
import gradio as gr import gradio as gr
def update_iframe(slider_value):
return f'''
<script>
window.addEventListener('message', function(event) {{
if (event.data.sliderValue !== undefined) {{
var iframe = document.getElementById("text_iframe");
iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
}}
}}, false);
</script>
<iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
'''
iface = gr.Interface( def change_mask(drop):
fn=update_iframe, return gr.update(choices=["hello", "kitty"])
inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
outputs=gr.outputs.HTML(), with gr.Blocks() as iface:
allow_flagging=False, drop = gr.Dropdown(
choices=["cat", "dog", "bird"], label="Animal", info="Will add more animals later!"
)
radio = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
multi_drop = gr.Dropdown(
["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True, label="Activity", info="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed auctor, nisl eget ultricies aliquam, nunc nisl aliquet nunc, eget aliquam nisl nunc vel nisl."
) )
iface.launch(server_name='0.0.0.0', server_port=12212) multi_drop.change(
fn=change_mask,
inputs = multi_drop,
outputs=multi_drop
)
iface.launch(server_name='0.0.0.0', server_port=1223)

0
test.txt Normal file
View File

0
test_beta.txt Normal file
View File

View File

@@ -37,16 +37,16 @@ class SamControler():
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device) self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
def seg_again(self, image: np.ndarray): # def seg_again(self, image: np.ndarray):
''' # '''
it is used when interact in video # it is used when interact in video
''' # '''
self.sam_controler.reset_image() # self.sam_controler.reset_image()
self.sam_controler.set_image(image) # self.sam_controler.set_image(image)
return # return
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
''' '''
it is used in first frame in video it is used in first frame in video
return: mask, logit, painted image(mask+point) return: mask, logit, painted image(mask+point)
@@ -88,47 +88,47 @@ class SamControler():
return mask, logit, painted_image return mask, logit, painted_image
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): # def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
origal_image = self.sam_controler.orignal_image # origal_image = self.sam_controler.orignal_image
if same: # if same:
''' # '''
true; loop in the same image # true; loop in the same image
''' # '''
prompts = { # prompts = {
'point_coords': points, # 'point_coords': points,
'point_labels': labels, # 'point_labels': labels,
'mask_input': logits[None, :, :] # 'mask_input': logits[None, :, :]
} # }
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask) # masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image) # painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image # return mask, logit, painted_image
else: # else:
''' # '''
loop in the different image, interact in the video # loop in the different image, interact in the video
''' # '''
if image is None: # if image is None:
raise('Image error') # raise('Image error')
else: # else:
self.seg_again(image) # self.seg_again(image)
prompts = { # prompts = {
'point_coords': points, # 'point_coords': points,
'point_labels': labels, # 'point_labels': labels,
} # }
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask) # masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :] # mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width) # painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width) # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width) # painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image) # painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image # return mask, logit, painted_image
@@ -226,31 +226,31 @@ class SamControler():
if __name__ == "__main__": # if __name__ == "__main__":
points = np.array([[500, 375], [1125, 625]]) # points = np.array([[500, 375], [1125, 625]])
labels = np.array([1, 1]) # labels = np.array([1, 1])
image = cv2.imread('/hhd3/gaoshang/truck.jpg') # image = cv2.imread('/hhd3/gaoshang/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
sam_controler = initialize() # sam_controler = initialize()
mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True) # mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8) # painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) # painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image) # cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image) # cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg') # painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True) # mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8) # painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) # painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image) # cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg') # painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True) # mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8) # painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3) # painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image) # cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg') # painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')

View File

@@ -15,26 +15,26 @@ class TrackingAnything():
self.xmem = BaseTracker(xmem_checkpoint, device=args.device) self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray, # def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): # same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
if first_flag: # if first_flag:
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) # mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
return mask, logit, painted_image # return mask, logit, painted_image
if interact_flag: # if interact_flag:
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask) # mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
return mask, logit, painted_image # return mask, logit, painted_image
mask, logit, painted_image = self.xmem.track(image, logit) # mask, logit, painted_image = self.xmem.track(image, logit)
return mask, logit, painted_image # return mask, logit, painted_image
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True): def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask) mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
return mask, logit, painted_image return mask, logit, painted_image
def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): # def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask) # mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
return mask, logit, painted_image # return mask, logit, painted_image
def generator(self, images: list, template_mask:np.ndarray): def generator(self, images: list, template_mask:np.ndarray):
@@ -53,6 +53,7 @@ class TrackingAnything():
masks.append(mask) masks.append(mask)
logits.append(logit) logits.append(logit)
painted_images.append(painted_image) painted_images.append(painted_image)
print("tracking image {}".format(i))
return masks, logits, painted_images return masks, logits, painted_images

View File

@@ -67,6 +67,7 @@ class BaseTracker:
logit: numpy arrays, probability map (H, W) logit: numpy arrays, probability map (H, W)
painted_image: numpy array (H, W, 3) painted_image: numpy array (H, W, 3)
""" """
if first_frame_annotation is not None: # first frame mask if first_frame_annotation is not None: # first frame mask
# initialisation # initialisation
mask, labels = self.mapper.convert_mask(first_frame_annotation) mask, labels = self.mapper.convert_mask(first_frame_annotation)
@@ -87,12 +88,20 @@ class BaseTracker:
out_mask = torch.argmax(probs, dim=0) out_mask = torch.argmax(probs, dim=0)
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8) out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
num_objs = out_mask.max() final_mask = np.zeros_like(out_mask)
# map back
for k, v in self.mapper.remappings.items():
final_mask[out_mask == v] = k
num_objs = final_mask.max()
painted_image = frame painted_image = frame
for obj in range(1, num_objs+1): for obj in range(1, num_objs+1):
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1) if np.max(final_mask==obj) == 0:
continue
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
return out_mask, out_mask, painted_image return final_mask, final_mask, painted_image
@torch.no_grad() @torch.no_grad()
def sam_refinement(self, frame, logits, ti): def sam_refinement(self, frame, logits, ti):
@@ -142,34 +151,38 @@ if __name__ == '__main__':
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device) # sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
tracker = BaseTracker(XMEM_checkpoint, device, None, device) tracker = BaseTracker(XMEM_checkpoint, device, None, device)
# test for storage efficiency # # test for storage efficiency
frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy') # frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png')) # first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
first_frame_annotation[first_frame_annotation==1] = 15
first_frame_annotation[first_frame_annotation==2] = 20
save_path = '/ssd1/gaomingqi/results/TrackA/multi-change1'
if not os.path.exists(save_path):
os.mkdir(save_path)
for ti, frame in enumerate(frames): for ti, frame in enumerate(frames):
print(ti)
if ti > 200:
break
if ti == 0: if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation) mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
else: else:
mask, prob, painted_image = tracker.track(frame) mask, prob, painted_image = tracker.track(frame)
# save # save
painted_image = Image.fromarray(painted_image) painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png') painted_image.save(f'{save_path}/{ti:05d}.png')
tracker.clear_memory() # tracker.clear_memory()
for ti, frame in enumerate(frames): # for ti, frame in enumerate(frames):
print(ti) # print(ti)
# if ti > 200: # # if ti > 200:
# break # # break
if ti == 0: # if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation) # mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
else: # else:
mask, prob, painted_image = tracker.track(frame) # mask, prob, painted_image = tracker.track(frame)
# save # # save
painted_image = Image.fromarray(painted_image) # painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png') # painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
# # track anything given in the first frame annotation # # track anything given in the first frame annotation
# for ti, frame in enumerate(frames): # for ti, frame in enumerate(frames):