app.py edit shanggao

This commit is contained in:
ShangGaoG
2023-04-14 04:04:04 +08:00
7 changed files with 77 additions and 67 deletions

1
.gitignore vendored
View File

@@ -5,3 +5,4 @@ docs/
*.mp4
debug_images/
*.png
*.jpg

View File

@@ -1,37 +0,0 @@
import gradio as gr
import time
def capture_frame(video):
frame = video.get_frame_at_sec(video.current_time)
return frame
def capture_time(video):
while True:
if video.paused:
time_paused = video.current_time
return time_paused
iface = gr.Interface(fn=capture_frame,
inputs=[gr.inputs.Video(type="mp4", label="Input video",
source="upload")],
outputs=["image"],
server_port=12212,
server_name="0.0.0.0",
capture_session=True)
video_player = iface.video[0]
video_player.pause = False
time_interface = gr.Interface(fn=capture_time,
inputs=[gr.inputs.Video(type="mp4", label="Input video",
source="upload", max_duration=10)],
outputs=["text"],
server_port=12212,
server_name="0.0.0.0",
capture_session=True)
time_interface.video[0].play = False
time_interface.video[0].pause = False
iface.launch()
time_interface.launch()

View File

@@ -13,3 +13,4 @@ matplotlib
onnxruntime
onnx
metaseg
pyyaml

59
tracker/base_tracker.py Normal file
View File

@@ -0,0 +1,59 @@
# input: frame list, first frame mask
# output: segmentation results on all frames
import os
import glob
import numpy as np
from PIL import Image
import torch
import yaml
from model.network import XMem
from inference.inference_core import InferenceCore
class BaseTracker:
def __init__(self, device, xmem_checkpoint) -> None:
"""
device: model device
xmem_checkpoint: checkpoint of XMem model
"""
# load configurations
with open("tracker/config/config.yaml", 'r') as stream:
config = yaml.safe_load(stream)
# initialise XMem
network = XMem(config, xmem_checkpoint).to(device).eval()
# initialise IncerenceCore
self.tracker = InferenceCore(network, config)
# set data transformation
# self.data_transform =
def track(self, frames, first_frame_annotation):
# data transformation
# tracking
pass
if __name__ == '__main__':
# video frames
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg'))
video_path_list.sort()
# first frame
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
# load frames
frames = []
for video_path in video_path_list:
frames.append(np.array(Image.open(video_path).convert('RGB')))
frames = np.stack(frames, 0) # N, H, W, C
# load first frame annotation
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C
# initalise tracker
device = 'cuda:0'
XMEM_checkpoint = '/ssd1/gaomingqi/checkpoints/XMem-s012.pth'
tracker = BaseTracker('cuda:0', XMEM_checkpoint)
# track anything given in the first frame annotation
tracker.track(frames, first_frame_annotation)

View File

@@ -0,0 +1,15 @@
# config info for XMem
benchmark: False
disable_long_term: False
max_mid_term_frames: 10
min_mid_term_frames: 5
max_long_term_elements: 10000
num_prototypes: 128
top_k: 30
mem_every: 5
deep_update_every: -1
save_scores: False
flip: False
size: 480
enable_long_term: True
enable_long_term_count_usage: True

View File

@@ -1,29 +0,0 @@
# input: frame list, first frame mask
# output: segmentation results on all frames
import os
import glob
import numpy as np
from PIL import Image
class XMem:
# based on https://github.com/hkchengrex/XMem
pass
if __name__ == '__main__':
# video frames
video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg'))
video_path_list.sort()
# first frame
first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png'
# load frames
frames = []
for video_path in video_path_list:
frames.append(np.array(Image.open(video_path).convert('RGB')))
frames = np.stack(frames, 0) # N, H, W, C
# load first frame annotation
first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C