update gradio example -- li

This commit is contained in:
memoryunreal
2023-04-24 17:07:46 +00:00
parent 50e9bfa79c
commit 8e503db441
7 changed files with 30 additions and 7 deletions

View File

@@ -44,6 +44,15 @@ pip install madgrad
python app.py --device cuda:0 --sam_model_type vit_h --port 12212
```
## Citation
If you find this work useful for your research or applications, please cite using this BibTeX:
```bibtex
@misc{gao2023track,
title = {}
}
```
## Acknowledgements
The project is based on [Segment Anything](https://github.com/facebookresearch/segment-anything), [XMem](https://github.com/hkchengrex/XMem), and [E2FGVI](https://github.com/MCG-NKU/E2FGVI). Thanks for the authors for their efforts.

22
app.py
View File

@@ -97,6 +97,8 @@ def get_frames_from_video(video_input, video_state):
gr.update(visible=True), gr.update(visible=True), \
gr.update(visible=True)
def run_example(example):
return video_input
# get the select frame from gradio slider
def select_template(image_selection_slider, video_state, interactive_state):
@@ -109,11 +111,14 @@ def select_template(image_selection_slider, video_state, interactive_state):
model.samcontroler.sam_controler.reset_image()
model.samcontroler.sam_controler.set_image(video_state["origin_images"][image_selection_slider])
# # clear multi mask
# interactive_state["multi_mask"] = {"masks":[], "mask_names":[]}
# update the masks when select a new template frame
# if video_state["masks"][image_selection_slider] is not None:
# video_state["painted_images"][image_selection_slider] = mask_painter(video_state["origin_images"][image_selection_slider], video_state["masks"][image_selection_slider])
return video_state["painted_images"][image_selection_slider], video_state, interactive_state
# set the tracking end frame
def get_end_number(track_pause_number_slider, interactive_state):
interactive_state["track_end_number"] = track_pause_number_slider
return interactive_state
@@ -446,7 +451,18 @@ with gr.Blocks() as iface:
fn = clear_click,
inputs = [video_state, click_state,],
outputs = [template_frame,click_state],
)
# set example
gr.Markdown("## Examples")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "./test_sample/", test_sample) for test_sample in ["test-sample8.mp4","test-sample4.mp4", \
"test-sample2.mp4","test-sample13.mp4"]],
fn=run_example,
inputs=[
video_input
],
outputs=[video_input],
# cache_examples=True,
)
iface.queue(concurrency_count=1)
iface.launch(debug=True, enable_queue=True, server_port=args.port, server_name="0.0.0.0")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -12,9 +12,7 @@ class TrackingAnything():
def __init__(self, sam_checkpoint, xmem_checkpoint, args):
self.args = args
self.samcontroler = SamControler(sam_checkpoint, args.sam_model_type, args.device)
self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
self.xmem = BaseTracker(xmem_checkpoint, device=args.device)
# def inference_step(self, first_flag: bool, interact_flag: bool, image: np.ndarray,
# same_image_flag: bool, points:np.ndarray, labels: np.ndarray, logits: np.ndarray=None, multimask=True):
# if first_flag:
@@ -63,7 +61,7 @@ def parse_augment():
parser.add_argument('--sam_model_type', type=str, default="vit_h")
parser.add_argument('--port', type=int, default=6080, help="only useful when running gradio applications")
parser.add_argument('--debug', action="store_true")
parser.add_argument('--mask_save', default=True)
parser.add_argument('--mask_save', default=False)
args = parser.parse_args()
if args.debug: