diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 6e82ec43..772dbb28 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -462,6 +462,7 @@ class Pipelines(object): image_quality_assessment_degradation = 'image-quality-assessment-degradation' vision_efficient_tuning = 'vision-efficient-tuning' image_bts_depth_estimation = 'image-bts-depth-estimation' + image_depth_estimation_marigold = 'image-depth-estimation-marigold' pedestrian_attribute_recognition = 'resnet50_pedestrian-attribute-recognition_image' text_to_360panorama_image = 'text-to-360panorama-image' image_try_on = 'image-try-on' diff --git a/modelscope/models/cv/image_depth_estimation_marigold/__init__.py b/modelscope/models/cv/image_depth_estimation_marigold/__init__.py new file mode 100644 index 00000000..15e4c01e --- /dev/null +++ b/modelscope/models/cv/image_depth_estimation_marigold/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .marigold import MarigoldDepthOutput + from .marigold_utils import (chw2hwc, colorize_depth_maps, ensemble_depths, + find_batch_size, inter_distances, + resize_max_res) +else: + _import_structure = { + 'marigold': ['MarigoldDepthOutput'], + 'marigold_utils': [ + 'find_batch_size', 'inter_distances', 'ensemble_depths', + 'colorize_depth_maps', 'chw2hwc', 'resize_max_res' + ] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_depth_estimation_marigold/marigold.py b/modelscope/models/cv/image_depth_estimation_marigold/marigold.py new file mode 100644 index 00000000..a597b68c --- /dev/null +++ b/modelscope/models/cv/image_depth_estimation_marigold/marigold.py @@ -0,0 +1,42 @@ +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# More information about the method can be found at https://marigoldmonodepth.github.io +# -------------------------------------------------------------------------- + +from typing import Dict, Union + +import numpy as np +from diffusers.utils import BaseOutput +from PIL import Image + + +class MarigoldDepthOutput(BaseOutput): + """ + Output class for Marigold monocular depth prediction pipeline. + + Args: + depth_np (`np.ndarray`): + Predicted depth map, with depth values in the range of [0, 1]. + depth_colored (`PIL.Image.Image`): + Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. + uncertainty (`None` or `np.ndarray`): + Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. + """ + + depth_np: np.ndarray + depth_colored: Image.Image + uncertainty: Union[None, np.ndarray] diff --git a/modelscope/models/cv/image_depth_estimation_marigold/marigold_utils.py b/modelscope/models/cv/image_depth_estimation_marigold/marigold_utils.py new file mode 100644 index 00000000..00bceafe --- /dev/null +++ b/modelscope/models/cv/image_depth_estimation_marigold/marigold_utils.py @@ -0,0 +1,364 @@ +# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- +# If you find this code useful, we kindly ask you to cite our paper in your work. +# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation +# More information about the method can be found at https://marigoldmonodepth.github.io + +import math + +import matplotlib +import numpy as np +import torch +from PIL import Image +from scipy.optimize import minimize + +# Search table for suggested max. inference batch size +bs_search_table = [ + # tested on A100-PCIE-80GB + { + 'res': 768, + 'total_vram': 79, + 'bs': 35, + 'dtype': torch.float32 + }, + { + 'res': 1024, + 'total_vram': 79, + 'bs': 20, + 'dtype': torch.float32 + }, + # tested on A100-PCIE-40GB + { + 'res': 768, + 'total_vram': 39, + 'bs': 15, + 'dtype': torch.float32 + }, + { + 'res': 1024, + 'total_vram': 39, + 'bs': 8, + 'dtype': torch.float32 + }, + { + 'res': 768, + 'total_vram': 39, + 'bs': 30, + 'dtype': torch.float16 + }, + { + 'res': 1024, + 'total_vram': 39, + 'bs': 15, + 'dtype': torch.float16 + }, + # tested on RTX3090, RTX4090 + { + 'res': 512, + 'total_vram': 23, + 'bs': 20, + 'dtype': torch.float32 + }, + { + 'res': 768, + 'total_vram': 23, + 'bs': 7, + 'dtype': torch.float32 + }, + { + 'res': 1024, + 'total_vram': 23, + 'bs': 3, + 'dtype': torch.float32 + }, + { + 'res': 512, + 'total_vram': 23, + 'bs': 40, + 'dtype': torch.float16 + }, + { + 'res': 768, + 'total_vram': 23, + 'bs': 18, + 'dtype': torch.float16 + }, + { + 'res': 1024, + 'total_vram': 23, + 'bs': 10, + 'dtype': torch.float16 + }, + # tested on GTX1080Ti + { + 'res': 512, + 'total_vram': 10, + 'bs': 5, + 'dtype': torch.float32 + }, + { + 'res': 768, + 'total_vram': 10, + 'bs': 2, + 'dtype': torch.float32 + }, + { + 'res': 512, + 'total_vram': 10, + 'bs': 10, + 'dtype': torch.float16 + }, + { + 'res': 768, + 'total_vram': 10, + 'bs': 5, + 'dtype': torch.float16 + }, + { + 'res': 1024, + 'total_vram': 10, + 'bs': 3, + 'dtype': torch.float16 + }, +] + + +def find_batch_size(ensemble_size: int, input_res: int, + dtype: torch.dtype) -> int: + """ + Automatically search for suitable operating batch size. + + Args: + ensemble_size (`int`): + Number of predictions to be ensembled. + input_res (`int`): + Operating resolution of the input image. + + Returns: + `int`: Operating batch size. + """ + if not torch.cuda.is_available(): + return 1 + + total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 + filtered_bs_search_table = [ + s for s in bs_search_table if s['dtype'] == dtype + ] + for settings in sorted( + filtered_bs_search_table, + key=lambda k: (k['res'], -k['total_vram']), + ): + if input_res <= settings['res'] and total_vram >= settings[ + 'total_vram']: + bs = settings['bs'] + if bs > ensemble_size: + bs = ensemble_size + elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: + bs = math.ceil(ensemble_size / 2) + return bs + + return 1 + + +def inter_distances(tensors: torch.Tensor): + """ + To calculate the distance between each two depth maps. + """ + distances = [] + for i, j in torch.combinations(torch.arange(tensors.shape[0])): + arr1 = tensors[i:i + 1] + arr2 = tensors[j:j + 1] + distances.append(arr1 - arr2) + dist = torch.concatenate(distances, dim=0) + return dist + + +def ensemble_depths( + input_images: torch.Tensor, + regularizer_strength: float = 0.02, + max_iter: int = 2, + tol: float = 1e-3, + reduction: str = 'median', + max_res: int = None, +): + """ + To ensemble multiple affine-invariant depth images (up to scale and shift), + by aligning estimating the scale and shift + """ + device = input_images.device + dtype = input_images.dtype + np_dtype = np.float32 + + original_input = input_images.clone() + n_img = input_images.shape[0] + ori_shape = input_images.shape + + if max_res is not None: + scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:])) + if scale_factor < 1: + downscaler = torch.nn.Upsample( + scale_factor=scale_factor, mode='nearest') + input_images = downscaler(torch.from_numpy(input_images)).numpy() + + # init guess + _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) + _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) + s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) + t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) + x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype) + + input_images = input_images.to(device) + + # objective function + def closure(x): + length = len(x) + s = x[:int(length / 2)] + t = x[int(length / 2):] + s = torch.from_numpy(s).to(dtype=dtype).to(device) + t = torch.from_numpy(t).to(dtype=dtype).to(device) + + transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view( + (-1, 1, 1)) + dists = inter_distances(transformed_arrays) + sqrt_dist = torch.sqrt(torch.mean(dists**2)) + + if 'mean' == reduction: + pred = torch.mean(transformed_arrays, dim=0) + elif 'median' == reduction: + pred = torch.median(transformed_arrays, dim=0).values + else: + raise ValueError + + near_err = torch.sqrt((0 - torch.min(pred))**2) + far_err = torch.sqrt((1 - torch.max(pred))**2) + + err = sqrt_dist + (near_err + far_err) * regularizer_strength + err = err.detach().cpu().numpy().astype(np_dtype) + return err + + res = minimize( + closure, + x, + method='BFGS', + tol=tol, + options={ + 'maxiter': max_iter, + 'disp': False + }) + x = res.x + length = len(x) + s = x[:int(length / 2)] + t = x[int(length / 2):] + + # Prediction + s = torch.from_numpy(s).to(dtype=dtype).to(device) + t = torch.from_numpy(t).to(dtype=dtype).to(device) + transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) + if 'mean' == reduction: + aligned_images = torch.mean(transformed_arrays, dim=0) + std = torch.std(transformed_arrays, dim=0) + uncertainty = std + elif 'median' == reduction: + aligned_images = torch.median(transformed_arrays, dim=0).values + # MAD (median absolute deviation) as uncertainty indicator + abs_dev = torch.abs(transformed_arrays - aligned_images) + mad = torch.median(abs_dev, dim=0).values + uncertainty = mad + else: + raise ValueError(f'Unknown reduction method: {reduction}') + + # Scale and shift to [0, 1] + _min = torch.min(aligned_images) + _max = torch.max(aligned_images) + aligned_images = (aligned_images - _min) / (_max - _min) + uncertainty /= _max - _min + + return aligned_images, uncertainty + + +def colorize_depth_maps(depth_map, + min_depth, + max_depth, + cmap='Spectral', + valid_mask=None): + """ + Colorize depth maps. + """ + assert len(depth_map.shape) >= 2, 'Invalid dimension' + + if isinstance(depth_map, torch.Tensor): + depth = depth_map.detach().clone().squeeze().numpy() + elif isinstance(depth_map, np.ndarray): + depth = depth_map.copy().squeeze() + # reshape to [ (B,) H, W ] + if depth.ndim < 3: + depth = depth[np.newaxis, :, :] + + # colorize + cm = matplotlib.colormaps[cmap] + depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) + img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1 + img_colored_np = np.rollaxis(img_colored_np, 3, 1) + + if valid_mask is not None: + if isinstance(depth_map, torch.Tensor): + valid_mask = valid_mask.detach().numpy() + valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W] + if valid_mask.ndim < 3: + valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] + else: + valid_mask = valid_mask[:, np.newaxis, :, :] + valid_mask = np.repeat(valid_mask, 3, axis=1) + img_colored_np[~valid_mask] = 0 + + if isinstance(depth_map, torch.Tensor): + img_colored = torch.from_numpy(img_colored_np).float() + elif isinstance(depth_map, np.ndarray): + img_colored = img_colored_np + + return img_colored + + +def chw2hwc(chw): + assert 3 == len(chw.shape) + if isinstance(chw, torch.Tensor): + hwc = torch.permute(chw, (1, 2, 0)) + elif isinstance(chw, np.ndarray): + hwc = np.moveaxis(chw, 0, -1) + return hwc + + +def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image: + """ + Resize image to limit maximum edge length while keeping aspect ratio. + + Args: + img (`Image.Image`): + Image to be resized. + max_edge_resolution (`int`): + Maximum edge length (pixel). + + Returns: + `Image.Image`: Resized image. + """ + original_width, original_height = img.size + downscale_factor = min(max_edge_resolution / original_width, + max_edge_resolution / original_height) + + new_width = int(original_width * downscale_factor) + new_height = int(original_height * downscale_factor) + + resized_img = img.resize((new_width, new_height)) + return resized_img diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index f2ca09bf..d987e989 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -121,6 +121,7 @@ if TYPE_CHECKING: from .image_local_feature_matching_pipeline import ImageLocalFeatureMatchingPipeline from .rife_video_frame_interpolation_pipeline import RIFEVideoFrameInterpolationPipeline from .anydoor_pipeline import AnydoorPipeline + from .image_depth_estimation_marigold_pipeline import ImageDepthEstimationMarigoldPipeline from .self_supervised_depth_completion_pipeline import SelfSupervisedDepthCompletionPipeline else: @@ -305,6 +306,9 @@ else: 'RIFEVideoFrameInterpolationPipeline' ], 'anydoor_pipeline': ['AnydoorPipeline'], + 'image_depth_estimation_marigold_pipeline': [ + 'ImageDepthEstimationMarigoldPipeline' + ], 'self_supervised_depth_completion_pipeline': [ 'SelfSupervisedDepthCompletionPipeline' ], diff --git a/modelscope/pipelines/cv/image_depth_estimation_marigold_pipeline.py b/modelscope/pipelines/cv/image_depth_estimation_marigold_pipeline.py new file mode 100644 index 00000000..e5cdd7e7 --- /dev/null +++ b/modelscope/pipelines/cv/image_depth_estimation_marigold_pipeline.py @@ -0,0 +1,409 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +from typing import Any, Dict, Union + +import numpy as np +import torch +from diffusers import (AutoencoderKL, DDIMScheduler, DiffusionPipeline, + UNet2DConditionModel) +from PIL import Image +from torch.utils.data import DataLoader, TensorDataset +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_depth_estimation_marigold import ( + MarigoldDepthOutput, chw2hwc, colorize_depth_maps, ensemble_depths, + find_batch_size, inter_distances, resize_max_res) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_depth_estimation, + module_name=Pipelines.image_depth_estimation_marigold) +class ImageDepthEstimationMarigoldPipeline(Pipeline): + + def __init__(self, model=str, **kwargs): + r""" + use `model` to create a image depth estimation pipeline for prediction + Args: + >>> model: modelscope model_id "Damo_XR_Lab/cv_marigold_monocular-depth-estimation" + + Examples: + + >>> from modelscope.pipelines import pipeline + >>> from modelscope.utils.constant import Tasks + >>> from modelscope.outputs import OutputKeys + >>> + >>> output_image_path = './result.png' + >>> img = './test.jpg' + >>> + >>> pipe = pipeline( + >>> Tasks.image_depth_estimation, + >>> model='Damo_XR_Lab/cv_marigold_monocular-depth-estimation') + >>> + >>> depth_vis = pipe(input)[OutputKeys.DEPTHS_COLOR] + >>> depth_vis.save(output_image_path) + >>> print('pipeline: the output image path is {}'.format(output_image_path)) + + """ + super().__init__(model=model, **kwargs) + + self._device = getattr( + kwargs, 'device', + torch.device('cuda' if torch.cuda.is_available() else 'cpu')) + self._dtype = torch.float16 + logger.info('load depth estimation marigold pipeline done') + + self.checkpoint_path = os.path.join(model, 'Marigold_v1_merged_2') + self.pipeline = _MarigoldPipeline.from_pretrained( + self.checkpoint_path, torch_dtype=self._dtype) + self.pipeline.to(self._device) + + def preprocess(self, input: Input) -> Dict[str, Any]: + # print('pipeline preprocess') + # TODO: input type: Image + return input + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + self.input_image = Image.open(input) + # print('load', input, self.input_image.size) + + results = self.pipeline(self.input_image) + return results + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + depths: np.ndarray = inputs.depth_np + depths_color: Image.Image = inputs.depth_colored + outputs = { + OutputKeys.DEPTHS: depths, + OutputKeys.DEPTHS_COLOR: depths_color + } + return outputs + + +class _MarigoldPipeline(DiffusionPipeline): + """ + Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io. + + This model inherits from [`DiffusionPipeline`]. + Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + unet (`UNet2DConditionModel`): + Conditional U-Net to denoise the depth latent, conditioned on image latent. + vae (`AutoencoderKL`): + Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps + to and from latent representations. + scheduler (`DDIMScheduler`): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + text_encoder (`CLIPTextModel`): + Text-encoder, for empty text embedding. + tokenizer (`CLIPTokenizer`): + CLIP tokenizer. + """ + rgb_latent_scale_factor = 0.18215 + depth_latent_scale_factor = 0.18215 + + def __init__( + self, + unet: UNet2DConditionModel, + vae: AutoencoderKL, + scheduler: DDIMScheduler, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + ): + super().__init__() + + self.register_modules( + unet=unet, + vae=vae, + scheduler=scheduler, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + + self.empty_text_embed = None + + @torch.no_grad() + def __call__( + self, + input_image: Image, + denoising_steps: int = 10, + ensemble_size: int = 10, + processing_res: int = 768, + match_input_res: bool = True, + batch_size: int = 0, + color_map: str = 'Spectral', + show_progress_bar: bool = True, + ensemble_kwargs: Dict = None, + ) -> MarigoldDepthOutput: + r""" + Function invoked when calling the pipeline. + + Args: + input_image (`Image`): + Input RGB (or gray-scale) image. + processing_res (`int`, *optional*, defaults to `768`): + Maximum resolution of processing. + If set to 0: will not resize at all. + match_input_res (`bool`, *optional*, defaults to `True`): + Resize depth prediction to match input resolution. + Only valid if `limit_input_res` is not None. + denoising_steps (`int`, *optional*, defaults to `10`): + Number of diffusion denoising steps (DDIM) during inference. + ensemble_size (`int`, *optional*, defaults to `10`): + Number of predictions to be ensembled. + batch_size (`int`, *optional*, defaults to `0`): + Inference batch size, no bigger than `num_ensemble`. + If set to 0, the script will automatically decide the proper batch size. + show_progress_bar (`bool`, *optional*, defaults to `True`): + Display a progress bar of diffusion denoising. + color_map (`str`, *optional*, defaults to `"Spectral"`): + Colormap used to colorize the depth map. + ensemble_kwargs (`dict`, *optional*, defaults to `None`): + Arguments for detailed ensembling settings. + Returns: + `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: + - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] + - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] + and values in [0, 1] + - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) + coming from ensembling. None if `ensemble_size = 1` + """ + + device = self.device + input_size = input_image.size + + if not match_input_res: + assert (processing_res is not None + ), 'Value error: `resize_output_back` is only valid with ' + assert processing_res >= 0 + assert denoising_steps >= 1 + assert ensemble_size >= 1 + + # ----------------- Image Preprocess ----------------- + # Resize image + if processing_res > 0: + input_image = resize_max_res( + input_image, max_edge_resolution=processing_res) + # Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel + input_image = input_image.convert('RGB') + image = np.asarray(input_image) + + # Normalize rgb values + rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W] + rgb_norm = rgb / 255.0 + rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) + rgb_norm = rgb_norm.to(device) + assert rgb_norm.min() >= 0.0 and rgb_norm.max() <= 1.0 + + # ----------------- Predicting depth ----------------- + # Batch repeated input image + duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) + single_rgb_dataset = TensorDataset(duplicated_rgb) + if batch_size > 0: + _bs = batch_size + else: + _bs = find_batch_size( + ensemble_size=ensemble_size, + input_res=max(rgb_norm.shape[1:]), + dtype=self.dtype, + ) + + single_rgb_loader = DataLoader( + single_rgb_dataset, batch_size=_bs, shuffle=False) + + # Predict depth maps (batched) + depth_pred_ls = [] + if show_progress_bar: + iterable = tqdm( + single_rgb_loader, + desc=' ' * 2 + 'Inference batches', + leave=False) + else: + iterable = single_rgb_loader + for batch in iterable: + (batched_img, ) = batch + depth_pred_raw = self.single_infer( + rgb_in=batched_img, + num_inference_steps=denoising_steps, + show_pbar=show_progress_bar, + ) + depth_pred_ls.append(depth_pred_raw.detach().clone()) + depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze() + torch.cuda.empty_cache() # clear vram cache for ensembling + + # ----------------- Test-time ensembling ----------------- + if ensemble_size > 1: + depth_pred, pred_uncert = ensemble_depths( + depth_preds, **(ensemble_kwargs or {})) + else: + depth_pred = depth_preds + pred_uncert = None + + # ----------------- Post processing ----------------- + # Scale prediction to [0, 1] + min_d = torch.min(depth_pred) + max_d = torch.max(depth_pred) + depth_pred = (depth_pred - min_d) / (max_d - min_d) + + # Convert to numpy + depth_pred = depth_pred.cpu().numpy().astype(np.float32) + + # Resize back to original resolution + if match_input_res: + pred_img = Image.fromarray(depth_pred) + pred_img = pred_img.resize(input_size) + depth_pred = np.asarray(pred_img) + + # Clip output range + depth_pred = depth_pred.clip(0, 1) + + # Colorize + depth_colored = colorize_depth_maps( + depth_pred, 0, 1, + cmap=color_map).squeeze() # [3, H, W], value in (0, 1) + depth_colored = (depth_colored * 255).astype(np.uint8) + depth_colored_hwc = chw2hwc(depth_colored) + depth_colored_img = Image.fromarray(depth_colored_hwc) + return MarigoldDepthOutput( + depth_np=depth_pred, + depth_colored=depth_colored_img, + uncertainty=pred_uncert, + ) + + def __encode_empty_text(self): + """ + Encode text embedding for empty prompt + """ + prompt = '' + text_inputs = self.tokenizer( + prompt, + padding='do_not_pad', + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors='pt', + ) + text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) + self.empty_text_embed = self.text_encoder(text_input_ids)[0].to( + self.dtype) + + @torch.no_grad() + def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, + show_pbar: bool) -> torch.Tensor: + r""" + Perform an individual depth prediction without ensembling. + + Args: + rgb_in (`torch.Tensor`): + Input RGB image. + num_inference_steps (`int`): + Number of diffusion denoisign steps (DDIM) during inference. + show_pbar (`bool`): + Display a progress bar of diffusion denoising. + Returns: + `torch.Tensor`: Predicted depth map. + """ + device = rgb_in.device + + # Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps # [T] + + # Encode image + rgb_latent = self.encode_rgb(rgb_in) + + # Initial depth map (noise) + depth_latent = torch.randn( + rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w] + + # Batched empty text embedding + if self.empty_text_embed is None: + self.__encode_empty_text() + batch_empty_text_embed = self.empty_text_embed.repeat( + (rgb_latent.shape[0], 1, 1)) # [B, 2, 1024] + + # Denoising loop + if show_pbar: + iterable = tqdm( + enumerate(timesteps), + total=len(timesteps), + leave=False, + desc=' ' * 4 + 'Diffusion denoising', + ) + else: + iterable = enumerate(timesteps) + + for i, t in iterable: + unet_input = torch.cat([rgb_latent, depth_latent], + dim=1) # this order is important + + # predict the noise residual + noise_pred = self.unet( + unet_input, t, encoder_hidden_states=batch_empty_text_embed + ).sample # [B, 4, h, w] + + # compute the previous noisy sample x_t -> x_t-1 + depth_latent = self.scheduler.step(noise_pred, t, + depth_latent).prev_sample + torch.cuda.empty_cache() + depth = self.decode_depth(depth_latent) + + # clip prediction + depth = torch.clip(depth, -1.0, 1.0) + # shift to [0, 1] + depth = (depth + 1.0) / 2.0 + + return depth + + def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: + """ + Encode RGB image into latent. + + Args: + rgb_in (`torch.Tensor`): + Input RGB image to be encoded. + + Returns: + `torch.Tensor`: Image latent. + """ + # encode + h = self.vae.encoder(rgb_in) + moments = self.vae.quant_conv(h) + mean, logvar = torch.chunk(moments, 2, dim=1) + # scale latent + rgb_latent = mean * self.rgb_latent_scale_factor + return rgb_latent + + def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: + """ + Decode depth latent into depth map. + + Args: + depth_latent (`torch.Tensor`): + Depth latent to be decoded. + + Returns: + `torch.Tensor`: Decoded depth map. + """ + # scale latent + depth_latent = depth_latent / self.depth_latent_scale_factor + # decode + z = self.vae.post_quant_conv(depth_latent) + stacked = self.vae.decoder(z) + # mean of output channels + depth_mean = stacked.mean(dim=1, keepdim=True) + return depth_mean + + def forward(self, x): + out = self.__call__(x) + return out diff --git a/tests/pipelines/test_image_depth_estimation_marigold.py b/tests/pipelines/test_image_depth_estimation_marigold.py new file mode 100644 index 00000000..ae33c138 --- /dev/null +++ b/tests/pipelines/test_image_depth_estimation_marigold.py @@ -0,0 +1,42 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.cv import ImageDepthEstimationMarigoldPipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class ImageDepthEstimationMarigoldTest(unittest.TestCase): + + def setUp(self) -> None: + self.task = Tasks.image_depth_estimation + self.model_id = 'Damo_XR_Lab/cv_marigold_monocular-depth-estimation' + self.image = 'data/in-the-wild_example/example_0.jpg' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name(self): + marigold = pipeline(task=self.task, model=self.model_id) + input_path = os.path.join(marigold.model, self.image) + result = marigold(input=input_path) + depth_vis = result[OutputKeys.DEPTHS_COLOR] + depth_vis.save('result_modelname.jpg') + print('Test run with model name ok.') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + marigold_pipe = ImageDepthEstimationMarigoldPipeline(cache_path) + marigold_pipe.group_key = self.task + input_path = os.path.join(cache_path, self.image) + result = marigold_pipe(input=input_path) + depth_vis = result[OutputKeys.DEPTHS_COLOR] + depth_vis.save('result_snapshot.jpg') + print('Test run with snapshot ok.') + + +if __name__ == '__main__': + unittest.main()