Files
Track-Anything/inference/interact/resource_manager.py
gaomingqi 9f30e59c45 add xmem
2023-04-12 08:24:08 +08:00

207 lines
6.8 KiB
Python

import os
from os import path
import shutil
import collections
import cv2
from PIL import Image
if not hasattr(Image, 'Resampling'): # Pillow<9.0
Image.Resampling = Image
import numpy as np
from util.palette import davis_palette
import progressbar
# https://bugs.python.org/issue28178
# ah python ah why
class LRU:
def __init__(self, func, maxsize=128):
self.cache = collections.OrderedDict()
self.func = func
self.maxsize = maxsize
def __call__(self, *args):
cache = self.cache
if args in cache:
cache.move_to_end(args)
return cache[args]
result = self.func(*args)
cache[args] = result
if len(cache) > self.maxsize:
cache.popitem(last=False)
return result
def invalidate(self, key):
self.cache.pop(key, None)
class ResourceManager:
def __init__(self, config):
# determine inputs
images = config['images']
video = config['video']
self.workspace = config['workspace']
self.size = config['size']
self.palette = davis_palette
# create temporary workspace if not specified
if self.workspace is None:
if images is not None:
basename = path.basename(images)
elif video is not None:
basename = path.basename(video)[:-4]
else:
raise NotImplementedError(
'Either images, video, or workspace has to be specified')
self.workspace = path.join('./workspace', basename)
print(f'Workspace is in: {self.workspace}')
# determine the location of input images
need_decoding = False
need_resizing = False
if path.exists(path.join(self.workspace, 'images')):
pass
elif images is not None:
need_resizing = True
elif video is not None:
# will decode video into frames later
need_decoding = True
# create workspace subdirectories
self.image_dir = path.join(self.workspace, 'images')
self.mask_dir = path.join(self.workspace, 'masks')
os.makedirs(self.image_dir, exist_ok=True)
os.makedirs(self.mask_dir, exist_ok=True)
# convert read functions to be buffered
self.get_image = LRU(self._get_image_unbuffered, maxsize=config['buffer_size'])
self.get_mask = LRU(self._get_mask_unbuffered, maxsize=config['buffer_size'])
# extract frames from video
if need_decoding:
self._extract_frames(video)
# copy/resize existing images to the workspace
if need_resizing:
self._copy_resize_frames(images)
# read all frame names
self.names = sorted(os.listdir(self.image_dir))
self.names = [f[:-4] for f in self.names] # remove extensions
self.length = len(self.names)
assert self.length > 0, f'No images found! Check {self.workspace}/images. Remove folder if necessary.'
print(f'{self.length} images found.')
self.height, self.width = self.get_image(0).shape[:2]
self.visualization_init = False
def _extract_frames(self, video):
cap = cv2.VideoCapture(video)
frame_index = 0
print(f'Extracting frames from {video} into {self.image_dir}...')
bar = progressbar.ProgressBar(max_value=progressbar.UnknownLength)
while(cap.isOpened()):
_, frame = cap.read()
if frame is None:
break
if self.size > 0:
h, w = frame.shape[:2]
new_w = (w*self.size//min(w, h))
new_h = (h*self.size//min(w, h))
if new_w != w or new_h != h:
frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA)
cv2.imwrite(path.join(self.image_dir, f'{frame_index:07d}.jpg'), frame)
frame_index += 1
bar.update(frame_index)
bar.finish()
print('Done!')
def _copy_resize_frames(self, images):
image_list = os.listdir(images)
print(f'Copying/resizing frames into {self.image_dir}...')
for image_name in progressbar.progressbar(image_list):
if self.size < 0:
# just copy
shutil.copy2(path.join(images, image_name), self.image_dir)
else:
frame = cv2.imread(path.join(images, image_name))
h, w = frame.shape[:2]
new_w = (w*self.size//min(w, h))
new_h = (h*self.size//min(w, h))
if new_w != w or new_h != h:
frame = cv2.resize(frame,dsize=(new_w,new_h),interpolation=cv2.INTER_AREA)
cv2.imwrite(path.join(self.image_dir, image_name), frame)
print('Done!')
def save_mask(self, ti, mask):
# mask should be uint8 H*W without channels
assert 0 <= ti < self.length
assert isinstance(mask, np.ndarray)
mask = Image.fromarray(mask)
mask.putpalette(self.palette)
mask.save(path.join(self.mask_dir, self.names[ti]+'.png'))
self.invalidate(ti)
def save_visualization(self, ti, image):
# image should be uint8 3*H*W
assert 0 <= ti < self.length
assert isinstance(image, np.ndarray)
if not self.visualization_init:
self.visualization_dir = path.join(self.workspace, 'visualization')
os.makedirs(self.visualization_dir, exist_ok=True)
self.visualization_init = True
image = Image.fromarray(image)
image.save(path.join(self.visualization_dir, self.names[ti]+'.jpg'))
def _get_image_unbuffered(self, ti):
# returns H*W*3 uint8 array
assert 0 <= ti < self.length
image = Image.open(path.join(self.image_dir, self.names[ti]+'.jpg'))
image = np.array(image)
return image
def _get_mask_unbuffered(self, ti):
# returns H*W uint8 array
assert 0 <= ti < self.length
mask_path = path.join(self.mask_dir, self.names[ti]+'.png')
if path.exists(mask_path):
mask = Image.open(mask_path)
mask = np.array(mask)
return mask
else:
return None
def read_external_image(self, file_name, size=None):
image = Image.open(file_name)
is_mask = image.mode in ['L', 'P']
if size is not None:
# PIL uses (width, height)
image = image.resize((size[1], size[0]),
resample=Image.Resampling.NEAREST if is_mask else Image.Resampling.BICUBIC)
image = np.array(image)
return image
def invalidate(self, ti):
# the image buffer is never invalidated
self.get_mask.invalidate((ti,))
def __len__(self):
return self.length
@property
def h(self):
return self.height
@property
def w(self):
return self.width