From cb491193d198c7d6b6b502e848fc8fb15383bc3d Mon Sep 17 00:00:00 2001 From: gaomingqi Date: Fri, 14 Apr 2023 00:19:09 +0800 Subject: [PATCH] fix painter --- tools/painter.py | 4 ++-- tracker/xmem.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) create mode 100644 tracker/xmem.py diff --git a/tools/painter.py b/tools/painter.py index c1685d6..e154b60 100644 --- a/tools/painter.py +++ b/tools/painter.py @@ -129,9 +129,9 @@ def point_painter(input_image, input_points, point_color=5, point_alpha=0.9, poi contour_mask[contour_mask>0.5] = 1. # paint mask - painted_image = vis_add_mask(input_image, point_mask, point_color, point_alpha) + painted_image = vis_add_mask(input_image.copy(), point_mask, point_color, point_alpha) # paint contour - painted_image = vis_add_mask(painted_image, 1-contour_mask, contour_color, 1) + painted_image = vis_add_mask(painted_image.copy(), 1-contour_mask, contour_color, 1) return painted_image def mask_painter(input_image, input_mask, mask_color=5, mask_alpha=0.7, contour_color=1, contour_width=3): diff --git a/tracker/xmem.py b/tracker/xmem.py new file mode 100644 index 0000000..1465843 --- /dev/null +++ b/tracker/xmem.py @@ -0,0 +1,29 @@ +# input: frame list, first frame mask +# output: segmentation results on all frames +import os +import glob +import numpy as np +from PIL import Image + + +class XMem: + # based on https://github.com/hkchengrex/XMem + pass + + +if __name__ == '__main__': + # video frames + video_path_list = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/dance-twirl', '*.jpg')) + video_path_list.sort() + # first frame + first_frame_path = '/ssd1/gaomingqi/datasets/davis/Annotations/480p/dance-twirl/00000.png' + + # load frames + frames = [] + for video_path in video_path_list: + frames.append(np.array(Image.open(video_path).convert('RGB'))) + frames = np.stack(frames, 0) # N, H, W, C + + # load first frame annotation + first_frame_annotation = np.array(Image.open(first_frame_path).convert('P')) # H, W, C +