This commit is contained in:
gaomingqi
2023-04-20 20:51:30 +08:00
8 changed files with 330 additions and 200 deletions

View File

@@ -5,7 +5,7 @@
***Track-Anything*** is a flexible and interactive tool for video object tracking and segmentation. It is developed upon [Segment Anything](https://github.com/facebookresearch/segment-anything), can specify anything to track and segment via user clicks only. During tracking, users can flexibly change the objects they wanna track or correct the region of interest if there are any ambiguities. These characteristics enable ***Track-Anything*** to be suitable for:
- Video object tracking and segmentation with shot changes.
- Data annnotation for video object tracking and segmentation.
- Visualized development and data annnotation for video object tracking and segmentation.
- Object-centric downstream video tasks, such as video inpainting and editing.
## Demo

233
app.py
View File

@@ -17,7 +17,7 @@ import torchvision
import torch
import concurrent.futures
import queue
from tools.painter import mask_painter, point_painter
# download checkpoints
def download_checkpoint(url, folder, filename):
os.makedirs(folder, exist_ok=True)
@@ -84,12 +84,21 @@ def get_frames_from_video(video_input, video_state):
"masks": [None]*len(frames),
"logits": [None]*len(frames),
"select_frame_number": 0,
"fps": 30
"fps": fps
}
return video_state, gr.update(visible=True, maximum=len(frames), value=1)
video_info = "Video Name: {}, FPS: {}, Total Frames: {}".format(video_state["video_name"], video_state["fps"], len(frames))
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][0])
return video_state, video_info, video_state["origin_images"][0], gr.update(visible=True, maximum=len(frames), value=1), gr.update(visible=True, maximum=len(frames), value=len(frames)), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True)
# get the select frame from gradio slider
def select_template(image_selection_slider, video_state):
def select_template(image_selection_slider, video_state, interactive_state):
# images = video_state[1]
image_selection_slider -= 1
@@ -100,8 +109,14 @@ def select_template(image_selection_slider, video_state):
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
# # clear multi mask
# interactive_state["multi_mask"] = {"masks":[], "mask_names":[]}
return video_state["painted_images"][image_selection_slider], video_state
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
def get_end_number(track_pause_number_slider, interactive_state):
interactive_state["track_end_number"] = track_pause_number_slider
return interactive_state
# use sam to get the mask
def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr.SelectData):
@@ -133,14 +148,62 @@ def sam_refine(video_state, point_prompt, click_state, interactive_state, evt:gr
return painted_image, video_state, interactive_state
def add_multi_mask(video_state, interactive_state, mask_dropdown):
mask = video_state["masks"][video_state["select_frame_number"]]
interactive_state["multi_mask"]["masks"].append(mask)
interactive_state["multi_mask"]["mask_names"].append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
mask_dropdown.append("mask_{:03d}".format(len(interactive_state["multi_mask"]["masks"])))
select_frame = show_mask(video_state, interactive_state, mask_dropdown)
return interactive_state, gr.update(choices=interactive_state["multi_mask"]["mask_names"], value=mask_dropdown), select_frame, [[],[]]
def clear_click(video_state, click_state):
click_state = [[],[]]
template_frame = video_state["origin_images"][video_state["select_frame_number"]]
return template_frame, click_state
def remove_multi_mask(interactive_state):
interactive_state["multi_mask"]["mask_names"]= []
interactive_state["multi_mask"]["masks"] = []
return interactive_state
def show_mask(video_state, interactive_state, mask_dropdown):
mask_dropdown.sort()
select_frame = video_state["origin_images"][video_state["select_frame_number"]]
for i in range(len(mask_dropdown)):
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
mask = interactive_state["multi_mask"]["masks"][mask_number]
select_frame = mask_painter(select_frame, mask.astype('uint8'), mask_color=mask_number+2)
return select_frame
# tracking vos
def vos_tracking_video(video_state, interactive_state):
def vos_tracking_video(video_state, interactive_state, mask_dropdown):
model.xmem.clear_memory()
if interactive_state["track_end_number"]:
following_frames = video_state["origin_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]]
else:
following_frames = video_state["origin_images"][video_state["select_frame_number"]:]
if interactive_state["multi_mask"]["masks"]:
if len(mask_dropdown) == 0:
mask_dropdown = ["mask_001"]
mask_dropdown.sort()
template_mask = interactive_state["multi_mask"]["masks"][int(mask_dropdown[0].split("_")[1]) - 1] * (int(mask_dropdown[0].split("_")[1]))
for i in range(1,len(mask_dropdown)):
mask_number = int(mask_dropdown[i].split("_")[1]) - 1
template_mask = np.clip(template_mask+interactive_state["multi_mask"]["masks"][mask_number]*(mask_number+1), 0, mask_number+1)
video_state["masks"][video_state["select_frame_number"]]= template_mask
else:
template_mask = video_state["masks"][video_state["select_frame_number"]]
fps = video_state["fps"]
masks, logits, painted_images = model.generator(images=following_frames, template_mask=template_mask)
if interactive_state["track_end_number"]:
video_state["masks"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = masks
video_state["logits"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = logits
video_state["painted_images"][video_state["select_frame_number"]:interactive_state["track_end_number"]] = painted_images
else:
video_state["masks"][video_state["select_frame_number"]:] = masks
video_state["logits"][video_state["select_frame_number"]:] = logits
video_state["painted_images"][video_state["select_frame_number"]:] = painted_images
@@ -176,6 +239,14 @@ def generate_video_from_frames(frames, output_path, fps=30):
output_path (str): The path to save the generated video.
fps (int, optional): The frame rate of the output video. Defaults to 30.
"""
# height, width, layers = frames[0].shape
# fourcc = cv2.VideoWriter_fourcc(*"mp4v")
# video = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
# print(output_path)
# for frame in frames:
# video.write(frame)
# video.release()
frames = torch.from_numpy(np.asarray(frames))
if not os.path.exists(os.path.dirname(output_path)):
os.makedirs(os.path.dirname(output_path))
@@ -193,8 +264,8 @@ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoi
# args, defined in track_anything.py
args = parse_augment()
# args.port = 12212
# args.device = "cuda:4"
# args.port = 12315
# args.device = "cuda:1"
# args.mask_save = True
model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
@@ -208,8 +279,15 @@ with gr.Blocks() as iface:
"inference_times": 0,
"negative_click_times" : 0,
"positive_click_times": 0,
"mask_save": args.mask_save
})
"mask_save": args.mask_save,
"multi_mask": {
"mask_names": [],
"masks": []
},
"track_end_number": None
}
)
video_state = gr.State(
{
"video_name": "",
@@ -225,43 +303,47 @@ with gr.Blocks() as iface:
with gr.Row():
# for user video input
with gr.Column(scale=1.0):
video_input = gr.Video().style(height=360)
with gr.Column():
with gr.Row(scale=0.4):
video_input = gr.Video(autosize=True)
video_info = gr.Textbox()
with gr.Row(scale=1):
with gr.Row():
# put the template frame under the radio button
with gr.Column(scale=0.5):
with gr.Column():
# extract frames
with gr.Column():
extract_frames_button = gr.Button(value="Get video info", interactive=True, variant="primary")
# click points settins, negative or positive, mode continuous or single
with gr.Row():
with gr.Row(scale=0.5):
with gr.Row():
point_prompt = gr.Radio(
choices=["Positive", "Negative"],
value="Positive",
label="Point Prompt",
interactive=True)
interactive=True,
visible=False)
click_mode = gr.Radio(
choices=["Continuous", "Single"],
value="Continuous",
label="Clicking Mode",
interactive=True)
with gr.Row(scale=0.5):
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True).style(height=160)
clear_button_image = gr.Button(value="Clear Image", interactive=True)
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame").style(height=360)
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", invisible=False)
interactive=True,
visible=False)
with gr.Row():
clear_button_click = gr.Button(value="Clear Clicks", interactive=True, visible=False).style(height=160)
Add_mask_button = gr.Button(value="Add mask", interactive=True, visible=False)
template_frame = gr.Image(type="pil",interactive=True, elem_id="template_frame", visible=False).style(height=360)
image_selection_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Image Selection", visible=False)
track_pause_number_slider = gr.Slider(minimum=1, maximum=100, step=1, value=1, label="Track end frames", visible=False)
with gr.Column(scale=0.5):
video_output = gr.Video().style(height=360)
tracking_video_predict_button = gr.Button(value="Tracking")
with gr.Column():
mask_dropdown = gr.Dropdown(multiselect=True, value=[], label="Mask_select", info=".", visible=False)
remove_mask_button = gr.Button(value="Remove mask", interactive=True, visible=False)
video_output = gr.Video(autosize=True, visible=False).style(height=360)
tracking_video_predict_button = gr.Button(value="Tracking", visible=False)
# first step: get the video information
extract_frames_button.click(
@@ -269,27 +351,52 @@ with gr.Blocks() as iface:
inputs=[
video_input, video_state
],
outputs=[video_state, image_selection_slider],
outputs=[video_state, video_info, template_frame,
image_selection_slider, track_pause_number_slider,point_prompt, click_mode, clear_button_click, Add_mask_button, template_frame,
tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button]
)
# second step: select images from slider
image_selection_slider.release(fn=select_template,
inputs=[image_selection_slider, video_state],
outputs=[template_frame, video_state], api_name="select_image")
inputs=[image_selection_slider, video_state, interactive_state],
outputs=[template_frame, video_state, interactive_state], api_name="select_image")
track_pause_number_slider.release(fn=get_end_number,
inputs=[track_pause_number_slider, interactive_state],
outputs=[interactive_state], api_name="end_image")
# click select image to get mask using sam
template_frame.select(
fn=sam_refine,
inputs=[video_state, point_prompt, click_state, interactive_state],
outputs=[template_frame, video_state, interactive_state]
)
# add different mask
Add_mask_button.click(
fn=add_multi_mask,
inputs=[video_state, interactive_state, mask_dropdown],
outputs=[interactive_state, mask_dropdown, template_frame, click_state]
)
remove_mask_button.click(
fn=remove_multi_mask,
inputs=[interactive_state],
outputs=[interactive_state]
)
# tracking video from select image and mask
tracking_video_predict_button.click(
fn=vos_tracking_video,
inputs=[video_state, interactive_state],
inputs=[video_state, interactive_state, mask_dropdown],
outputs=[video_output, video_state, interactive_state]
)
# click to get mask
mask_dropdown.change(
fn=show_mask,
inputs=[video_state, interactive_state, mask_dropdown],
outputs=[template_frame]
)
# clear input
video_input.clear(
@@ -306,55 +413,41 @@ with gr.Blocks() as iface:
"inference_times": 0,
"negative_click_times" : 0,
"positive_click_times": 0,
"mask_save": args.mask_save
"mask_save": args.mask_save,
"multi_mask": {
"mask_names": [],
"masks": []
},
[[],[]]
"track_end_number": 0
},
[[],[]],
None,
None,
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), \
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False, value=[]), gr.update(visible=False) \
),
[],
[
video_state,
interactive_state,
click_state,
video_output,
template_frame,
tracking_video_predict_button, image_selection_slider , track_pause_number_slider,point_prompt, click_mode, clear_button_click,
Add_mask_button, template_frame, tracking_video_predict_button, video_output, mask_dropdown, remove_mask_button
],
queue=False,
show_progress=False
)
clear_button_image.click(
lambda: (
{
"origin_images": None,
"painted_images": None,
"masks": None,
"logits": None,
"select_frame_number": 0,
"fps": 30
},
{
"inference_times": 0,
"negative_click_times" : 0,
"positive_click_times": 0,
"mask_save": args.mask_save
},
[[],[]]
),
[],
[
video_state,
interactive_state,
click_state,
],
show_progress=False)
queue=False,
show_progress=False
# points clear
clear_button_click.click(
fn = clear_click,
inputs = [video_state, click_state,],
outputs = [template_frame,click_state],
)
clear_button_clike.click(
lambda: ([[],[]]),
[],
[click_state],
queue=False,
show_progress=False
)
iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")

View File

@@ -1,23 +1,46 @@
# import gradio as gr
# def update_iframe(slider_value):
# return f'''
# <script>
# window.addEventListener('message', function(event) {{
# if (event.data.sliderValue !== undefined) {{
# var iframe = document.getElementById("text_iframe");
# iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
# }}
# }}, false);
# </script>
# <iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
# '''
# iface = gr.Interface(
# fn=update_iframe,
# inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
# outputs=gr.outputs.HTML(),
# allow_flagging=False,
# )
# iface.launch(server_name='0.0.0.0', server_port=12212)
import gradio as gr
def update_iframe(slider_value):
return f'''
<script>
window.addEventListener('message', function(event) {{
if (event.data.sliderValue !== undefined) {{
var iframe = document.getElementById("text_iframe");
iframe.src = "http://localhost:5001/get_text?slider_value=" + event.data.sliderValue;
}}
}}, false);
</script>
<iframe id="text_iframe" src="http://localhost:5001/get_text?slider_value={slider_value}" style="width: 100%; height: 100%; border: none;"></iframe>
'''
iface = gr.Interface(
fn=update_iframe,
inputs=gr.inputs.Slider(minimum=0, maximum=100, step=1, default=50),
outputs=gr.outputs.HTML(),
allow_flagging=False,
def change_mask(drop):
return gr.update(choices=["hello", "kitty"])
with gr.Blocks() as iface:
drop = gr.Dropdown(
choices=["cat", "dog", "bird"], label="Animal", info="Will add more animals later!"
)
radio = gr.Radio(["park", "zoo", "road"], label="Location", info="Where did they go?")
multi_drop = gr.Dropdown(
["ran", "swam", "ate", "slept"], value=["swam", "slept"], multiselect=True, label="Activity", info="Lorem ipsum dolor sit amet, consectetur adipiscing elit. Sed auctor, nisl eget ultricies aliquam, nunc nisl aliquet nunc, eget aliquam nisl nunc vel nisl."
)
iface.launch(server_name='0.0.0.0', server_port=12212)
multi_drop.change(
fn=change_mask,
inputs = multi_drop,
outputs=multi_drop
)
iface.launch(server_name='0.0.0.0', server_port=1223)

0
test.txt Normal file
View File

0
test_beta.txt Normal file
View File

View File

@@ -37,16 +37,16 @@ class SamControler():
self.sam_controler = BaseSegmenter(SAM_checkpoint, model_type, device)
def seg_again(self, image: np.ndarray):
'''
it is used when interact in video
'''
self.sam_controler.reset_image()
self.sam_controler.set_image(image)
return
# def seg_again(self, image: np.ndarray):
# '''
# it is used when interact in video
# '''
# self.sam_controler.reset_image()
# self.sam_controler.set_image(image)
# return
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True,mask_color=3):
'''
it is used in first frame in video
return: mask, logit, painted image(mask+point)
@@ -88,47 +88,47 @@ class SamControler():
return mask, logit, painted_image
def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
origal_image = self.sam_controler.orignal_image
if same:
'''
true; loop in the same image
'''
prompts = {
'point_coords': points,
'point_labels': labels,
'mask_input': logits[None, :, :]
}
masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
# def interact_loop(self, image:np.ndarray, same: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
# origal_image = self.sam_controler.orignal_image
# if same:
# '''
# true; loop in the same image
# '''
# prompts = {
# 'point_coords': points,
# 'point_labels': labels,
# 'mask_input': logits[None, :, :]
# }
# masks, scores, logits = self.sam_controler.predict(prompts, 'both', multimask)
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image)
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
# painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image
else:
'''
loop in the different image, interact in the video
'''
if image is None:
raise('Image error')
else:
self.seg_again(image)
prompts = {
'point_coords': points,
'point_labels': labels,
}
masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
# return mask, logit, painted_image
# else:
# '''
# loop in the different image, interact in the video
# '''
# if image is None:
# raise('Image error')
# else:
# self.seg_again(image)
# prompts = {
# 'point_coords': points,
# 'point_labels': labels,
# }
# masks, scores, logits = self.sam_controler.predict(prompts, 'point', multimask)
# mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
painted_image = Image.fromarray(painted_image)
# painted_image = mask_painter(image, mask.astype('uint8'), mask_color, mask_alpha, contour_color, contour_width)
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels>0)],axis = 1), point_color_ne, point_alpha, point_radius, contour_color, contour_width)
# painted_image = point_painter(painted_image, np.squeeze(points[np.argwhere(labels<1)],axis = 1), point_color_ps, point_alpha, point_radius, contour_color, contour_width)
# painted_image = Image.fromarray(painted_image)
return mask, logit, painted_image
# return mask, logit, painted_image
@@ -226,31 +226,31 @@ class SamControler():
if __name__ == "__main__":
points = np.array([[500, 375], [1125, 625]])
labels = np.array([1, 1])
image = cv2.imread('/hhd3/gaoshang/truck.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# if __name__ == "__main__":
# points = np.array([[500, 375], [1125, 625]])
# labels = np.array([1, 1])
# image = cv2.imread('/hhd3/gaoshang/truck.jpg')
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
sam_controler = initialize()
mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
# sam_controler = initialize()
# mask, logit, painted_image_full = first_frame_click(sam_controler,image, points, labels, multimask=True)
# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
# cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image)
# cv2.imwrite('/hhd3/gaoshang/truck_change.jpg', image)
# painted_image_full.save('/hhd3/gaoshang/truck_point_full.jpg')
mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
# mask, logit, painted_image_full = interact_loop(sam_controler,image,True, points, np.array([1, 0]), logit, multimask=True)
# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
# cv2.imwrite('/hhd3/gaoshang/truck_same.jpg', painted_image)
# painted_image_full.save('/hhd3/gaoshang/truck_same_full.jpg')
mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')
# mask, logit, painted_image_full = interact_loop(sam_controler,image, False, points, labels, multimask=True)
# painted_image = mask_painter2(image, mask.astype('uint8'), background_alpha=0.8)
# painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) # numpy array (h, w, 3)
# cv2.imwrite('/hhd3/gaoshang/truck_diff.jpg', painted_image)
# painted_image_full.save('/hhd3/gaoshang/truck_diff_full.jpg')

View File

@@ -15,26 +15,26 @@ class TrackingAnything():
self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
if first_flag:
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
return mask, logit, painted_image
# def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
# same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
# if first_flag:
# mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
# return mask, logit, painted_image
if interact_flag:
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
return mask, logit, painted_image
# if interact_flag:
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
# return mask, logit, painted_image
mask, logit, painted_image = self.xmem.track(image, logit)
return mask, logit, painted_image
# mask, logit, painted_image = self.xmem.track(image, logit)
# return mask, logit, painted_image
def first_frame_click(self, image: np.ndarray, points:np.ndarray, labels: np.ndarray, multimask=True):
mask, logit, painted_image = self.samcontroler.first_frame_click(image, points, labels, multimask)
return mask, logit, painted_image
def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
return mask, logit, painted_image
# def interact(self, image: np.ndarray, same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
# mask, logit, painted_image = self.samcontroler.interact_loop(image, same_image_flag, points, labels, logits, multimask)
# return mask, logit, painted_image
def generator(self, images: list, template_mask:np.ndarray):
@@ -53,6 +53,7 @@ class TrackingAnything():
masks.append(mask)
logits.append(logit)
painted_images.append(painted_image)
print("tracking image {}".format(i))
return masks, logits, painted_images

View File

@@ -67,6 +67,7 @@ class BaseTracker:
logit: numpy arrays, probability map (H, W)
painted_image: numpy array (H, W, 3)
"""
if first_frame_annotation is not None: # first frame mask
# initialisation
mask, labels = self.mapper.convert_mask(first_frame_annotation)
@@ -87,12 +88,20 @@ class BaseTracker:
out_mask = torch.argmax(probs, dim=0)
out_mask = (out_mask.detach().cpu().numpy()).astype(np.uint8)
num_objs = out_mask.max()
final_mask = np.zeros_like(out_mask)
# map back
for k, v in self.mapper.remappings.items():
final_mask[out_mask == v] = k
num_objs = final_mask.max()
painted_image = frame
for obj in range(1, num_objs+1):
painted_image = mask_painter(painted_image, (out_mask==obj).astype('uint8'), mask_color=obj+1)
if np.max(final_mask==obj) == 0:
continue
painted_image = mask_painter(painted_image, (final_mask==obj).astype('uint8'), mask_color=obj+1)
return out_mask, out_mask, painted_image
return final_mask, final_mask, painted_image
@torch.no_grad()
def sam_refinement(self, frame, logits, ti):
@@ -142,34 +151,38 @@ if __name__ == '__main__':
# sam_model = BaseSegmenter(SAM_checkpoint, model_type, device=device)
tracker = BaseTracker(XMEM_checkpoint, device, None, device)
# test for storage efficiency
frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
# # test for storage efficiency
# frames = np.load('/ssd1/gaomingqi/efficiency/efficiency.npy')
# first_frame_annotation = np.array(Image.open('/ssd1/gaomingqi/efficiency/template_mask.png'))
first_frame_annotation[first_frame_annotation==1] = 15
first_frame_annotation[first_frame_annotation==2] = 20
save_path = '/ssd1/gaomingqi/results/TrackA/multi-change1'
if not os.path.exists(save_path):
os.mkdir(save_path)
for ti, frame in enumerate(frames):
print(ti)
if ti > 200:
break
if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
else:
mask, prob, painted_image = tracker.track(frame)
# save
painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
painted_image.save(f'{save_path}/{ti:05d}.png')
tracker.clear_memory()
for ti, frame in enumerate(frames):
print(ti)
# if ti > 200:
# break
if ti == 0:
mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
else:
mask, prob, painted_image = tracker.track(frame)
# save
painted_image = Image.fromarray(painted_image)
painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
# tracker.clear_memory()
# for ti, frame in enumerate(frames):
# print(ti)
# # if ti > 200:
# # break
# if ti == 0:
# mask, prob, painted_image = tracker.track(frame, first_frame_annotation)
# else:
# mask, prob, painted_image = tracker.track(frame)
# # save
# painted_image = Image.fromarray(painted_image)
# painted_image.save(f'/ssd1/gaomingqi/results/TrackA/gsw/{ti:05d}.png')
# # track anything given in the first frame annotation
# for ti, frame in enumerate(frames):