mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-25 04:30:48 +01:00
upload marigold monocular depth estimation core files (#703)
* upload marigold monocular depth estimation core files
* fix lint
* remove unused files.
* update marigold model files
* update marigold core files to fix review comments
* fix lint
* fix lint
* fix lint
* format code
* format code
---------
Co-authored-by: 葭润 <ranqing.rq@alibaba-inc.com>
Co-authored-by: wenmeng zhou <wenmeng.zwm@alibaba-inc.com>
(cherry picked from commit 1ade52df17)
This commit is contained in:
@@ -462,6 +462,7 @@ class Pipelines(object):
|
||||
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
|
||||
vision_efficient_tuning = 'vision-efficient-tuning'
|
||||
image_bts_depth_estimation = 'image-bts-depth-estimation'
|
||||
image_depth_estimation_marigold = 'image-depth-estimation-marigold'
|
||||
pedestrian_attribute_recognition = 'resnet50_pedestrian-attribute-recognition_image'
|
||||
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||
image_try_on = 'image-try-on'
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .marigold import MarigoldDepthOutput
|
||||
from .marigold_utils import (chw2hwc, colorize_depth_maps, ensemble_depths,
|
||||
find_batch_size, inter_distances,
|
||||
resize_max_res)
|
||||
else:
|
||||
_import_structure = {
|
||||
'marigold': ['MarigoldDepthOutput'],
|
||||
'marigold_utils': [
|
||||
'find_batch_size', 'inter_distances', 'ensemble_depths',
|
||||
'colorize_depth_maps', 'chw2hwc', 'resize_max_res'
|
||||
]
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# --------------------------------------------------------------------------
|
||||
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
||||
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
||||
# More information about the method can be found at https://marigoldmonodepth.github.io
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
from diffusers.utils import BaseOutput
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class MarigoldDepthOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Marigold monocular depth prediction pipeline.
|
||||
|
||||
Args:
|
||||
depth_np (`np.ndarray`):
|
||||
Predicted depth map, with depth values in the range of [0, 1].
|
||||
depth_colored (`PIL.Image.Image`):
|
||||
Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
|
||||
uncertainty (`None` or `np.ndarray`):
|
||||
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
|
||||
"""
|
||||
|
||||
depth_np: np.ndarray
|
||||
depth_colored: Image.Image
|
||||
uncertainty: Union[None, np.ndarray]
|
||||
@@ -0,0 +1,364 @@
|
||||
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# --------------------------------------------------------------------------
|
||||
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
||||
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
||||
# More information about the method can be found at https://marigoldmonodepth.github.io
|
||||
|
||||
import math
|
||||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from scipy.optimize import minimize
|
||||
|
||||
# Search table for suggested max. inference batch size
|
||||
bs_search_table = [
|
||||
# tested on A100-PCIE-80GB
|
||||
{
|
||||
'res': 768,
|
||||
'total_vram': 79,
|
||||
'bs': 35,
|
||||
'dtype': torch.float32
|
||||
},
|
||||
{
|
||||
'res': 1024,
|
||||
'total_vram': 79,
|
||||
'bs': 20,
|
||||
'dtype': torch.float32
|
||||
},
|
||||
# tested on A100-PCIE-40GB
|
||||
{
|
||||
'res': 768,
|
||||
'total_vram': 39,
|
||||
'bs': 15,
|
||||
'dtype': torch.float32
|
||||
},
|
||||
{
|
||||
'res': 1024,
|
||||
'total_vram': 39,
|
||||
'bs': 8,
|
||||
'dtype': torch.float32
|
||||
},
|
||||
{
|
||||
'res': 768,
|
||||
'total_vram': 39,
|
||||
'bs': 30,
|
||||
'dtype': torch.float16
|
||||
},
|
||||
{
|
||||
'res': 1024,
|
||||
'total_vram': 39,
|
||||
'bs': 15,
|
||||
'dtype': torch.float16
|
||||
},
|
||||
# tested on RTX3090, RTX4090
|
||||
{
|
||||
'res': 512,
|
||||
'total_vram': 23,
|
||||
'bs': 20,
|
||||
'dtype': torch.float32
|
||||
},
|
||||
{
|
||||
'res': 768,
|
||||
'total_vram': 23,
|
||||
'bs': 7,
|
||||
'dtype': torch.float32
|
||||
},
|
||||
{
|
||||
'res': 1024,
|
||||
'total_vram': 23,
|
||||
'bs': 3,
|
||||
'dtype': torch.float32
|
||||
},
|
||||
{
|
||||
'res': 512,
|
||||
'total_vram': 23,
|
||||
'bs': 40,
|
||||
'dtype': torch.float16
|
||||
},
|
||||
{
|
||||
'res': 768,
|
||||
'total_vram': 23,
|
||||
'bs': 18,
|
||||
'dtype': torch.float16
|
||||
},
|
||||
{
|
||||
'res': 1024,
|
||||
'total_vram': 23,
|
||||
'bs': 10,
|
||||
'dtype': torch.float16
|
||||
},
|
||||
# tested on GTX1080Ti
|
||||
{
|
||||
'res': 512,
|
||||
'total_vram': 10,
|
||||
'bs': 5,
|
||||
'dtype': torch.float32
|
||||
},
|
||||
{
|
||||
'res': 768,
|
||||
'total_vram': 10,
|
||||
'bs': 2,
|
||||
'dtype': torch.float32
|
||||
},
|
||||
{
|
||||
'res': 512,
|
||||
'total_vram': 10,
|
||||
'bs': 10,
|
||||
'dtype': torch.float16
|
||||
},
|
||||
{
|
||||
'res': 768,
|
||||
'total_vram': 10,
|
||||
'bs': 5,
|
||||
'dtype': torch.float16
|
||||
},
|
||||
{
|
||||
'res': 1024,
|
||||
'total_vram': 10,
|
||||
'bs': 3,
|
||||
'dtype': torch.float16
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def find_batch_size(ensemble_size: int, input_res: int,
|
||||
dtype: torch.dtype) -> int:
|
||||
"""
|
||||
Automatically search for suitable operating batch size.
|
||||
|
||||
Args:
|
||||
ensemble_size (`int`):
|
||||
Number of predictions to be ensembled.
|
||||
input_res (`int`):
|
||||
Operating resolution of the input image.
|
||||
|
||||
Returns:
|
||||
`int`: Operating batch size.
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return 1
|
||||
|
||||
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
||||
filtered_bs_search_table = [
|
||||
s for s in bs_search_table if s['dtype'] == dtype
|
||||
]
|
||||
for settings in sorted(
|
||||
filtered_bs_search_table,
|
||||
key=lambda k: (k['res'], -k['total_vram']),
|
||||
):
|
||||
if input_res <= settings['res'] and total_vram >= settings[
|
||||
'total_vram']:
|
||||
bs = settings['bs']
|
||||
if bs > ensemble_size:
|
||||
bs = ensemble_size
|
||||
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
||||
bs = math.ceil(ensemble_size / 2)
|
||||
return bs
|
||||
|
||||
return 1
|
||||
|
||||
|
||||
def inter_distances(tensors: torch.Tensor):
|
||||
"""
|
||||
To calculate the distance between each two depth maps.
|
||||
"""
|
||||
distances = []
|
||||
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
|
||||
arr1 = tensors[i:i + 1]
|
||||
arr2 = tensors[j:j + 1]
|
||||
distances.append(arr1 - arr2)
|
||||
dist = torch.concatenate(distances, dim=0)
|
||||
return dist
|
||||
|
||||
|
||||
def ensemble_depths(
|
||||
input_images: torch.Tensor,
|
||||
regularizer_strength: float = 0.02,
|
||||
max_iter: int = 2,
|
||||
tol: float = 1e-3,
|
||||
reduction: str = 'median',
|
||||
max_res: int = None,
|
||||
):
|
||||
"""
|
||||
To ensemble multiple affine-invariant depth images (up to scale and shift),
|
||||
by aligning estimating the scale and shift
|
||||
"""
|
||||
device = input_images.device
|
||||
dtype = input_images.dtype
|
||||
np_dtype = np.float32
|
||||
|
||||
original_input = input_images.clone()
|
||||
n_img = input_images.shape[0]
|
||||
ori_shape = input_images.shape
|
||||
|
||||
if max_res is not None:
|
||||
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
|
||||
if scale_factor < 1:
|
||||
downscaler = torch.nn.Upsample(
|
||||
scale_factor=scale_factor, mode='nearest')
|
||||
input_images = downscaler(torch.from_numpy(input_images)).numpy()
|
||||
|
||||
# init guess
|
||||
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
||||
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
||||
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
|
||||
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
|
||||
x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
|
||||
|
||||
input_images = input_images.to(device)
|
||||
|
||||
# objective function
|
||||
def closure(x):
|
||||
length = len(x)
|
||||
s = x[:int(length / 2)]
|
||||
t = x[int(length / 2):]
|
||||
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
||||
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
||||
|
||||
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view(
|
||||
(-1, 1, 1))
|
||||
dists = inter_distances(transformed_arrays)
|
||||
sqrt_dist = torch.sqrt(torch.mean(dists**2))
|
||||
|
||||
if 'mean' == reduction:
|
||||
pred = torch.mean(transformed_arrays, dim=0)
|
||||
elif 'median' == reduction:
|
||||
pred = torch.median(transformed_arrays, dim=0).values
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
near_err = torch.sqrt((0 - torch.min(pred))**2)
|
||||
far_err = torch.sqrt((1 - torch.max(pred))**2)
|
||||
|
||||
err = sqrt_dist + (near_err + far_err) * regularizer_strength
|
||||
err = err.detach().cpu().numpy().astype(np_dtype)
|
||||
return err
|
||||
|
||||
res = minimize(
|
||||
closure,
|
||||
x,
|
||||
method='BFGS',
|
||||
tol=tol,
|
||||
options={
|
||||
'maxiter': max_iter,
|
||||
'disp': False
|
||||
})
|
||||
x = res.x
|
||||
length = len(x)
|
||||
s = x[:int(length / 2)]
|
||||
t = x[int(length / 2):]
|
||||
|
||||
# Prediction
|
||||
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
||||
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
||||
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
|
||||
if 'mean' == reduction:
|
||||
aligned_images = torch.mean(transformed_arrays, dim=0)
|
||||
std = torch.std(transformed_arrays, dim=0)
|
||||
uncertainty = std
|
||||
elif 'median' == reduction:
|
||||
aligned_images = torch.median(transformed_arrays, dim=0).values
|
||||
# MAD (median absolute deviation) as uncertainty indicator
|
||||
abs_dev = torch.abs(transformed_arrays - aligned_images)
|
||||
mad = torch.median(abs_dev, dim=0).values
|
||||
uncertainty = mad
|
||||
else:
|
||||
raise ValueError(f'Unknown reduction method: {reduction}')
|
||||
|
||||
# Scale and shift to [0, 1]
|
||||
_min = torch.min(aligned_images)
|
||||
_max = torch.max(aligned_images)
|
||||
aligned_images = (aligned_images - _min) / (_max - _min)
|
||||
uncertainty /= _max - _min
|
||||
|
||||
return aligned_images, uncertainty
|
||||
|
||||
|
||||
def colorize_depth_maps(depth_map,
|
||||
min_depth,
|
||||
max_depth,
|
||||
cmap='Spectral',
|
||||
valid_mask=None):
|
||||
"""
|
||||
Colorize depth maps.
|
||||
"""
|
||||
assert len(depth_map.shape) >= 2, 'Invalid dimension'
|
||||
|
||||
if isinstance(depth_map, torch.Tensor):
|
||||
depth = depth_map.detach().clone().squeeze().numpy()
|
||||
elif isinstance(depth_map, np.ndarray):
|
||||
depth = depth_map.copy().squeeze()
|
||||
# reshape to [ (B,) H, W ]
|
||||
if depth.ndim < 3:
|
||||
depth = depth[np.newaxis, :, :]
|
||||
|
||||
# colorize
|
||||
cm = matplotlib.colormaps[cmap]
|
||||
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
||||
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
|
||||
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
||||
|
||||
if valid_mask is not None:
|
||||
if isinstance(depth_map, torch.Tensor):
|
||||
valid_mask = valid_mask.detach().numpy()
|
||||
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
||||
if valid_mask.ndim < 3:
|
||||
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
||||
else:
|
||||
valid_mask = valid_mask[:, np.newaxis, :, :]
|
||||
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
||||
img_colored_np[~valid_mask] = 0
|
||||
|
||||
if isinstance(depth_map, torch.Tensor):
|
||||
img_colored = torch.from_numpy(img_colored_np).float()
|
||||
elif isinstance(depth_map, np.ndarray):
|
||||
img_colored = img_colored_np
|
||||
|
||||
return img_colored
|
||||
|
||||
|
||||
def chw2hwc(chw):
|
||||
assert 3 == len(chw.shape)
|
||||
if isinstance(chw, torch.Tensor):
|
||||
hwc = torch.permute(chw, (1, 2, 0))
|
||||
elif isinstance(chw, np.ndarray):
|
||||
hwc = np.moveaxis(chw, 0, -1)
|
||||
return hwc
|
||||
|
||||
|
||||
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
|
||||
"""
|
||||
Resize image to limit maximum edge length while keeping aspect ratio.
|
||||
|
||||
Args:
|
||||
img (`Image.Image`):
|
||||
Image to be resized.
|
||||
max_edge_resolution (`int`):
|
||||
Maximum edge length (pixel).
|
||||
|
||||
Returns:
|
||||
`Image.Image`: Resized image.
|
||||
"""
|
||||
original_width, original_height = img.size
|
||||
downscale_factor = min(max_edge_resolution / original_width,
|
||||
max_edge_resolution / original_height)
|
||||
|
||||
new_width = int(original_width * downscale_factor)
|
||||
new_height = int(original_height * downscale_factor)
|
||||
|
||||
resized_img = img.resize((new_width, new_height))
|
||||
return resized_img
|
||||
@@ -121,6 +121,7 @@ if TYPE_CHECKING:
|
||||
from .image_local_feature_matching_pipeline import ImageLocalFeatureMatchingPipeline
|
||||
from .rife_video_frame_interpolation_pipeline import RIFEVideoFrameInterpolationPipeline
|
||||
from .anydoor_pipeline import AnydoorPipeline
|
||||
from .image_depth_estimation_marigold_pipeline import ImageDepthEstimationMarigoldPipeline
|
||||
from .self_supervised_depth_completion_pipeline import SelfSupervisedDepthCompletionPipeline
|
||||
|
||||
else:
|
||||
@@ -305,6 +306,9 @@ else:
|
||||
'RIFEVideoFrameInterpolationPipeline'
|
||||
],
|
||||
'anydoor_pipeline': ['AnydoorPipeline'],
|
||||
'image_depth_estimation_marigold_pipeline': [
|
||||
'ImageDepthEstimationMarigoldPipeline'
|
||||
],
|
||||
'self_supervised_depth_completion_pipeline': [
|
||||
'SelfSupervisedDepthCompletionPipeline'
|
||||
],
|
||||
|
||||
@@ -0,0 +1,409 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import (AutoencoderKL, DDIMScheduler, DiffusionPipeline,
|
||||
UNet2DConditionModel)
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.image_depth_estimation_marigold import (
|
||||
MarigoldDepthOutput, chw2hwc, colorize_depth_maps, ensemble_depths,
|
||||
find_batch_size, inter_distances, resize_max_res)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_depth_estimation,
|
||||
module_name=Pipelines.image_depth_estimation_marigold)
|
||||
class ImageDepthEstimationMarigoldPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model=str, **kwargs):
|
||||
r"""
|
||||
use `model` to create a image depth estimation pipeline for prediction
|
||||
Args:
|
||||
>>> model: modelscope model_id "Damo_XR_Lab/cv_marigold_monocular-depth-estimation"
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
>>> from modelscope.outputs import OutputKeys
|
||||
>>>
|
||||
>>> output_image_path = './result.png'
|
||||
>>> img = './test.jpg'
|
||||
>>>
|
||||
>>> pipe = pipeline(
|
||||
>>> Tasks.image_depth_estimation,
|
||||
>>> model='Damo_XR_Lab/cv_marigold_monocular-depth-estimation')
|
||||
>>>
|
||||
>>> depth_vis = pipe(input)[OutputKeys.DEPTHS_COLOR]
|
||||
>>> depth_vis.save(output_image_path)
|
||||
>>> print('pipeline: the output image path is {}'.format(output_image_path))
|
||||
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
self._device = getattr(
|
||||
kwargs, 'device',
|
||||
torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
||||
self._dtype = torch.float16
|
||||
logger.info('load depth estimation marigold pipeline done')
|
||||
|
||||
self.checkpoint_path = os.path.join(model, 'Marigold_v1_merged_2')
|
||||
self.pipeline = _MarigoldPipeline.from_pretrained(
|
||||
self.checkpoint_path, torch_dtype=self._dtype)
|
||||
self.pipeline.to(self._device)
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
# print('pipeline preprocess')
|
||||
# TODO: input type: Image
|
||||
return input
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
self.input_image = Image.open(input)
|
||||
# print('load', input, self.input_image.size)
|
||||
|
||||
results = self.pipeline(self.input_image)
|
||||
return results
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
depths: np.ndarray = inputs.depth_np
|
||||
depths_color: Image.Image = inputs.depth_colored
|
||||
outputs = {
|
||||
OutputKeys.DEPTHS: depths,
|
||||
OutputKeys.DEPTHS_COLOR: depths_color
|
||||
}
|
||||
return outputs
|
||||
|
||||
|
||||
class _MarigoldPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`].
|
||||
Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
unet (`UNet2DConditionModel`):
|
||||
Conditional U-Net to denoise the depth latent, conditioned on image latent.
|
||||
vae (`AutoencoderKL`):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
|
||||
to and from latent representations.
|
||||
scheduler (`DDIMScheduler`):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
Text-encoder, for empty text embedding.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
CLIP tokenizer.
|
||||
"""
|
||||
rgb_latent_scale_factor = 0.18215
|
||||
depth_latent_scale_factor = 0.18215
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
vae: AutoencoderKL,
|
||||
scheduler: DDIMScheduler,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
self.empty_text_embed = None
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
input_image: Image,
|
||||
denoising_steps: int = 10,
|
||||
ensemble_size: int = 10,
|
||||
processing_res: int = 768,
|
||||
match_input_res: bool = True,
|
||||
batch_size: int = 0,
|
||||
color_map: str = 'Spectral',
|
||||
show_progress_bar: bool = True,
|
||||
ensemble_kwargs: Dict = None,
|
||||
) -> MarigoldDepthOutput:
|
||||
r"""
|
||||
Function invoked when calling the pipeline.
|
||||
|
||||
Args:
|
||||
input_image (`Image`):
|
||||
Input RGB (or gray-scale) image.
|
||||
processing_res (`int`, *optional*, defaults to `768`):
|
||||
Maximum resolution of processing.
|
||||
If set to 0: will not resize at all.
|
||||
match_input_res (`bool`, *optional*, defaults to `True`):
|
||||
Resize depth prediction to match input resolution.
|
||||
Only valid if `limit_input_res` is not None.
|
||||
denoising_steps (`int`, *optional*, defaults to `10`):
|
||||
Number of diffusion denoising steps (DDIM) during inference.
|
||||
ensemble_size (`int`, *optional*, defaults to `10`):
|
||||
Number of predictions to be ensembled.
|
||||
batch_size (`int`, *optional*, defaults to `0`):
|
||||
Inference batch size, no bigger than `num_ensemble`.
|
||||
If set to 0, the script will automatically decide the proper batch size.
|
||||
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
||||
Display a progress bar of diffusion denoising.
|
||||
color_map (`str`, *optional*, defaults to `"Spectral"`):
|
||||
Colormap used to colorize the depth map.
|
||||
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
|
||||
Arguments for detailed ensembling settings.
|
||||
Returns:
|
||||
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
|
||||
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
|
||||
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W]
|
||||
and values in [0, 1]
|
||||
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
|
||||
coming from ensembling. None if `ensemble_size = 1`
|
||||
"""
|
||||
|
||||
device = self.device
|
||||
input_size = input_image.size
|
||||
|
||||
if not match_input_res:
|
||||
assert (processing_res is not None
|
||||
), 'Value error: `resize_output_back` is only valid with '
|
||||
assert processing_res >= 0
|
||||
assert denoising_steps >= 1
|
||||
assert ensemble_size >= 1
|
||||
|
||||
# ----------------- Image Preprocess -----------------
|
||||
# Resize image
|
||||
if processing_res > 0:
|
||||
input_image = resize_max_res(
|
||||
input_image, max_edge_resolution=processing_res)
|
||||
# Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
|
||||
input_image = input_image.convert('RGB')
|
||||
image = np.asarray(input_image)
|
||||
|
||||
# Normalize rgb values
|
||||
rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
|
||||
rgb_norm = rgb / 255.0
|
||||
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
|
||||
rgb_norm = rgb_norm.to(device)
|
||||
assert rgb_norm.min() >= 0.0 and rgb_norm.max() <= 1.0
|
||||
|
||||
# ----------------- Predicting depth -----------------
|
||||
# Batch repeated input image
|
||||
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
|
||||
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
||||
if batch_size > 0:
|
||||
_bs = batch_size
|
||||
else:
|
||||
_bs = find_batch_size(
|
||||
ensemble_size=ensemble_size,
|
||||
input_res=max(rgb_norm.shape[1:]),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
single_rgb_loader = DataLoader(
|
||||
single_rgb_dataset, batch_size=_bs, shuffle=False)
|
||||
|
||||
# Predict depth maps (batched)
|
||||
depth_pred_ls = []
|
||||
if show_progress_bar:
|
||||
iterable = tqdm(
|
||||
single_rgb_loader,
|
||||
desc=' ' * 2 + 'Inference batches',
|
||||
leave=False)
|
||||
else:
|
||||
iterable = single_rgb_loader
|
||||
for batch in iterable:
|
||||
(batched_img, ) = batch
|
||||
depth_pred_raw = self.single_infer(
|
||||
rgb_in=batched_img,
|
||||
num_inference_steps=denoising_steps,
|
||||
show_pbar=show_progress_bar,
|
||||
)
|
||||
depth_pred_ls.append(depth_pred_raw.detach().clone())
|
||||
depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze()
|
||||
torch.cuda.empty_cache() # clear vram cache for ensembling
|
||||
|
||||
# ----------------- Test-time ensembling -----------------
|
||||
if ensemble_size > 1:
|
||||
depth_pred, pred_uncert = ensemble_depths(
|
||||
depth_preds, **(ensemble_kwargs or {}))
|
||||
else:
|
||||
depth_pred = depth_preds
|
||||
pred_uncert = None
|
||||
|
||||
# ----------------- Post processing -----------------
|
||||
# Scale prediction to [0, 1]
|
||||
min_d = torch.min(depth_pred)
|
||||
max_d = torch.max(depth_pred)
|
||||
depth_pred = (depth_pred - min_d) / (max_d - min_d)
|
||||
|
||||
# Convert to numpy
|
||||
depth_pred = depth_pred.cpu().numpy().astype(np.float32)
|
||||
|
||||
# Resize back to original resolution
|
||||
if match_input_res:
|
||||
pred_img = Image.fromarray(depth_pred)
|
||||
pred_img = pred_img.resize(input_size)
|
||||
depth_pred = np.asarray(pred_img)
|
||||
|
||||
# Clip output range
|
||||
depth_pred = depth_pred.clip(0, 1)
|
||||
|
||||
# Colorize
|
||||
depth_colored = colorize_depth_maps(
|
||||
depth_pred, 0, 1,
|
||||
cmap=color_map).squeeze() # [3, H, W], value in (0, 1)
|
||||
depth_colored = (depth_colored * 255).astype(np.uint8)
|
||||
depth_colored_hwc = chw2hwc(depth_colored)
|
||||
depth_colored_img = Image.fromarray(depth_colored_hwc)
|
||||
return MarigoldDepthOutput(
|
||||
depth_np=depth_pred,
|
||||
depth_colored=depth_colored_img,
|
||||
uncertainty=pred_uncert,
|
||||
)
|
||||
|
||||
def __encode_empty_text(self):
|
||||
"""
|
||||
Encode text embedding for empty prompt
|
||||
"""
|
||||
prompt = ''
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding='do_not_pad',
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors='pt',
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
||||
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(
|
||||
self.dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int,
|
||||
show_pbar: bool) -> torch.Tensor:
|
||||
r"""
|
||||
Perform an individual depth prediction without ensembling.
|
||||
|
||||
Args:
|
||||
rgb_in (`torch.Tensor`):
|
||||
Input RGB image.
|
||||
num_inference_steps (`int`):
|
||||
Number of diffusion denoisign steps (DDIM) during inference.
|
||||
show_pbar (`bool`):
|
||||
Display a progress bar of diffusion denoising.
|
||||
Returns:
|
||||
`torch.Tensor`: Predicted depth map.
|
||||
"""
|
||||
device = rgb_in.device
|
||||
|
||||
# Set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps # [T]
|
||||
|
||||
# Encode image
|
||||
rgb_latent = self.encode_rgb(rgb_in)
|
||||
|
||||
# Initial depth map (noise)
|
||||
depth_latent = torch.randn(
|
||||
rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w]
|
||||
|
||||
# Batched empty text embedding
|
||||
if self.empty_text_embed is None:
|
||||
self.__encode_empty_text()
|
||||
batch_empty_text_embed = self.empty_text_embed.repeat(
|
||||
(rgb_latent.shape[0], 1, 1)) # [B, 2, 1024]
|
||||
|
||||
# Denoising loop
|
||||
if show_pbar:
|
||||
iterable = tqdm(
|
||||
enumerate(timesteps),
|
||||
total=len(timesteps),
|
||||
leave=False,
|
||||
desc=' ' * 4 + 'Diffusion denoising',
|
||||
)
|
||||
else:
|
||||
iterable = enumerate(timesteps)
|
||||
|
||||
for i, t in iterable:
|
||||
unet_input = torch.cat([rgb_latent, depth_latent],
|
||||
dim=1) # this order is important
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
unet_input, t, encoder_hidden_states=batch_empty_text_embed
|
||||
).sample # [B, 4, h, w]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
depth_latent = self.scheduler.step(noise_pred, t,
|
||||
depth_latent).prev_sample
|
||||
torch.cuda.empty_cache()
|
||||
depth = self.decode_depth(depth_latent)
|
||||
|
||||
# clip prediction
|
||||
depth = torch.clip(depth, -1.0, 1.0)
|
||||
# shift to [0, 1]
|
||||
depth = (depth + 1.0) / 2.0
|
||||
|
||||
return depth
|
||||
|
||||
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Encode RGB image into latent.
|
||||
|
||||
Args:
|
||||
rgb_in (`torch.Tensor`):
|
||||
Input RGB image to be encoded.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Image latent.
|
||||
"""
|
||||
# encode
|
||||
h = self.vae.encoder(rgb_in)
|
||||
moments = self.vae.quant_conv(h)
|
||||
mean, logvar = torch.chunk(moments, 2, dim=1)
|
||||
# scale latent
|
||||
rgb_latent = mean * self.rgb_latent_scale_factor
|
||||
return rgb_latent
|
||||
|
||||
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Decode depth latent into depth map.
|
||||
|
||||
Args:
|
||||
depth_latent (`torch.Tensor`):
|
||||
Depth latent to be decoded.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`: Decoded depth map.
|
||||
"""
|
||||
# scale latent
|
||||
depth_latent = depth_latent / self.depth_latent_scale_factor
|
||||
# decode
|
||||
z = self.vae.post_quant_conv(depth_latent)
|
||||
stacked = self.vae.decoder(z)
|
||||
# mean of output channels
|
||||
depth_mean = stacked.mean(dim=1, keepdim=True)
|
||||
return depth_mean
|
||||
|
||||
def forward(self, x):
|
||||
out = self.__call__(x)
|
||||
return out
|
||||
42
tests/pipelines/test_image_depth_estimation_marigold.py
Normal file
42
tests/pipelines/test_image_depth_estimation_marigold.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.cv import ImageDepthEstimationMarigoldPipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class ImageDepthEstimationMarigoldTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.image_depth_estimation
|
||||
self.model_id = 'Damo_XR_Lab/cv_marigold_monocular-depth-estimation'
|
||||
self.image = 'data/in-the-wild_example/example_0.jpg'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_name(self):
|
||||
marigold = pipeline(task=self.task, model=self.model_id)
|
||||
input_path = os.path.join(marigold.model, self.image)
|
||||
result = marigold(input=input_path)
|
||||
depth_vis = result[OutputKeys.DEPTHS_COLOR]
|
||||
depth_vis.save('result_modelname.jpg')
|
||||
print('Test run with model name ok.')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
marigold_pipe = ImageDepthEstimationMarigoldPipeline(cache_path)
|
||||
marigold_pipe.group_key = self.task
|
||||
input_path = os.path.join(cache_path, self.image)
|
||||
result = marigold_pipe(input=input_path)
|
||||
depth_vis = result[OutputKeys.DEPTHS_COLOR]
|
||||
depth_vis.save('result_snapshot.jpg')
|
||||
print('Test run with snapshot ok.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user