diff --git a/README.md b/README.md index c9d3332..afcf541 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,15 @@ pip install madgrad python app.py --device cuda:0 --sam_model_type vit_h --port 12212 ``` +## Citation +If you find this work useful for your research or applications, please cite using this BibTeX: +```bibtex +@misc{gao2023track, + title = {} + +} +``` + ## Acknowledgements The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything), [XMem](https://github.com/hkchengrex/XMem), and [E2FGVI](https://github.com/MCG-NKU/E2FGVI). Thanks for the authors for their efforts. diff --git a/app.py b/app.py index acf7df2..7330c7a 100644 --- a/app.py +++ b/app.py @@ -97,6 +97,8 @@ def get_frames_from_video(video_input, video_state): gr.update(visible=True), gr.update(visible=True), \ gr.update(visible=True) +def run_example(example): + return video_input # get the select frame from gradio slider def select_template(image_selection_slider, video_state, interactive_state): @@ -109,11 +111,14 @@ def select_template(image_selection_slider, video_state, interactive_state): model.samcontroler.sam_controler.reset_image() model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider]) - # # clear multi mask - # interactive_state["multi_mask"] = {"masks":[], "mask_names":[]} + # update the masks when select a new template frame + # if video_state["masks"][image_selection_slider] is not None: + # video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider]) + return video_state["painted_images"][image_selection_slider], video_state, interactive_state +# set the tracking end frame def get_end_number(track_pause_number_slider, interactive_state): interactive_state["track_end_number"] = track_pause_number_slider return interactive_state @@ -446,7 +451,18 @@ with gr.Blocks() as iface: fn = clear_click, inputs = [video_state, click_state,], outputs = [template_frame,click_state], - + ) + # set example + gr.Markdown("## Examples") + gr.Examples( + examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample8.mp4","test-sample4.mp4", \ + "test-sample2.mp4","test-sample13.mp4"]], + fn=run_example, + inputs=[ + video_input + ], + outputs=[video_input], + # cache_examples=True, ) iface.queue(concurrency_count=1) iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0") diff --git a/test_sample/test-sample13.mp4 b/test_sample/test-sample13.mp4 new file mode 100644 index 0000000..04aa0df Binary files /dev/null and b/test_sample/test-sample13.mp4 differ diff --git a/test_sample/test-sample2.mp4 b/test_sample/test-sample2.mp4 new file mode 100644 index 0000000..cf6c1b1 Binary files /dev/null and b/test_sample/test-sample2.mp4 differ diff --git a/test_sample/test-sample4.mp4 b/test_sample/test-sample4.mp4 new file mode 100644 index 0000000..848988d Binary files /dev/null and b/test_sample/test-sample4.mp4 differ diff --git a/test_sample/test-sample8.mp4 b/test_sample/test-sample8.mp4 new file mode 100644 index 0000000..15f05e1 Binary files /dev/null and b/test_sample/test-sample8.mp4 differ diff --git a/track_anything.py b/track_anything.py index ab6c1f5..77d3436 100644 --- a/track_anything.py +++ b/track_anything.py @@ -12,9 +12,7 @@ class TrackingAnything(): def __init__(self, sam_checkpoint, xmem_checkpoint, args): self.args = args self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device) - self.xmem = BaseTracker(xmem_checkpoint, device=args.device) - - + self.xmem = BaseTracker(xmem_checkpoint, device=args.device) # def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray, # same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True): # if first_flag: @@ -63,7 +61,7 @@ def parse_augment(): parser.add_argument('--sam_model_type', type=str, default="vit_h") parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications") parser.add_argument('--debug', action="store_true") - parser.add_argument('--mask_save', default=True) + parser.add_argument('--mask_save', default=False) args = parser.parse_args() if args.debug: