mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Merge pull request #574 from modelscope/master-merge-internal20231007
Master merge internal20231007
This commit is contained in:
@@ -1,2 +1,3 @@
|
||||
recursive-include modelscope/configs *.py *.cu *.h *.cpp
|
||||
recursive-include modelscope/cli/template *.tpl
|
||||
recursive-include modelscope/utils *.json
|
||||
|
||||
Submodule data/test updated: b648024203...77a9ad7fb3
@@ -29,9 +29,10 @@ RUN pip install --no-cache-dir text2sql_lgesql==1.3.0 \
|
||||
detectron2==0.3 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html --force --no-deps
|
||||
|
||||
RUN pip install --no-cache-dir mpi4py paint_ldm \
|
||||
mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 pai-easycv \
|
||||
mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 pai-easycv ms_swift \
|
||||
ipykernel fasttext fairseq deepspeed -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
|
||||
ARG USE_GPU
|
||||
# for cpu install cpu version faiss, faiss depends on blas lib, we install libopenblas TODO rename gpu or cpu version faiss
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
pip install --no-cache-dir funtextprocessing kwsbp==0.0.6 faiss==1.7.2 safetensors typeguard==2.13.3 scikit-learn librosa==0.9.2 funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
|
||||
@@ -45,10 +46,14 @@ COPY examples /modelscope/examples
|
||||
# for pai-easycv setup compatiblity issue
|
||||
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
|
||||
|
||||
RUN CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" pip install --no-cache-dir 'git+https://github.com/facebookresearch/detectron2.git'
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" pip install --no-cache-dir 'git+https://github.com/facebookresearch/detectron2.git'; \
|
||||
else \
|
||||
echo 'cpu unsupport detectron2'; \
|
||||
fi
|
||||
|
||||
# torchmetrics==0.11.4 for ofa
|
||||
RUN pip install --no-cache-dir tiktoken torchmetrics==0.11.4 https://modelscope.oss-cn-beijing.aliyuncs.com/releases/v/ms_swift-1.1.0-py3-none-any.whl transformers_stream_generator 'protobuf<=3.20.0' bitsandbytes basicsr
|
||||
RUN pip install --no-cache-dir jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator 'protobuf<=3.20.0' bitsandbytes basicsr
|
||||
COPY docker/scripts/install_flash_attension.sh /tmp/install_flash_attension.sh
|
||||
RUN if [ "$USE_GPU" = "True" ] ; then \
|
||||
bash /tmp/install_flash_attension.sh; \
|
||||
|
||||
@@ -4,36 +4,39 @@ from typing import TYPE_CHECKING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .version import __release_datetime__, __version__
|
||||
from .trainers import EpochBasedTrainer, TrainingArgs, build_dataset_from_file
|
||||
from .trainers import Hook, Priority
|
||||
from .exporters import Exporter
|
||||
from .exporters import TfModelExporter
|
||||
from .exporters import TorchModelExporter
|
||||
from .exporters import Exporter, TfModelExporter, TorchModelExporter
|
||||
from .hub.api import HubApi
|
||||
from .hub.snapshot_download import snapshot_download
|
||||
from .hub.check_model import check_local_model_is_latest, check_model_is_id
|
||||
from .hub.push_to_hub import push_to_hub, push_to_hub_async
|
||||
from .hub.check_model import check_model_is_id, check_local_model_is_latest
|
||||
from .metrics import AudioNoiseMetric, Metric, task_default_metrics, ImageColorEnhanceMetric, ImageDenoiseMetric, \
|
||||
ImageInstanceSegmentationCOCOMetric, ImagePortraitEnhancementMetric, SequenceClassificationMetric, \
|
||||
TextGenerationMetric, TokenClassificationMetric, VideoSummarizationMetric, MovieSceneSegmentationMetric, \
|
||||
AccuracyMetric, BleuMetric, ImageInpaintingMetric, ReferringVideoObjectSegmentationMetric, \
|
||||
VideoFrameInterpolationMetric, VideoStabilizationMetric, VideoSuperResolutionMetric, PplMetric, \
|
||||
ImageQualityAssessmentDegradationMetric, ImageQualityAssessmentMosMetric, TextRankingMetric, \
|
||||
LossMetric, ImageColorizationMetric, OCRRecognitionMetric
|
||||
from .hub.snapshot_download import snapshot_download
|
||||
from .metrics import (
|
||||
AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric,
|
||||
ImageColorizationMetric, ImageDenoiseMetric, ImageInpaintingMetric,
|
||||
ImageInstanceSegmentationCOCOMetric, ImagePortraitEnhancementMetric,
|
||||
ImageQualityAssessmentDegradationMetric,
|
||||
ImageQualityAssessmentMosMetric, LossMetric, Metric,
|
||||
MovieSceneSegmentationMetric, OCRRecognitionMetric, PplMetric,
|
||||
ReferringVideoObjectSegmentationMetric, SequenceClassificationMetric,
|
||||
TextGenerationMetric, TextRankingMetric, TokenClassificationMetric,
|
||||
VideoFrameInterpolationMetric, VideoStabilizationMetric,
|
||||
VideoSummarizationMetric, VideoSuperResolutionMetric,
|
||||
task_default_metrics)
|
||||
from .models import Model, TorchModel
|
||||
from .preprocessors import Preprocessor
|
||||
from .msdatasets import MsDataset
|
||||
from .pipelines import Pipeline, pipeline
|
||||
from .utils.hub import read_config, create_model_if_not_exist
|
||||
from .utils.logger import get_logger
|
||||
from .preprocessors import Preprocessor
|
||||
from .trainers import (EpochBasedTrainer, Hook, Priority, TrainingArgs,
|
||||
build_dataset_from_file)
|
||||
from .utils.constant import Tasks
|
||||
from .utils.hf_util import AutoConfig, GenerationConfig, GPTQConfig, BitsAndBytesConfig
|
||||
from .utils.hf_util import AutoConfig, GPTQConfig, BitsAndBytesConfig
|
||||
from .utils.hf_util import (AutoModel, AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification)
|
||||
from .utils.hf_util import AutoTokenizer
|
||||
from .msdatasets import MsDataset
|
||||
AutoModelForTokenClassification, AutoTokenizer,
|
||||
GenerationConfig)
|
||||
from .utils.hub import create_model_if_not_exist, read_config
|
||||
from .utils.logger import get_logger
|
||||
from .version import __release_datetime__, __version__
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
|
||||
@@ -7,13 +7,13 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
if TYPE_CHECKING:
|
||||
from .base import Exporter
|
||||
from .builder import build_exporter
|
||||
from .cv import CartoonTranslationExporter
|
||||
from .nlp import CsanmtForTranslationExporter
|
||||
from .tf_model_exporter import TfModelExporter
|
||||
from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter
|
||||
from .torch_model_exporter import TorchModelExporter
|
||||
from .cv import FaceDetectionSCRFDExporter
|
||||
from .cv import CartoonTranslationExporter, FaceDetectionSCRFDExporter
|
||||
from .multi_modal import StableDiffuisonExporter
|
||||
from .nlp import (CsanmtForTranslationExporter,
|
||||
SbertForSequenceClassificationExporter,
|
||||
SbertForZeroShotClassificationExporter)
|
||||
from .tf_model_exporter import TfModelExporter
|
||||
from .torch_model_exporter import TorchModelExporter
|
||||
else:
|
||||
_import_structure = {
|
||||
'base': ['Exporter'],
|
||||
|
||||
@@ -6,8 +6,9 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cartoon_translation_exporter import CartoonTranslationExporter
|
||||
from .object_detection_damoyolo_exporter import ObjectDetectionDamoyoloExporter
|
||||
from .face_detection_scrfd_exporter import FaceDetectionSCRFDExporter
|
||||
from .object_detection_damoyolo_exporter import \
|
||||
ObjectDetectionDamoyoloExporter
|
||||
else:
|
||||
_import_structure = {
|
||||
'cartoon_translation_exporter': ['CartoonTranslationExporter'],
|
||||
|
||||
41
modelscope/exporters/cv/ocr_detection_db_exporter.py
Normal file
41
modelscope/exporters/cv/ocr_detection_db_exporter.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Mapping
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.ocr_detection, module_name=Models.ocr_detection)
|
||||
class OCRDetectionDBExporter(TorchModelExporter):
|
||||
|
||||
def export_onnx(self,
|
||||
output_dir: str,
|
||||
opset=11,
|
||||
input_shape=(1, 3, 800, 800)):
|
||||
onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE)
|
||||
dummy_input = torch.randn(*input_shape)
|
||||
self.model.onnx_export = True
|
||||
self.model.eval()
|
||||
_ = self.model(dummy_input)
|
||||
torch.onnx._export(
|
||||
self.model,
|
||||
dummy_input,
|
||||
onnx_file,
|
||||
input_names=[
|
||||
'images',
|
||||
],
|
||||
output_names=[
|
||||
'pred',
|
||||
],
|
||||
opset_version=opset)
|
||||
|
||||
return {'model', onnx_file}
|
||||
40
modelscope/exporters/cv/ocr_recognition_exporter.py
Normal file
40
modelscope/exporters/cv/ocr_recognition_exporter.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Mapping
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.ocr_recognition, module_name=Models.ocr_recognition)
|
||||
class OCRRecognitionExporter(TorchModelExporter):
|
||||
|
||||
def export_onnx(self,
|
||||
output_dir: str,
|
||||
opset=11,
|
||||
input_shape=(1, 3, 32, 640)):
|
||||
onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE)
|
||||
dummy_input = torch.randn(*input_shape)
|
||||
self.model.onnx_export = True
|
||||
self.model.eval()
|
||||
_ = self.model(dummy_input)
|
||||
torch.onnx._export(
|
||||
self.model,
|
||||
dummy_input,
|
||||
onnx_file,
|
||||
input_names=[
|
||||
'images',
|
||||
],
|
||||
output_names=[
|
||||
'pred',
|
||||
],
|
||||
opset_version=opset)
|
||||
return {'model', onnx_file}
|
||||
@@ -6,7 +6,8 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .csanmt_for_translation_exporter import CsanmtForTranslationExporter
|
||||
from .model_for_token_classification_exporter import ModelForSequenceClassificationExporter
|
||||
from .model_for_token_classification_exporter import \
|
||||
ModelForSequenceClassificationExporter
|
||||
from .sbert_for_sequence_classification_exporter import \
|
||||
SbertForSequenceClassificationExporter
|
||||
from .sbert_for_zero_shot_classification_exporter import \
|
||||
|
||||
@@ -28,7 +28,8 @@ class CsanmtForTranslationExporter(TfModelExporter):
|
||||
tf.disable_eager_execution()
|
||||
super().__init__(model)
|
||||
|
||||
from modelscope.pipelines.nlp.translation_pipeline import TranslationPipeline
|
||||
from modelscope.pipelines.nlp.translation_pipeline import \
|
||||
TranslationPipeline
|
||||
self.pipeline = TranslationPipeline(self.model)
|
||||
|
||||
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
@@ -77,7 +77,9 @@ class ModelForSequenceClassificationExporter(TorchModelExporter):
|
||||
return
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
ort_session = ort.InferenceSession(
|
||||
output,
|
||||
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
outputs_origin = model.forward(
|
||||
|
||||
@@ -102,7 +102,9 @@ class TfModelExporter(Exporter):
|
||||
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model, full_check=True)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
ort_session = ort.InferenceSession(
|
||||
output,
|
||||
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
outputs_origin = call_func(
|
||||
dummy_inputs) if call_func is not None else model(dummy_inputs)
|
||||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
|
||||
|
||||
@@ -30,7 +30,7 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_TIMEOUT,
|
||||
DEFAULT_CREDENTIALS_PATH,
|
||||
MODELSCOPE_CLOUD_ENVIRONMENT,
|
||||
MODELSCOPE_CLOUD_USERNAME,
|
||||
ONE_YEAR_SECONDS,
|
||||
MODELSCOPE_REQUEST_ID, ONE_YEAR_SECONDS,
|
||||
REQUESTS_API_HTTP_METHOD, Licenses,
|
||||
ModelVisibility)
|
||||
from modelscope.hub.errors import (InvalidParameter, NotExistError,
|
||||
@@ -105,7 +105,9 @@ class HubApi:
|
||||
"""
|
||||
path = f'{self.endpoint}/api/v1/login'
|
||||
r = self.session.post(
|
||||
path, json={'AccessToken': access_token}, headers=self.headers)
|
||||
path,
|
||||
json={'AccessToken': access_token},
|
||||
headers=self.builder_headers(self.headers))
|
||||
raise_for_http_status(r)
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
@@ -166,7 +168,10 @@ class HubApi:
|
||||
'TrainId': os.environ.get('MODELSCOPE_TRAIN_ID', ''),
|
||||
}
|
||||
r = self.session.post(
|
||||
path, json=body, cookies=cookies, headers=self.headers)
|
||||
path,
|
||||
json=body,
|
||||
cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
handle_http_post_error(r, path, body)
|
||||
raise_on_error(r.json())
|
||||
model_repo_url = f'{get_endpoint()}/{model_id}'
|
||||
@@ -189,7 +194,9 @@ class HubApi:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
path = f'{self.endpoint}/api/v1/models/{model_id}'
|
||||
|
||||
r = self.session.delete(path, cookies=cookies, headers=self.headers)
|
||||
r = self.session.delete(path,
|
||||
cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
raise_for_http_status(r)
|
||||
raise_on_error(r.json())
|
||||
|
||||
@@ -223,7 +230,8 @@ class HubApi:
|
||||
else:
|
||||
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}'
|
||||
|
||||
r = self.session.get(path, cookies=cookies, headers=self.headers)
|
||||
r = self.session.get(path, cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
handle_http_response(r, logger, cookies, model_id)
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
@@ -390,7 +398,7 @@ class HubApi:
|
||||
data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' %
|
||||
(owner_or_group, page_number, page_size),
|
||||
cookies=cookies,
|
||||
headers=self.headers)
|
||||
headers=self.builder_headers(self.headers))
|
||||
handle_http_response(r, logger, cookies, 'list_model')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
@@ -435,7 +443,8 @@ class HubApi:
|
||||
if cutoff_timestamp is None:
|
||||
cutoff_timestamp = get_release_datetime()
|
||||
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp
|
||||
r = self.session.get(path, cookies=cookies, headers=self.headers)
|
||||
r = self.session.get(path, cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
handle_http_response(r, logger, cookies, model_id)
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
@@ -472,13 +481,15 @@ class HubApi:
|
||||
cutoff_timestamp=release_timestamp,
|
||||
use_cookies=False if cookies is None else cookies)
|
||||
if len(revisions) == 0:
|
||||
raise NoValidRevisionError(
|
||||
'The model: %s has no valid revision!' % model_id)
|
||||
# tags (revisions) returned from backend are guaranteed to be ordered by create-time
|
||||
# we shall obtain the latest revision created earlier than release version of this branch
|
||||
revision = revisions[0]
|
||||
logger.warning(('There is no version specified and there is no version in the model repository,'
|
||||
'use the master branch, which is fragile, please use it with caution!'))
|
||||
revision = MASTER_MODEL_BRANCH
|
||||
else:
|
||||
# tags (revisions) returned from backend are guaranteed to be ordered by create-time
|
||||
# we shall obtain the latest revision created earlier than release version of this branch
|
||||
revision = revisions[0]
|
||||
logger.info(
|
||||
'Model revision not specified, use the latest revision: %s'
|
||||
'Model revision not specified, use revision: %s'
|
||||
% revision)
|
||||
else:
|
||||
# use user-specified revision
|
||||
@@ -487,8 +498,11 @@ class HubApi:
|
||||
cutoff_timestamp=current_timestamp,
|
||||
use_cookies=False if cookies is None else cookies)
|
||||
if revision not in revisions:
|
||||
raise NotExistError('The model: %s has no revision: %s !' %
|
||||
(model_id, revision))
|
||||
if revision == MASTER_MODEL_BRANCH:
|
||||
logger.warning('Using the master branch is fragile, please use it with caution!')
|
||||
else:
|
||||
raise NotExistError('The model: %s has no revision: %s !' %
|
||||
(model_id, revision))
|
||||
logger.info('Use user-specified model revision: %s' % revision)
|
||||
return revision
|
||||
|
||||
@@ -510,7 +524,8 @@ class HubApi:
|
||||
cookies = self._check_cookie(use_cookies)
|
||||
|
||||
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions'
|
||||
r = self.session.get(path, cookies=cookies, headers=self.headers)
|
||||
r = self.session.get(path, cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
handle_http_response(r, logger, cookies, model_id)
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
@@ -552,6 +567,7 @@ class HubApi:
|
||||
if root is not None:
|
||||
path = path + f'&Root={root}'
|
||||
headers = self.headers if headers is None else headers
|
||||
headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
r = self.session.get(
|
||||
path, cookies=cookies, headers=headers)
|
||||
|
||||
@@ -570,7 +586,8 @@ class HubApi:
|
||||
def list_datasets(self):
|
||||
path = f'{self.endpoint}/api/v1/datasets'
|
||||
params = {}
|
||||
r = self.session.get(path, params=params, headers=self.headers)
|
||||
r = self.session.get(path, params=params,
|
||||
headers=self.builder_headers(self.headers))
|
||||
raise_for_http_status(r)
|
||||
dataset_list = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return [x['Name'] for x in dataset_list]
|
||||
@@ -590,7 +607,9 @@ class HubApi:
|
||||
""" Get the meta file-list of the dataset. """
|
||||
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
r = self.session.get(datahub_url, cookies=cookies, headers=self.headers)
|
||||
r = self.session.get(datahub_url,
|
||||
cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
resp = r.json()
|
||||
datahub_raise_on_error(datahub_url, resp)
|
||||
file_list = resp['Data']
|
||||
@@ -736,7 +755,9 @@ class HubApi:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
r = self.session.get(
|
||||
url=datahub_url, cookies=cookies, headers=self.headers)
|
||||
url=datahub_url,
|
||||
cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
resp = r.json()
|
||||
raise_on_error(resp)
|
||||
return resp['Data']
|
||||
@@ -759,7 +780,11 @@ class HubApi:
|
||||
data = dict(
|
||||
data=dataset_info,
|
||||
)
|
||||
r = self.session.post(url=virgo_dataset_url, json=data, cookies=cookies, headers=self.headers, timeout=900)
|
||||
r = self.session.post(url=virgo_dataset_url,
|
||||
json=data,
|
||||
cookies=cookies,
|
||||
headers=self.builder_headers(self.headers),
|
||||
timeout=900)
|
||||
resp = r.json()
|
||||
if resp['code'] != 0:
|
||||
raise RuntimeError(f'Failed to get virgo dataset: {resp}')
|
||||
@@ -773,7 +798,8 @@ class HubApi:
|
||||
zip_file_name: str):
|
||||
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
r = self.session.get(url=datahub_url, cookies=cookies, headers=self.headers)
|
||||
r = self.session.get(url=datahub_url, cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
resp = r.json()
|
||||
# get visibility of the dataset
|
||||
raise_on_error(resp)
|
||||
@@ -781,7 +807,8 @@ class HubApi:
|
||||
visibility = DatasetVisibilityMap.get(data['Visibility'])
|
||||
|
||||
datahub_sts_url = f'{datahub_url}/ststoken?Revision={revision}'
|
||||
r_sts = self.session.get(url=datahub_sts_url, cookies=cookies, headers=self.headers)
|
||||
r_sts = self.session.get(url=datahub_sts_url, cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
resp_sts = r_sts.json()
|
||||
raise_on_error(resp_sts)
|
||||
data_sts = resp_sts['Data']
|
||||
@@ -848,7 +875,8 @@ class HubApi:
|
||||
|
||||
# Download count
|
||||
download_count_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase'
|
||||
download_count_resp = self.session.post(download_count_url, cookies=cookies, headers=self.headers)
|
||||
download_count_resp = self.session.post(download_count_url, cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
raise_for_http_status(download_count_resp)
|
||||
|
||||
# Download uv
|
||||
@@ -860,13 +888,18 @@ class HubApi:
|
||||
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
|
||||
download_uv_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/' \
|
||||
f'{channel}?user={user_name}'
|
||||
download_uv_resp = self.session.post(download_uv_url, cookies=cookies, headers=self.headers)
|
||||
download_uv_resp = self.session.post(download_uv_url, cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
download_uv_resp = download_uv_resp.json()
|
||||
raise_on_error(download_uv_resp)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
def builder_headers(self, headers):
|
||||
return {MODELSCOPE_REQUEST_ID: str(uuid.uuid4().hex),
|
||||
**headers}
|
||||
|
||||
|
||||
class ModelScopeConfig:
|
||||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
MODELSCOPE_URL_SCHEME = 'http://'
|
||||
MODELSCOPE_URL_SCHEME = 'https://'
|
||||
DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn'
|
||||
DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DOMAIN
|
||||
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB = int(
|
||||
@@ -31,6 +31,7 @@ MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG'
|
||||
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60
|
||||
MODEL_META_FILE_NAME = '.mdl'
|
||||
MODEL_META_MODEL_ID = 'id'
|
||||
MODELSCOPE_REQUEST_ID = 'X-Request-ID'
|
||||
|
||||
|
||||
class Licenses(object):
|
||||
|
||||
@@ -5,6 +5,7 @@ from http import HTTPStatus
|
||||
import requests
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from modelscope.hub.constants import MODELSCOPE_REQUEST_ID
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@@ -46,6 +47,13 @@ class FileDownloadError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_request_id(response: requests.Response):
|
||||
if MODELSCOPE_REQUEST_ID in response.request.headers:
|
||||
return response.request.headers[MODELSCOPE_REQUEST_ID]
|
||||
else:
|
||||
return ''
|
||||
|
||||
|
||||
def is_ok(rsp):
|
||||
""" Check the request is ok
|
||||
|
||||
@@ -71,12 +79,14 @@ def handle_http_post_error(response, url, request_body):
|
||||
response.raise_for_status()
|
||||
except HTTPError as error:
|
||||
message = _decode_response_error(response)
|
||||
raise HTTPError('Request %s with body: %s exception, '
|
||||
'Response details: %s' %
|
||||
(url, request_body, message)) from error
|
||||
raise HTTPError(
|
||||
'Request %s with body: %s exception, '
|
||||
'Response details: %s, request id: %s' %
|
||||
(url, request_body, message, get_request_id(response))) from error
|
||||
|
||||
|
||||
def handle_http_response(response, logger, cookies, model_id):
|
||||
def handle_http_response(response: requests.Response, logger, cookies,
|
||||
model_id):
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except HTTPError as error:
|
||||
@@ -85,7 +95,8 @@ def handle_http_response(response, logger, cookies, model_id):
|
||||
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \
|
||||
private. Please login first.')
|
||||
message = _decode_response_error(response)
|
||||
raise HTTPError('Response details: %s' % message) from error
|
||||
raise HTTPError('Response details: %s, Request id: %s' %
|
||||
(message, get_request_id(response))) from error
|
||||
|
||||
|
||||
def raise_on_error(rsp):
|
||||
@@ -122,9 +133,10 @@ def datahub_raise_on_error(url, rsp):
|
||||
if rsp.get('Code') == HTTPStatus.OK:
|
||||
return True
|
||||
else:
|
||||
request_id = get_request_id(rsp)
|
||||
raise RequestError(
|
||||
f"Url = {url}, Message = {rsp.get('Message')}, Please specify correct dataset_name and namespace."
|
||||
)
|
||||
f"Url = {url}, Request id={request_id} Message = {rsp.get('Message')},\
|
||||
Please specify correct dataset_name and namespace.")
|
||||
|
||||
|
||||
def raise_for_http_status(rsp):
|
||||
@@ -146,14 +158,14 @@ def raise_for_http_status(rsp):
|
||||
reason = rsp.reason.decode('iso-8859-1')
|
||||
else:
|
||||
reason = rsp.reason
|
||||
|
||||
request_id = get_request_id(rsp)
|
||||
if 400 <= rsp.status_code < 500:
|
||||
http_error_msg = u'%s Client Error: %s for url: %s' % (rsp.status_code,
|
||||
reason, rsp.url)
|
||||
http_error_msg = u'%s Client Error: %s, Request id: %s for url: %s' % (
|
||||
rsp.status_code, reason, request_id, rsp.url)
|
||||
|
||||
elif 500 <= rsp.status_code < 600:
|
||||
http_error_msg = u'%s Server Error: %s for url: %s' % (rsp.status_code,
|
||||
reason, rsp.url)
|
||||
http_error_msg = u'%s Server Error: %s, Request id: %s, for url: %s' % (
|
||||
rsp.status_code, reason, request_id, rsp.url)
|
||||
|
||||
if http_error_msg:
|
||||
req = rsp.request
|
||||
|
||||
@@ -4,6 +4,7 @@ import copy
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from http.cookiejar import CookieJar
|
||||
@@ -192,6 +193,7 @@ def download_part_with_retry(params):
|
||||
progress, start, end, url, file_name, cookies, headers = params
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
get_headers['Range'] = 'bytes=%s-%s' % (start, end)
|
||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
retry = Retry(
|
||||
total=API_FILE_DOWNLOAD_RETRY_TIMES,
|
||||
backoff_factor=1,
|
||||
@@ -289,6 +291,7 @@ def http_get_file(
|
||||
temp_file_manager = partial(
|
||||
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.debug('downloading %s to %s', url, temp_file.name)
|
||||
# retry sleep 0.5s, 1s, 2s, 4s
|
||||
|
||||
@@ -82,6 +82,7 @@ class Models(object):
|
||||
image_skychange = 'image-skychange'
|
||||
video_human_matting = 'video-human-matting'
|
||||
human_reconstruction = 'human-reconstruction'
|
||||
text_texture_generation = 'text-texture-generation'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_deinterlace = 'video-deinterlace'
|
||||
@@ -124,6 +125,8 @@ class Models(object):
|
||||
pedestrian_attribute_recognition = 'pedestrian-attribute-recognition'
|
||||
image_try_on = 'image-try-on'
|
||||
human_image_generation = 'human-image-generation'
|
||||
image_view_transform = 'image-view-transform'
|
||||
image_control_3d_portrait = 'image-control-3d-portrait'
|
||||
|
||||
# nlp models
|
||||
bert = 'bert'
|
||||
@@ -293,6 +296,7 @@ class Pipelines(object):
|
||||
table_recognition = 'dla34-table-recognition'
|
||||
lineless_table_recognition = 'lore-lineless-table-recognition'
|
||||
license_plate_detection = 'resnet18-license-plate-detection'
|
||||
card_detection_correction = 'resnet18-card-detection-correction'
|
||||
action_recognition = 'TAdaConv_action-recognition'
|
||||
animal_recognition = 'resnet101-animal-recognition'
|
||||
general_recognition = 'resnet101-general-recognition'
|
||||
@@ -365,6 +369,8 @@ class Pipelines(object):
|
||||
hand_detection = 'yolox-pai_hand-detection'
|
||||
skin_retouching = 'unet-skin-retouching'
|
||||
face_reconstruction = 'resnet50-face-reconstruction'
|
||||
head_reconstruction = 'HRN-head-reconstruction'
|
||||
text_to_head = 'HRN-text-to-head'
|
||||
tinynas_classification = 'tinynas-classification'
|
||||
easyrobust_classification = 'easyrobust-classification'
|
||||
tinynas_detection = 'tinynas-detection'
|
||||
@@ -401,6 +407,7 @@ class Pipelines(object):
|
||||
image_skychange = 'image-skychange'
|
||||
video_human_matting = 'video-human-matting'
|
||||
human_reconstruction = 'human-reconstruction'
|
||||
text_texture_generation = 'text-texture-generation'
|
||||
vision_middleware_multi_task = 'vision-middleware-multi-task'
|
||||
vidt = 'vidt'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
@@ -442,6 +449,10 @@ class Pipelines(object):
|
||||
text_to_360panorama_image = 'text-to-360panorama-image'
|
||||
image_try_on = 'image-try-on'
|
||||
human_image_generation = 'human-image-generation'
|
||||
human3d_render = 'human3d-render'
|
||||
human3d_animation = 'human3d-animation'
|
||||
image_view_transform = 'image-view-transform'
|
||||
image_control_3d_portrait = 'image-control-3d-portrait'
|
||||
|
||||
# nlp tasks
|
||||
automatic_post_editing = 'automatic-post-editing'
|
||||
@@ -677,6 +688,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.license_plate_detection:
|
||||
(Pipelines.license_plate_detection,
|
||||
'damo/cv_resnet18_license-plate-detection_damo'),
|
||||
Tasks.card_detection_correction: (Pipelines.card_detection_correction,
|
||||
'damo/cv_resnet18_card_correction'),
|
||||
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
|
||||
Tasks.feature_extraction: (Pipelines.feature_extraction,
|
||||
'damo/pert_feature-extraction_base-test'),
|
||||
@@ -830,6 +843,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_effnetv2_video-human-matting'),
|
||||
Tasks.human_reconstruction: (Pipelines.human_reconstruction,
|
||||
'damo/cv_hrnet_image-human-reconstruction'),
|
||||
Tasks.text_texture_generation: (
|
||||
Pipelines.text_texture_generation,
|
||||
'damo/cv_diffuser_text-texture-generation'),
|
||||
Tasks.video_frame_interpolation: (
|
||||
Pipelines.video_frame_interpolation,
|
||||
'damo/cv_raft_video-frame-interpolation'),
|
||||
@@ -908,7 +924,16 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.image_try_on: (Pipelines.image_try_on,
|
||||
'damo/cv_SAL-VTON_virtual-try-on'),
|
||||
Tasks.human_image_generation: (Pipelines.human_image_generation,
|
||||
'damo/cv_FreqHPT_human-image-generation')
|
||||
'damo/cv_FreqHPT_human-image-generation'),
|
||||
Tasks.human3d_render: (Pipelines.human3d_render,
|
||||
'damo/cv_3d-human-synthesis-library'),
|
||||
Tasks.human3d_animation: (Pipelines.human3d_animation,
|
||||
'damo/cv_3d-human-animation'),
|
||||
Tasks.image_view_transform: (Pipelines.image_view_transform,
|
||||
'damo/cv_image-view-transform'),
|
||||
Tasks.image_control_3d_portrait: (
|
||||
Pipelines.image_control_3d_portrait,
|
||||
'damo/cv_vit_image-control-3d-portrait-synthesis')
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -4,34 +4,39 @@ from typing import TYPE_CHECKING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .accuracy_metric import AccuracyMetric
|
||||
from .audio_noise_metric import AudioNoiseMetric
|
||||
from .base import Metric
|
||||
from .bleu_metric import BleuMetric
|
||||
from .builder import METRICS, build_metric, task_default_metrics
|
||||
from .image_color_enhance_metric import ImageColorEnhanceMetric
|
||||
from .image_colorization_metric import ImageColorizationMetric
|
||||
from .image_denoise_metric import ImageDenoiseMetric
|
||||
from .image_inpainting_metric import ImageInpaintingMetric
|
||||
from .image_instance_segmentation_metric import \
|
||||
ImageInstanceSegmentationCOCOMetric
|
||||
from .image_portrait_enhancement_metric import ImagePortraitEnhancementMetric
|
||||
from .image_portrait_enhancement_metric import \
|
||||
ImagePortraitEnhancementMetric
|
||||
from .image_quality_assessment_degradation_metric import \
|
||||
ImageQualityAssessmentDegradationMetric
|
||||
from .image_quality_assessment_mos_metric import \
|
||||
ImageQualityAssessmentMosMetric
|
||||
from .loss_metric import LossMetric
|
||||
from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric
|
||||
from .ocr_recognition_metric import OCRRecognitionMetric
|
||||
from .ppl_metric import PplMetric
|
||||
from .referring_video_object_segmentation_metric import \
|
||||
ReferringVideoObjectSegmentationMetric
|
||||
from .sequence_classification_metric import SequenceClassificationMetric
|
||||
from .text_generation_metric import TextGenerationMetric
|
||||
from .text_ranking_metric import TextRankingMetric
|
||||
from .token_classification_metric import TokenClassificationMetric
|
||||
from .video_summarization_metric import VideoSummarizationMetric
|
||||
from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric
|
||||
from .accuracy_metric import AccuracyMetric
|
||||
from .bleu_metric import BleuMetric
|
||||
from .image_inpainting_metric import ImageInpaintingMetric
|
||||
from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric
|
||||
from .translation_evaluation_metric import TranslationEvaluationMetric
|
||||
from .video_frame_interpolation_metric import VideoFrameInterpolationMetric
|
||||
from .video_stabilization_metric import VideoStabilizationMetric
|
||||
from .video_super_resolution_metric.video_super_resolution_metric import VideoSuperResolutionMetric
|
||||
from .ppl_metric import PplMetric
|
||||
from .image_quality_assessment_degradation_metric import ImageQualityAssessmentDegradationMetric
|
||||
from .image_quality_assessment_mos_metric import ImageQualityAssessmentMosMetric
|
||||
from .text_ranking_metric import TextRankingMetric
|
||||
from .loss_metric import LossMetric
|
||||
from .image_colorization_metric import ImageColorizationMetric
|
||||
from .ocr_recognition_metric import OCRRecognitionMetric
|
||||
from .translation_evaluation_metric import TranslationEvaluationMetric
|
||||
from .video_summarization_metric import VideoSummarizationMetric
|
||||
from .video_super_resolution_metric.video_super_resolution_metric import \
|
||||
VideoSuperResolutionMetric
|
||||
else:
|
||||
_import_structure = {
|
||||
'audio_noise_metric': ['AudioNoiseMetric'],
|
||||
|
||||
@@ -143,10 +143,15 @@ class Model(ABC):
|
||||
task_name = getattr(cfg, 'task', None)
|
||||
if 'task' in kwargs:
|
||||
task_name = kwargs.pop('task')
|
||||
model_cfg = getattr(cfg, 'model', None)
|
||||
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
|
||||
model_cfg.type = model_cfg.model_type
|
||||
model_type = getattr(model_cfg, 'type', None)
|
||||
try:
|
||||
model_cfg = cfg.model
|
||||
if hasattr(model_cfg,
|
||||
'model_type') and not hasattr(model_cfg, 'type'):
|
||||
model_cfg.type = model_cfg.model_type
|
||||
model_type = model_cfg.type
|
||||
except Exception:
|
||||
model_cfg = {}
|
||||
model_type = ''
|
||||
if isinstance(device, str) and device.startswith('gpu'):
|
||||
device = 'cuda' + device[3:]
|
||||
use_hf = kwargs.pop('use_hf', None)
|
||||
|
||||
@@ -5,10 +5,10 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
|
||||
body_2d_keypoints, body_3d_keypoints, cartoon,
|
||||
cmdssl_video_embedding, controllable_image_generation,
|
||||
crowd_counting, face_detection, face_generation,
|
||||
face_reconstruction, human_reconstruction, image_classification,
|
||||
image_color_enhance, image_colorization, image_defrcn_fewshot,
|
||||
image_denoise, image_editing, image_inpainting,
|
||||
image_instance_segmentation, image_matching,
|
||||
face_reconstruction, human3d_animation, human_reconstruction,
|
||||
image_classification, image_color_enhance, image_colorization,
|
||||
image_defrcn_fewshot, image_denoise, image_editing,
|
||||
image_inpainting, image_instance_segmentation, image_matching,
|
||||
image_mvs_depth_estimation, image_panoptic_segmentation,
|
||||
image_portrait_enhancement, image_probing_model,
|
||||
image_quality_assessment_degradation,
|
||||
@@ -23,7 +23,8 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
|
||||
referring_video_object_segmentation,
|
||||
robust_image_classification, salient_detection,
|
||||
shop_segmentation, stream_yolo, super_resolution,
|
||||
surface_recon_common, table_recognition, video_deinterlace,
|
||||
surface_recon_common, table_recognition,
|
||||
text_texture_generation, video_deinterlace,
|
||||
video_frame_interpolation, video_object_segmentation,
|
||||
video_panoptic_segmentation, video_single_object_tracking,
|
||||
video_stabilization, video_summarization,
|
||||
|
||||
673
modelscope/models/cv/head_reconstruction/models/bfm.py
Normal file
673
modelscope/models/cv/head_reconstruction/models/bfm.py
Normal file
@@ -0,0 +1,673 @@
|
||||
# Part of the implementation is borrowed and modified from Deep3DFaceRecon_pytorch,
|
||||
# publicly available at https://github.com/sicxu/Deep3DFaceRecon_pytorch
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from scipy.io import loadmat
|
||||
|
||||
from modelscope.models.cv.face_reconstruction.utils import read_obj
|
||||
|
||||
|
||||
def perspective_projection(focal, center):
|
||||
# return p.T (N, 3) @ (3, 3)
|
||||
return np.array([focal, 0, center, 0, focal, center, 0, 0,
|
||||
1]).reshape([3, 3]).astype(np.float32).transpose()
|
||||
|
||||
|
||||
class SH:
|
||||
|
||||
def __init__(self):
|
||||
self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
|
||||
self.c = [
|
||||
1 / np.sqrt(4 * np.pi),
|
||||
np.sqrt(3.) / np.sqrt(4 * np.pi),
|
||||
3 * np.sqrt(5.) / np.sqrt(12 * np.pi)
|
||||
]
|
||||
|
||||
|
||||
class ParametricFaceModel:
|
||||
|
||||
def __init__(self,
|
||||
assets_root='assets',
|
||||
recenter=True,
|
||||
camera_distance=10.,
|
||||
init_lit=np.array([0.8, 0, 0, 0, 0, 0, 0, 0, 0]),
|
||||
focal=1015.,
|
||||
center=112.,
|
||||
is_train=True,
|
||||
default_name='BFM_model_front.mat'):
|
||||
|
||||
model = loadmat(os.path.join(assets_root, '3dmm/BFM', default_name))
|
||||
model_bfm_front = loadmat(
|
||||
os.path.join(assets_root, '3dmm/BFM/BFM_model_front.mat'))
|
||||
self.mean_shape_ori = model_bfm_front['meanshape'].astype(np.float32)
|
||||
# mean face shape. [3*N,1]
|
||||
self.mean_shape = model['meanshape'].astype(np.float32) # (1, 107127)
|
||||
|
||||
# identity basis. [3*N,80]
|
||||
self.id_base = model['idBase'].astype(np.float32) # (107127, 80)
|
||||
|
||||
# expression basis. [3*N,64]
|
||||
self.exp_base = model['exBase'].astype(np.float32) # (107127, 64)
|
||||
|
||||
# mean face texture. [3*N,1] (0-255)
|
||||
self.mean_tex = model['meantex'].astype(np.float32) # (1, 107127)
|
||||
|
||||
# texture basis. [3*N,80]
|
||||
self.tex_base = model['texBase'].astype(np.float32) # (107127, 80)
|
||||
|
||||
self.bfm_keep_inds = np.load(
|
||||
os.path.join(assets_root, '3dmm/inds/bfm_keep_inds.npy'))
|
||||
|
||||
self.ours_hair_area_inds = np.load(
|
||||
os.path.join(assets_root, '3dmm/inds/ours_hair_area_inds.npy'))
|
||||
|
||||
if default_name == 'ourRefineFull_model.mat':
|
||||
self.mean_tex = self.mean_tex.reshape(1, -1, 3)
|
||||
mean_tex_keep = self.mean_tex[:, self.bfm_keep_inds]
|
||||
self.mean_tex[:, :len(self.bfm_keep_inds)] = mean_tex_keep
|
||||
self.mean_tex[:,
|
||||
len(self.bfm_keep_inds):] = np.array([200, 146,
|
||||
118])[None,
|
||||
None]
|
||||
self.mean_tex[:, self.ours_hair_area_inds] = 40.0
|
||||
self.mean_tex = self.mean_tex.reshape(1, -1)
|
||||
self.mean_tex = np.ascontiguousarray(self.mean_tex)
|
||||
|
||||
self.tex_base = self.tex_base.reshape(-1, 3, 80)
|
||||
tex_base_keep = self.tex_base[self.bfm_keep_inds]
|
||||
self.tex_base[:len(self.bfm_keep_inds)] = tex_base_keep
|
||||
self.tex_base[len(self.bfm_keep_inds):] = 0.0
|
||||
self.tex_base = self.tex_base.reshape(-1, 80)
|
||||
self.tex_base = np.ascontiguousarray(self.tex_base)
|
||||
|
||||
# face indices for each vertex that lies in. starts from 0. [N,8]
|
||||
self.point_buf = model['point_buf'].astype(np.int64) - 1 # (35709, 8)
|
||||
|
||||
# vertex indices for each face. starts from 0. [F,3]
|
||||
self.face_buf = model['tri'].astype(np.int64) - 1 # (70789, 3)
|
||||
|
||||
# vertex indices for 68 landmarks. starts from 0. [68,1]
|
||||
self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
|
||||
|
||||
if default_name == 'ourRefineFull_model.mat':
|
||||
self.keypoints = np.load(
|
||||
os.path.join(
|
||||
assets_root,
|
||||
'3dmm/inds/our_refine0223_basis_withoutEyes_withUV_keypoints_inds.npy'
|
||||
)).astype(np.int64)
|
||||
self.point_buf = self.point_buf[:, :8] + 1
|
||||
|
||||
if is_train:
|
||||
# vertex indices for small face region to compute photometric error. starts from 0.
|
||||
self.front_mask = np.squeeze(model['frontmask2_idx']).astype(
|
||||
np.int64) - 1
|
||||
# vertex indices for each face from small face region. starts from 0. [f,3]
|
||||
self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
|
||||
# vertex indices for pre-defined skin region to compute reflectance loss
|
||||
self.skin_mask = np.squeeze(model['skinmask'])
|
||||
|
||||
if default_name == 'ourRefineFull_model.mat':
|
||||
nose_reduced_mesh = read_obj(
|
||||
os.path.join(assets_root,
|
||||
'3dmm/adjust_part/our_full/145_nose.obj'))
|
||||
self.nose_reduced_part = nose_reduced_mesh['vertices'].reshape(
|
||||
(1, -1)) - self.mean_shape
|
||||
|
||||
neck_mesh = read_obj(
|
||||
os.path.join(assets_root,
|
||||
'3dmm/adjust_part/our_full/154_neck.obj'))
|
||||
self.neck_adjust_part = neck_mesh['vertices'].reshape(
|
||||
(1, -1)) - self.mean_shape
|
||||
|
||||
eyes_mesh = read_obj(
|
||||
os.path.join(
|
||||
assets_root,
|
||||
'3dmm/adjust_part/our_full/our_mean_adjust_eyes.obj'))
|
||||
self.eyes_adjust_part = eyes_mesh['vertices'].reshape(
|
||||
(1, -1)) - self.mean_shape
|
||||
|
||||
self.neck_slim_part = None
|
||||
self.neck_stretch_part = None
|
||||
elif default_name == 'ourRefineBFMEye0504_model.mat':
|
||||
nose_reduced_mesh = read_obj(
|
||||
os.path.join(assets_root,
|
||||
'3dmm/adjust_part/our_full_bfmEyes/145_nose.obj'))
|
||||
self.nose_reduced_part = nose_reduced_mesh['vertices'].reshape(
|
||||
(1, -1)) - self.mean_shape
|
||||
|
||||
neck_mesh = read_obj(
|
||||
os.path.join(assets_root,
|
||||
'3dmm/adjust_part/our_full_bfmEyes/146_neck.obj'))
|
||||
self.neck_adjust_part = neck_mesh['vertices'].reshape(
|
||||
(1, -1)) - self.mean_shape
|
||||
|
||||
self.eyes_adjust_part = None
|
||||
|
||||
neck_slim_mesh = read_obj(
|
||||
os.path.join(
|
||||
assets_root,
|
||||
'3dmm/adjust_part/our_full_bfmEyes/147_neckSlim2.obj'))
|
||||
self.neck_slim_part = neck_slim_mesh['vertices'].reshape(
|
||||
(1, -1)) - self.mean_shape
|
||||
|
||||
neck_stretch_mesh = read_obj(
|
||||
os.path.join(
|
||||
assets_root,
|
||||
'3dmm/adjust_part/our_full_bfmEyes/148_neckLength.obj'))
|
||||
self.neck_stretch_part = neck_stretch_mesh['vertices'].reshape(
|
||||
(1, -1)) - self.mean_shape
|
||||
else:
|
||||
self.nose_reduced_part = None
|
||||
|
||||
self.neck_adjust_part = None
|
||||
self.eyes_adjust_part = None
|
||||
self.neck_slim_part = None
|
||||
self.neck_stretch_part = None
|
||||
|
||||
if recenter:
|
||||
mean_shape = self.mean_shape.reshape([-1, 3])
|
||||
mean_shape_ori = self.mean_shape_ori.reshape([-1, 3])
|
||||
mean_shape = mean_shape - np.mean(
|
||||
mean_shape_ori[:35709, ...], axis=0, keepdims=True)
|
||||
self.mean_shape = mean_shape.reshape([-1, 1])
|
||||
|
||||
eye_corner_inds = np.load(
|
||||
os.path.join(assets_root, '3dmm/inds/eye_corner_inds.npy'))
|
||||
self.eye_corner_inds = torch.from_numpy(eye_corner_inds).long()
|
||||
eye_lines = np.load(
|
||||
os.path.join(assets_root, '3dmm/inds/eye_corner_lines.npy'))
|
||||
self.eye_lines = torch.from_numpy(eye_lines).long()
|
||||
|
||||
self.center = center
|
||||
self.persc_proj = perspective_projection(focal, self.center)
|
||||
self.camera_distance = camera_distance
|
||||
self.SH = SH()
|
||||
self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
|
||||
|
||||
def to(self, device):
|
||||
self.device = device
|
||||
for key, value in self.__dict__.items():
|
||||
if type(value).__module__ == np.__name__:
|
||||
setattr(self, key, torch.tensor(value).to(device))
|
||||
|
||||
def compute_shape(self,
|
||||
id_coeff,
|
||||
exp_coeff,
|
||||
nose_coeff=0.0,
|
||||
neck_coeff=0.0,
|
||||
eyes_coeff=0.0,
|
||||
neckSlim_coeff=0.0,
|
||||
neckStretch_coeff=0.0):
|
||||
"""
|
||||
Return:
|
||||
face_shape -- torch.tensor, size (B, N, 3)
|
||||
|
||||
Parameters:
|
||||
id_coeff -- torch.tensor, size (B, 80), identity coeffs
|
||||
exp_coeff -- torch.tensor, size (B, 64), expression coeffs
|
||||
"""
|
||||
batch_size = id_coeff.shape[0]
|
||||
id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
|
||||
exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
|
||||
face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
|
||||
|
||||
if nose_coeff != 0:
|
||||
face_shape = face_shape + nose_coeff * self.nose_reduced_part
|
||||
if neck_coeff != 0:
|
||||
face_shape = face_shape + neck_coeff * self.neck_adjust_part
|
||||
if eyes_coeff != 0 and self.eyes_adjust_part is not None:
|
||||
face_shape = face_shape + eyes_coeff * self.eyes_adjust_part
|
||||
if neckSlim_coeff != 0 and self.neck_slim_part is not None:
|
||||
face_shape = face_shape + neckSlim_coeff * self.neck_slim_part
|
||||
if neckStretch_coeff != 0 and self.neck_stretch_part is not None:
|
||||
|
||||
neck_stretch_part = self.neck_stretch_part.reshape(1, -1, 3)
|
||||
neck_stretch_part_top = neck_stretch_part[0, 37476, 1]
|
||||
neck_stretch_part_bottom = neck_stretch_part[0, 37357, 1]
|
||||
neck_stretch_height = neck_stretch_part_top - neck_stretch_part_bottom
|
||||
|
||||
face_shape_ = face_shape.reshape(1, -1, 3)
|
||||
face_shape_top = face_shape_[0, 37476, 1]
|
||||
face_shape_bottom = face_shape_[0, 37357, 1]
|
||||
face_shape_height = face_shape_top - face_shape_bottom
|
||||
|
||||
target_neck_height = 0.72 # top ind 37476, bottom ind 37357
|
||||
neckStretch_coeff = (target_neck_height
|
||||
- face_shape_height) / neck_stretch_height
|
||||
|
||||
face_shape = face_shape + neckStretch_coeff * self.neck_stretch_part
|
||||
|
||||
return face_shape.reshape([batch_size, -1, 3])
|
||||
|
||||
def compute_texture(self, tex_coeff, normalize=True):
|
||||
"""
|
||||
Return:
|
||||
face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
|
||||
|
||||
Parameters:
|
||||
tex_coeff -- torch.tensor, size (B, 80)
|
||||
"""
|
||||
batch_size = tex_coeff.shape[0]
|
||||
face_texture = torch.einsum('ij,aj->ai', self.tex_base,
|
||||
tex_coeff) + self.mean_tex
|
||||
if normalize:
|
||||
face_texture = face_texture / 255.
|
||||
return face_texture.reshape([batch_size, -1, 3])
|
||||
|
||||
def compute_norm(self, face_shape):
|
||||
"""
|
||||
Return:
|
||||
vertex_norm -- torch.tensor, size (B, N, 3)
|
||||
|
||||
Parameters:
|
||||
face_shape -- torch.tensor, size (B, N, 3)
|
||||
"""
|
||||
|
||||
v1 = face_shape[:, self.face_buf[:, 0]]
|
||||
v2 = face_shape[:, self.face_buf[:, 1]]
|
||||
v3 = face_shape[:, self.face_buf[:, 2]]
|
||||
e1 = v1 - v2
|
||||
e2 = v2 - v3
|
||||
face_norm = torch.cross(e1, e2, dim=-1)
|
||||
face_norm = F.normalize(face_norm, dim=-1, p=2)
|
||||
face_norm = torch.cat(
|
||||
[face_norm,
|
||||
torch.zeros(face_norm.shape[0], 1, 3).to(self.device)],
|
||||
dim=1)
|
||||
|
||||
vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
|
||||
vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
|
||||
return vertex_norm
|
||||
|
||||
def compute_color(self, face_texture, face_norm, gamma):
|
||||
"""
|
||||
Return:
|
||||
face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
|
||||
|
||||
Parameters:
|
||||
face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
|
||||
face_norm -- torch.tensor, size (B, N, 3), rotated face normal
|
||||
gamma -- torch.tensor, size (B, 27), SH coeffs
|
||||
"""
|
||||
batch_size = gamma.shape[0]
|
||||
a, c = self.SH.a, self.SH.c
|
||||
gamma = gamma.reshape([batch_size, 3, 9])
|
||||
gamma = gamma + self.init_lit
|
||||
gamma = gamma.permute(0, 2, 1)
|
||||
|
||||
y1 = a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device)
|
||||
y2 = -a[1] * c[1] * face_norm[..., 1:2]
|
||||
y3 = a[1] * c[1] * face_norm[..., 2:]
|
||||
y4 = -a[1] * c[1] * face_norm[..., :1]
|
||||
y5 = a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2]
|
||||
y6 = -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:]
|
||||
y7 = 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:]**2 - 1)
|
||||
y8 = -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:]
|
||||
y9 = 0.5 * a[2] * c[2] * (
|
||||
face_norm[..., :1]**2 - face_norm[..., 1:2]**2)
|
||||
Y = torch.cat([y1, y2, y3, y4, y5, y6, y7, y8, y9], dim=-1)
|
||||
r = Y @ gamma[..., :1]
|
||||
g = Y @ gamma[..., 1:2]
|
||||
b = Y @ gamma[..., 2:]
|
||||
face_color = torch.cat([r, g, b], dim=-1) * face_texture
|
||||
return face_color
|
||||
|
||||
def compute_rotation(self, angles):
|
||||
"""
|
||||
Return:
|
||||
rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
|
||||
|
||||
Parameters:
|
||||
angles -- torch.tensor, size (B, 3), radian
|
||||
"""
|
||||
|
||||
batch_size = angles.shape[0]
|
||||
ones = torch.ones([batch_size, 1]).to(self.device)
|
||||
zeros = torch.zeros([batch_size, 1]).to(self.device)
|
||||
x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
|
||||
|
||||
value_list = [
|
||||
ones, zeros, zeros, zeros,
|
||||
torch.cos(x), -torch.sin(x), zeros,
|
||||
torch.sin(x),
|
||||
torch.cos(x)
|
||||
]
|
||||
rot_x = torch.cat(value_list, dim=1).reshape([batch_size, 3, 3])
|
||||
|
||||
value_list = [
|
||||
torch.cos(y), zeros,
|
||||
torch.sin(y), zeros, ones, zeros, -torch.sin(y), zeros,
|
||||
torch.cos(y)
|
||||
]
|
||||
rot_y = torch.cat(value_list, dim=1).reshape([batch_size, 3, 3])
|
||||
|
||||
value_list = [
|
||||
torch.cos(z), -torch.sin(z), zeros,
|
||||
torch.sin(z),
|
||||
torch.cos(z), zeros, zeros, zeros, ones
|
||||
]
|
||||
rot_z = torch.cat(value_list, dim=1).reshape([batch_size, 3, 3])
|
||||
|
||||
rot = rot_z @ rot_y @ rot_x
|
||||
return rot.permute(0, 2, 1)
|
||||
|
||||
def to_camera(self, face_shape):
|
||||
face_shape[..., -1] = self.camera_distance - face_shape[..., -1]
|
||||
return face_shape
|
||||
|
||||
def to_image(self, face_shape):
|
||||
"""
|
||||
Return:
|
||||
face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
|
||||
|
||||
Parameters:
|
||||
face_shape -- torch.tensor, size (B, N, 3)
|
||||
"""
|
||||
# to image_plane
|
||||
face_proj = face_shape @ self.persc_proj
|
||||
face_proj = face_proj[..., :2] / face_proj[..., 2:]
|
||||
|
||||
return face_proj
|
||||
|
||||
def transform(self, face_shape, rot, trans):
|
||||
"""
|
||||
Return:
|
||||
face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
|
||||
|
||||
Parameters:
|
||||
face_shape -- torch.tensor, size (B, N, 3)
|
||||
rot -- torch.tensor, size (B, 3, 3)
|
||||
trans -- torch.tensor, size (B, 3)
|
||||
"""
|
||||
return face_shape @ rot + trans.unsqueeze(1)
|
||||
|
||||
def get_landmarks(self, face_proj):
|
||||
"""
|
||||
Return:
|
||||
face_lms -- torch.tensor, size (B, 68, 2)
|
||||
|
||||
Parameters:
|
||||
face_proj -- torch.tensor, size (B, N, 2)
|
||||
"""
|
||||
return face_proj[:, self.keypoints]
|
||||
|
||||
def split_coeff(self, coeffs):
|
||||
"""
|
||||
Return:
|
||||
coeffs_dict -- a dict of torch.tensors
|
||||
|
||||
Parameters:
|
||||
coeffs -- torch.tensor, size (B, 256)
|
||||
"""
|
||||
if type(coeffs) == dict and 'id' in coeffs:
|
||||
return coeffs
|
||||
|
||||
id_coeffs = coeffs[:, :80]
|
||||
exp_coeffs = coeffs[:, 80:144]
|
||||
tex_coeffs = coeffs[:, 144:224]
|
||||
angles = coeffs[:, 224:227]
|
||||
gammas = coeffs[:, 227:254]
|
||||
translations = coeffs[:, 254:]
|
||||
return {
|
||||
'id': id_coeffs,
|
||||
'exp': exp_coeffs,
|
||||
'tex': tex_coeffs,
|
||||
'angle': angles,
|
||||
'gamma': gammas,
|
||||
'trans': translations
|
||||
}
|
||||
|
||||
def merge_coeff(self, coeffs):
|
||||
"""
|
||||
Return:
|
||||
coeffs_dict -- a dict of torch.tensors
|
||||
|
||||
Parameters:
|
||||
coeffs -- torch.tensor, size (B, 256)
|
||||
"""
|
||||
names = ['id', 'exp', 'tex', 'angle', 'gamma', 'trans']
|
||||
coeffs_merge = []
|
||||
for name in names:
|
||||
coeffs_merge.append(coeffs[name].detach())
|
||||
coeffs_merge = torch.cat(coeffs_merge, dim=1)
|
||||
|
||||
return coeffs_merge
|
||||
|
||||
def reverse_recenter(self, face_shape):
|
||||
batch_size = face_shape.shape[0]
|
||||
face_shape = face_shape.reshape([-1, 3])
|
||||
mean_shape_ori = self.mean_shape_ori.reshape([-1, 3])
|
||||
face_shape = face_shape + torch.mean(
|
||||
mean_shape_ori[:35709, ...], dim=0, keepdim=True)
|
||||
face_shape = face_shape.reshape([batch_size, -1, 3])
|
||||
return face_shape
|
||||
|
||||
def add_nonlinear_offset_eyes(self, face_shape, shape_offset):
|
||||
assert face_shape.shape[0] == 1 and shape_offset.shape[0] == 1
|
||||
face_shape = face_shape[0]
|
||||
shape_offset = shape_offset[0]
|
||||
|
||||
corner_shape = face_shape[-625:, :]
|
||||
corner_offset = shape_offset[self.eye_corner_inds]
|
||||
for i in range(len(self.eye_lines)):
|
||||
corner_shape[self.eye_lines[i]] += corner_offset[i][None, ...]
|
||||
face_shape[-625:, :] = corner_shape
|
||||
|
||||
l_eye_landmarks = [11540, 11541]
|
||||
r_eye_landmarks = [4271, 4272]
|
||||
|
||||
l_eye_offset = torch.mean(
|
||||
shape_offset[l_eye_landmarks], dim=0, keepdim=True)
|
||||
face_shape[37082:37082 + 609] += l_eye_offset
|
||||
|
||||
r_eye_offset = torch.mean(
|
||||
shape_offset[r_eye_landmarks], dim=0, keepdim=True)
|
||||
face_shape[37082 + 609:37082 + 609 + 608] += r_eye_offset
|
||||
|
||||
face_shape = face_shape[None, ...]
|
||||
|
||||
return face_shape
|
||||
|
||||
def add_nonlinear_offset(self, face_shape, shape_offset_uv, UVs):
|
||||
"""
|
||||
|
||||
Args:
|
||||
face_shape: torch.tensor, size (1, N, 3)
|
||||
shape_offset_uv: torch.tensor, size (1, h, w, 3)
|
||||
UVs: torch.tensor, size (N, 2)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert face_shape.shape[0] == 1 and shape_offset_uv.shape[0] == 1
|
||||
face_shape = face_shape[0]
|
||||
shape_offset_uv = shape_offset_uv[0]
|
||||
|
||||
h, w = shape_offset_uv.shape[:2]
|
||||
UVs_coords = UVs.clone()
|
||||
UVs_coords[:, 0] *= w
|
||||
UVs_coords[:, 1] *= h
|
||||
UVs_coords_int = torch.floor(UVs_coords)
|
||||
UVs_coords_float = UVs_coords - UVs_coords_int
|
||||
UVs_coords_int = UVs_coords_int.long()
|
||||
|
||||
shape_lt = shape_offset_uv[(h - 1 - UVs_coords_int[:, 1]).clamp(
|
||||
0, h - 1), UVs_coords_int[:, 0].clamp(0, w - 1)] # (N, 3)
|
||||
shape_lb = shape_offset_uv[(h - UVs_coords_int[:, 1]).clamp(0, h - 1),
|
||||
UVs_coords_int[:, 0].clamp(0, w - 1)]
|
||||
shape_rt = shape_offset_uv[(h - 1
|
||||
- UVs_coords_int[:, 1]).clamp(0, h - 1),
|
||||
(UVs_coords_int[:, 0] + 1).clamp(0, w - 1)]
|
||||
shape_rb = shape_offset_uv[(h - UVs_coords_int[:, 1]).clamp(0, h - 1),
|
||||
(UVs_coords_int[:, 0] + 1).clamp(0, w - 1)]
|
||||
|
||||
value_1 = shape_lt * (
|
||||
1 - UVs_coords_float[:, :1]) * UVs_coords_float[:, 1:]
|
||||
value_2 = shape_lb * (1 - UVs_coords_float[:, :1]) * (
|
||||
1 - UVs_coords_float[:, 1:])
|
||||
value_3 = shape_rt * UVs_coords_float[:, :1] * UVs_coords_float[:, 1:]
|
||||
value_4 = shape_rb * UVs_coords_float[:, :1] * (
|
||||
1 - UVs_coords_float[:, 1:])
|
||||
|
||||
offset_shape = value_1 + value_2 + value_3 + value_4 # (B, N, 3)
|
||||
|
||||
face_shape = (face_shape + offset_shape)[None, ...]
|
||||
|
||||
return face_shape, offset_shape[None, ...]
|
||||
|
||||
def compute_for_render_head_fitting(self,
|
||||
coeffs,
|
||||
shape_offset_uv,
|
||||
texture_offset_uv,
|
||||
shape_offset_uv_head,
|
||||
texture_offset_uv_head,
|
||||
UVs,
|
||||
reverse_recenter=True,
|
||||
get_eyes=False,
|
||||
get_neck=False,
|
||||
nose_coeff=0.0,
|
||||
neck_coeff=0.0,
|
||||
eyes_coeff=0.0):
|
||||
if type(coeffs) == dict:
|
||||
coef_dict = coeffs
|
||||
elif type(coeffs) == torch.Tensor:
|
||||
coef_dict = self.split_coeff(coeffs)
|
||||
|
||||
face_shape = self.compute_shape(
|
||||
coef_dict['id'],
|
||||
coef_dict['exp'],
|
||||
nose_coeff=nose_coeff,
|
||||
neck_coeff=neck_coeff,
|
||||
eyes_coeff=eyes_coeff) # (1, n, 3)
|
||||
if reverse_recenter:
|
||||
face_shape_ori_noRecenter = self.reverse_recenter(
|
||||
face_shape.clone())
|
||||
else:
|
||||
face_shape_ori_noRecenter = face_shape.clone()
|
||||
face_vertex_ori = self.to_camera(face_shape_ori_noRecenter)
|
||||
|
||||
face_shape[:, :35241, :], shape_offset = self.add_nonlinear_offset(
|
||||
face_shape[:, :35241, :], shape_offset_uv,
|
||||
UVs[:35709, ...][self.bfm_keep_inds]) # (1, n, 3)
|
||||
if get_eyes:
|
||||
face_shape = self.add_nonlinear_offset_eyes(
|
||||
face_shape, shape_offset)
|
||||
if get_neck:
|
||||
face_shape[:, 35241:37082, ...], _ = self.add_nonlinear_offset(
|
||||
face_shape[:, 35241:37082, ...], shape_offset_uv_head,
|
||||
UVs[35709:, ...]) # (1, n, 3)
|
||||
else:
|
||||
face_shape[:, self.ours_hair_area_inds,
|
||||
...], _ = self.add_nonlinear_offset(
|
||||
face_shape[:, self.ours_hair_area_inds,
|
||||
...], shape_offset_uv_head,
|
||||
UVs[self.ours_hair_area_inds + (35709 - 35241),
|
||||
...]) # (1, n, 3)
|
||||
|
||||
if reverse_recenter:
|
||||
face_shape_offset_noRecenter = self.reverse_recenter(
|
||||
face_shape.clone())
|
||||
else:
|
||||
face_shape_offset_noRecenter = face_shape.clone()
|
||||
face_vertex_offset = self.to_camera(face_shape_offset_noRecenter)
|
||||
|
||||
rotation = self.compute_rotation(coef_dict['angle'])
|
||||
|
||||
face_shape_transformed = self.transform(face_shape, rotation,
|
||||
coef_dict['trans'])
|
||||
face_vertex = self.to_camera(face_shape_transformed)
|
||||
|
||||
face_proj = self.to_image(face_vertex)
|
||||
landmark = self.get_landmarks(face_proj)
|
||||
|
||||
face_texture = self.compute_texture(coef_dict['tex']) # (1, n, 3)
|
||||
face_texture[:, :35241, :], texture_offset = self.add_nonlinear_offset(
|
||||
face_texture[:, :35241, :], texture_offset_uv,
|
||||
UVs[:35709, ...][self.bfm_keep_inds])
|
||||
face_texture[:, 35241:37082, :], _ = self.add_nonlinear_offset(
|
||||
face_texture[:, 35241:37082, :], texture_offset_uv_head,
|
||||
UVs[35709:, ...])
|
||||
|
||||
face_norm = self.compute_norm(face_shape)
|
||||
face_norm_roted = face_norm @ rotation
|
||||
face_color = self.compute_color(face_texture, face_norm_roted,
|
||||
coef_dict['gamma'])
|
||||
|
||||
return face_vertex, face_texture, face_color, landmark, face_vertex_ori, face_vertex_offset, face_proj
|
||||
|
||||
def compute_for_render_head(self,
|
||||
coeffs,
|
||||
shape_offset_uv,
|
||||
texture_offset_uv,
|
||||
shape_offset_uv_head,
|
||||
texture_offset_uv_head,
|
||||
UVs,
|
||||
reverse_recenter=True,
|
||||
nose_coeff=0.0,
|
||||
neck_coeff=0.0,
|
||||
eyes_coeff=0.0,
|
||||
neckSlim_coeff=0.0,
|
||||
neckStretch_coeff=0.0):
|
||||
if type(coeffs) == dict:
|
||||
coef_dict = coeffs
|
||||
elif type(coeffs) == torch.Tensor:
|
||||
coef_dict = self.split_coeff(coeffs)
|
||||
|
||||
face_shape = self.compute_shape(
|
||||
coef_dict['id'],
|
||||
coef_dict['exp'],
|
||||
nose_coeff=nose_coeff,
|
||||
neck_coeff=neck_coeff,
|
||||
eyes_coeff=eyes_coeff,
|
||||
neckSlim_coeff=neckSlim_coeff,
|
||||
neckStretch_coeff=neckStretch_coeff) # (1, n, 3)
|
||||
if reverse_recenter:
|
||||
face_shape_ori_noRecenter = self.reverse_recenter(
|
||||
face_shape.clone())
|
||||
else:
|
||||
face_shape_ori_noRecenter = face_shape.clone()
|
||||
face_vertex_ori = self.to_camera(face_shape_ori_noRecenter)
|
||||
|
||||
face_shape[:, :35709, :], shape_offset = self.add_nonlinear_offset(
|
||||
face_shape[:, :35709, :], shape_offset_uv, UVs[:35709,
|
||||
...]) # (1, n, 3)
|
||||
face_shape[:, 35709:,
|
||||
...], _ = self.add_nonlinear_offset(face_shape[:, 35709:,
|
||||
...],
|
||||
shape_offset_uv_head,
|
||||
UVs[35709:,
|
||||
...]) # (1, n, 3)
|
||||
|
||||
if reverse_recenter:
|
||||
face_shape_offset_noRecenter = self.reverse_recenter(
|
||||
face_shape.clone())
|
||||
else:
|
||||
face_shape_offset_noRecenter = face_shape.clone()
|
||||
face_vertex_offset = self.to_camera(face_shape_offset_noRecenter)
|
||||
|
||||
rotation = self.compute_rotation(coef_dict['angle'])
|
||||
|
||||
face_shape_transformed = self.transform(face_shape, rotation,
|
||||
coef_dict['trans'])
|
||||
face_vertex = self.to_camera(face_shape_transformed)
|
||||
|
||||
face_proj = self.to_image(face_vertex)
|
||||
landmark = self.get_landmarks(face_proj)
|
||||
|
||||
face_texture = self.compute_texture(coef_dict['tex']) # (1, n, 3)
|
||||
face_texture[:, :35709, :], texture_offset = self.add_nonlinear_offset(
|
||||
face_texture[:, :35709, :], texture_offset_uv, UVs[:35709, ...])
|
||||
face_texture[:, 35709:, :], _ = self.add_nonlinear_offset(
|
||||
face_texture[:, 35709:, :], texture_offset_uv_head, UVs[35709:,
|
||||
...])
|
||||
|
||||
face_norm = self.compute_norm(face_shape)
|
||||
face_norm_roted = face_norm @ rotation
|
||||
face_color = self.compute_color(face_texture, face_norm_roted,
|
||||
coef_dict['gamma'])
|
||||
|
||||
return face_vertex, face_texture, face_color, landmark, face_vertex_ori, face_vertex_offset, face_proj
|
||||
@@ -0,0 +1,196 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
if tf.__version__ >= '2.0':
|
||||
tf = tf.compat.v1
|
||||
tf.disable_eager_execution()
|
||||
|
||||
|
||||
class HeadSegmentor():
|
||||
|
||||
def __init__(self, model_root):
|
||||
"""The HeadSegmentor is implemented based on https://arxiv.org/abs/2004.04955
|
||||
Args:
|
||||
model_root: the root directory of the model files
|
||||
"""
|
||||
self.sess = self.load_sess(
|
||||
os.path.join(model_root, 'head_segmentation',
|
||||
'Matting_headparser_6_18.pb'))
|
||||
self.sess_detect = self.load_sess(
|
||||
os.path.join(model_root, 'head_segmentation', 'face_detect.pb'))
|
||||
self.sess_face = self.load_sess(
|
||||
os.path.join(model_root, 'head_segmentation', 'segment_face.pb'))
|
||||
|
||||
def load_sess(self, model_path):
|
||||
config = tf.ConfigProto(allow_soft_placement=True)
|
||||
config.gpu_options.allow_growth = True
|
||||
sess = tf.Session(config=config)
|
||||
with tf.gfile.FastGFile(model_path, 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
sess.graph.as_default()
|
||||
tf.import_graph_def(graph_def, name='')
|
||||
sess.run(tf.global_variables_initializer())
|
||||
return sess
|
||||
|
||||
def process(self, image):
|
||||
""" image: bgr
|
||||
"""
|
||||
|
||||
h, w, c = image.shape
|
||||
faceRects = self.detect_face(image)
|
||||
face_num = len(faceRects)
|
||||
all_head_alpha = []
|
||||
all_face_mask = []
|
||||
for i in range(face_num):
|
||||
y1 = faceRects[i][0]
|
||||
y2 = faceRects[i][1]
|
||||
x1 = faceRects[i][2]
|
||||
x2 = faceRects[i][3]
|
||||
pad_y1, pad_y2, pad_x1, pad_x2 = self.pad_box(
|
||||
y1, y2, x1, x2, 0.15, 0.15, 0.15, 0.15, h, w)
|
||||
temp_img = image.copy()
|
||||
roi_img = temp_img[pad_y1:pad_y2, pad_x1:pad_x2]
|
||||
output_alpha = self.sess_face.run(
|
||||
self.sess_face.graph.get_tensor_by_name('output_alpha_face:0'),
|
||||
feed_dict={'input_image_face:0': roi_img[:, :, ::-1]})
|
||||
face_mask = np.zeros((h, w, 3))
|
||||
face_mask[pad_y1:pad_y2, pad_x1:pad_x2] = output_alpha
|
||||
all_face_mask.append(face_mask)
|
||||
cv2.imwrite(str(i) + 'face.jpg', face_mask)
|
||||
cv2.imwrite(str(i) + 'face_roi.jpg', roi_img)
|
||||
|
||||
for i in range(face_num):
|
||||
y1 = faceRects[i][0]
|
||||
y2 = faceRects[i][1]
|
||||
x1 = faceRects[i][2]
|
||||
x2 = faceRects[i][3]
|
||||
pad_y1, pad_y2, pad_x1, pad_x2 = self.pad_box(
|
||||
y1, y2, x1, x2, 1.47, 1.47, 1.3, 2.0, h, w)
|
||||
temp_img = image.copy()
|
||||
for j in range(face_num):
|
||||
y1 = faceRects[j][0]
|
||||
y2 = faceRects[j][1]
|
||||
x1 = faceRects[j][2]
|
||||
x2 = faceRects[j][3]
|
||||
small_y1, small_y2, small_x1, small_x2 = self.pad_box(
|
||||
y1, y2, x1, x2, -0.1, -0.1, -0.1, -0.1, h, w)
|
||||
small_width = small_x2 - small_x1
|
||||
small_height = small_y2 - small_y1
|
||||
if (small_x1 < 0 or small_y1 < 0 or small_width < 3
|
||||
or small_height < 3 or small_x2 > w or small_y2 > h):
|
||||
continue
|
||||
# if(i!=j):
|
||||
# temp_img[small_y1:small_y2,small_x1:small_x2]=0
|
||||
if (i != j):
|
||||
temp_img = temp_img * (1.0 - all_face_mask[j] / 255.0)
|
||||
|
||||
roi_img = temp_img[pad_y1:pad_y2, pad_x1:pad_x2]
|
||||
output_alpha = self.sess.run(
|
||||
self.sess.graph.get_tensor_by_name('output_alpha:0'),
|
||||
feed_dict={'input_image:0': roi_img[:, :, ::-1]})
|
||||
head_alpha = np.zeros((h, w))
|
||||
head_alpha[pad_y1:pad_y2, pad_x1:pad_x2] = output_alpha[:, :, 0]
|
||||
if np.sum(head_alpha) > 255 * w * h * 0.01 * 0.01:
|
||||
all_head_alpha.append(head_alpha)
|
||||
|
||||
head_num = len(all_head_alpha)
|
||||
head_elements = []
|
||||
if head_num == 0:
|
||||
return head_elements
|
||||
|
||||
for i in range(head_num):
|
||||
head_alpha = all_head_alpha[i]
|
||||
head_elements.append(head_alpha)
|
||||
|
||||
return head_elements
|
||||
|
||||
def pad_box(self, y1, y2, x1, x2, left_ratio, right_ratio, top_ratio,
|
||||
bottom_ratio, h, w):
|
||||
box_w = x2 - x1
|
||||
box_h = y2 - y1
|
||||
pad_y1 = np.maximum(np.int32(y1 - top_ratio * box_h), 0)
|
||||
pad_y2 = np.minimum(np.int32(y2 + bottom_ratio * box_h), h - 1)
|
||||
pad_x1 = np.maximum(np.int32(x1 - left_ratio * box_w), 0)
|
||||
pad_x2 = np.minimum(np.int32(x2 + right_ratio * box_w), w - 1)
|
||||
return pad_y1, pad_y2, pad_x1, pad_x2
|
||||
|
||||
def detect_face(self, img):
|
||||
h, w, c = img.shape
|
||||
input_img = cv2.resize(img[:, :, ::-1], (512, 512))
|
||||
boxes, scores, num_detections = self.sess_detect.run(
|
||||
[
|
||||
self.sess_detect.graph.get_tensor_by_name('tower_0/boxes:0'),
|
||||
self.sess_detect.graph.get_tensor_by_name('tower_0/scores:0'),
|
||||
self.sess_detect.graph.get_tensor_by_name(
|
||||
'tower_0/num_detections:0')
|
||||
],
|
||||
feed_dict={
|
||||
'tower_0/images:0': input_img[np.newaxis],
|
||||
'training_flag:0': False
|
||||
})
|
||||
faceRects = []
|
||||
for i in range(num_detections[0]):
|
||||
if scores[0, i] < 0.5:
|
||||
continue
|
||||
y1 = np.int32(boxes[0, i, 0] * h)
|
||||
x1 = np.int32(boxes[0, i, 1] * w)
|
||||
y2 = np.int32(boxes[0, i, 2] * h)
|
||||
x2 = np.int32(boxes[0, i, 3] * w)
|
||||
if x2 <= x1 + 3 or y2 <= y1 + 3:
|
||||
continue
|
||||
faceRects.append((y1, y2, x1, x2, y2 - y1, x2 - x1))
|
||||
sorted(faceRects, key=lambda x: x[4] * x[5], reverse=True)
|
||||
return faceRects
|
||||
|
||||
def generate_json(self, status_code, status_msg, ori_url, result_element,
|
||||
track_id):
|
||||
data = {}
|
||||
data['originUri'] = ori_url
|
||||
data['elements'] = result_element
|
||||
data['statusCode'] = status_code
|
||||
data['statusMessage'] = status_msg
|
||||
data['requestId'] = track_id
|
||||
return json.dumps(data)
|
||||
|
||||
def get_box(self, alpha):
|
||||
h, w = alpha.shape
|
||||
start_h = 0
|
||||
end_h = 0
|
||||
start_w = 0
|
||||
end_w = 0
|
||||
for i in range(0, h, 3):
|
||||
line = alpha[i, :]
|
||||
if np.max(line) >= 1:
|
||||
start_h = i
|
||||
break
|
||||
|
||||
for i in range(0, w, 3):
|
||||
line = alpha[:, i]
|
||||
if np.max(line) >= 1:
|
||||
start_w = i
|
||||
break
|
||||
|
||||
for i in range(0, h, 3):
|
||||
i = h - 1 - i
|
||||
line = alpha[i, :]
|
||||
if np.max(line) >= 1:
|
||||
end_h = i
|
||||
if end_h < h - 1:
|
||||
end_h = end_h + 1
|
||||
break
|
||||
for i in range(0, w, 3):
|
||||
i = w - 1 - i
|
||||
line = alpha[:, i]
|
||||
if np.max(line) >= 1:
|
||||
end_w = i
|
||||
if end_w < w - 1:
|
||||
end_w = end_w + 1
|
||||
break
|
||||
|
||||
return start_h, start_w, end_h, end_w
|
||||
@@ -0,0 +1,564 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
from modelscope.models.cv.face_reconstruction.utils import (estimate_normals,
|
||||
read_obj)
|
||||
from . import networks, opt
|
||||
from .bfm import ParametricFaceModel
|
||||
from .losses import (BinaryDiceLoss, TVLoss, TVLoss_std, landmark_loss,
|
||||
perceptual_loss, photo_loss, points_loss_horizontal,
|
||||
reflectance_loss, reg_loss)
|
||||
from .nv_diffrast import MeshRenderer
|
||||
|
||||
|
||||
@MODELS.register_module('head-reconstruction', 'head_reconstruction')
|
||||
class HeadReconModel(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
"""The HeadReconModel is implemented based on HRN, publicly available at
|
||||
https://github.com/youngLBW/HRN
|
||||
|
||||
Args:
|
||||
model_dir: the root directory of the model files
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
|
||||
self.model_dir = model_dir
|
||||
opt.bfm_folder = os.path.join(model_dir, 'assets')
|
||||
self.opt = opt
|
||||
self.isTrain = opt.isTrain
|
||||
self.visual_names = ['output_vis']
|
||||
self.model_names = ['net_recon']
|
||||
self.parallel_names = self.model_names + [
|
||||
'renderer', 'renderer_fitting'
|
||||
]
|
||||
|
||||
# networks
|
||||
self.net_recon = networks.define_net_recon(
|
||||
net_recon=opt.net_recon,
|
||||
use_last_fc=opt.use_last_fc,
|
||||
init_path=None)
|
||||
|
||||
# assets
|
||||
self.headmodel = ParametricFaceModel(
|
||||
assets_root=opt.bfm_folder,
|
||||
camera_distance=opt.camera_d,
|
||||
focal=opt.focal,
|
||||
center=opt.center,
|
||||
is_train=self.isTrain,
|
||||
default_name='ourRefineBFMEye0504_model.mat')
|
||||
|
||||
self.headmodel_for_fitting = ParametricFaceModel(
|
||||
assets_root=opt.bfm_folder,
|
||||
camera_distance=opt.camera_d,
|
||||
focal=opt.focal,
|
||||
center=opt.center,
|
||||
is_train=self.isTrain,
|
||||
default_name='ourRefineFull_model.mat')
|
||||
|
||||
# renderer
|
||||
fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi
|
||||
self.renderer = MeshRenderer(
|
||||
rasterize_fov=fov,
|
||||
znear=opt.z_near,
|
||||
zfar=opt.z_far,
|
||||
rasterize_size=int(2 * opt.center))
|
||||
|
||||
self.renderer_fitting = MeshRenderer(
|
||||
rasterize_fov=fov,
|
||||
znear=opt.z_near,
|
||||
zfar=opt.z_far,
|
||||
rasterize_size=int(2 * opt.center))
|
||||
|
||||
template_obj_path = os.path.join(
|
||||
model_dir,
|
||||
'assets/3dmm/template_mesh/template_ourFull_bfmEyes.obj')
|
||||
self.template_output_mesh = read_obj(template_obj_path)
|
||||
|
||||
self.nonlinear_UVs = self.template_output_mesh['uvs']
|
||||
self.nonlinear_UVs = torch.from_numpy(self.nonlinear_UVs)
|
||||
|
||||
self.jaw_edge_mask = cv2.imread(
|
||||
os.path.join(model_dir,
|
||||
'assets/texture/jaw_edge_mask2.png'))[..., 0].astype(
|
||||
np.float32) / 255.0
|
||||
self.jaw_edge_mask = cv2.resize(self.jaw_edge_mask, (300, 300))[...,
|
||||
None]
|
||||
|
||||
self.input_imgs = []
|
||||
self.input_img_hds = []
|
||||
self.input_fat_img_hds = []
|
||||
self.atten_masks = []
|
||||
self.gt_lms = []
|
||||
self.gt_lm_hds = []
|
||||
self.trans_ms = []
|
||||
self.img_names = []
|
||||
self.face_masks = []
|
||||
self.head_masks = []
|
||||
self.input_imgs_coeff = []
|
||||
self.gt_lms_coeff = []
|
||||
|
||||
self.loss_names = [
|
||||
'all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc'
|
||||
]
|
||||
|
||||
self.compute_feat_loss = perceptual_loss
|
||||
self.comupte_color_loss = photo_loss
|
||||
self.compute_lm_loss = landmark_loss
|
||||
self.compute_reg_loss = reg_loss
|
||||
self.compute_reflc_loss = reflectance_loss
|
||||
|
||||
if opt.isTrain:
|
||||
self.optimizer = torch.optim.Adam(
|
||||
self.net_recon.parameters(), lr=opt.lr)
|
||||
self.optimizers = [self.optimizer]
|
||||
self.parallel_names += ['net_recog']
|
||||
|
||||
def set_device(self, device):
|
||||
self.device = device
|
||||
self.net_recon = self.net_recon.to(self.device)
|
||||
self.headmodel.to(self.device)
|
||||
self.headmodel_for_fitting.to(self.device)
|
||||
self.nonlinear_UVs = self.nonlinear_UVs.to(self.device)
|
||||
|
||||
def load_networks(self, load_path):
|
||||
state_dict = torch.load(load_path, map_location=self.device)
|
||||
print('loading the model from %s' % load_path)
|
||||
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, name)
|
||||
if isinstance(net, torch.nn.DataParallel):
|
||||
net = net.module
|
||||
net.load_state_dict(state_dict[name], strict=False)
|
||||
|
||||
def setup(self, checkpoint_path):
|
||||
"""Load and print networks; create schedulers
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
self.load_networks(checkpoint_path)
|
||||
|
||||
def parallelize(self, convert_sync_batchnorm=True):
|
||||
if not self.opt.use_ddp:
|
||||
for name in self.parallel_names:
|
||||
if isinstance(name, str):
|
||||
module = getattr(self, name)
|
||||
setattr(self, name, module.to(self.device))
|
||||
else:
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
module = getattr(self, name)
|
||||
if convert_sync_batchnorm:
|
||||
module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
|
||||
module)
|
||||
setattr(
|
||||
self, name,
|
||||
torch.nn.parallel.DistributedDataParallel(
|
||||
module.to(self.device),
|
||||
device_ids=[self.device.index],
|
||||
find_unused_parameters=True,
|
||||
broadcast_buffers=True))
|
||||
|
||||
# DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
|
||||
for name in self.parallel_names:
|
||||
if isinstance(name, str) and name not in self.model_names:
|
||||
module = getattr(self, name)
|
||||
setattr(self, name, module.to(self.device))
|
||||
|
||||
# put state_dict of optimizer to gpu device
|
||||
if self.opt.phase != 'test':
|
||||
if self.opt.continue_train:
|
||||
for optim in self.optimizers:
|
||||
for state in optim.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(self.device)
|
||||
|
||||
def eval(self):
|
||||
"""Make models eval mode"""
|
||||
for name in self.model_names:
|
||||
if isinstance(name, str):
|
||||
net = getattr(self, name)
|
||||
net.eval()
|
||||
|
||||
def set_render(self, image_res=1024):
|
||||
fov = 2 * np.arctan(self.opt.center / self.opt.focal) * 180 / np.pi
|
||||
if image_res is None:
|
||||
image_res = int(2 * self.opt.center)
|
||||
|
||||
self.renderer = MeshRenderer(
|
||||
rasterize_fov=fov,
|
||||
znear=self.opt.z_near,
|
||||
zfar=self.opt.z_far,
|
||||
rasterize_size=image_res)
|
||||
|
||||
def set_input(self, input):
|
||||
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
||||
|
||||
Parameters:
|
||||
input: a dictionary that contains the data itself and its metadata information.
|
||||
"""
|
||||
self.input_img = input['imgs'].to(self.device)
|
||||
self.input_img_hd = input['imgs_hd'].to(
|
||||
self.device) if 'imgs_hd' in input else None
|
||||
|
||||
if 'imgs_fat_hd' not in input or input['imgs_fat_hd'] is None:
|
||||
self.input_fat_img_hd = self.input_img_hd
|
||||
else:
|
||||
self.input_fat_img_hd = input['imgs_fat_hd'].to(self.device)
|
||||
|
||||
self.atten_mask = input['msks'].to(
|
||||
self.device) if 'msks' in input else None
|
||||
self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None
|
||||
self.gt_lm_hd = input['lms_hd'].to(
|
||||
self.device) if 'lms_hd' in input else None
|
||||
self.trans_m = input['M'].to(self.device) if 'M' in input else None
|
||||
self.image_paths = input['im_paths'] if 'im_paths' in input else None
|
||||
self.img_name = input['img_name'] if 'img_name' in input else None
|
||||
self.face_mask = input['face_mask'].to(
|
||||
self.device) if 'face_mask' in input else None
|
||||
self.head_mask = input['head_mask'].to(
|
||||
self.device) if 'head_mask' in input else None
|
||||
self.gt_normals = input['normals'].to(
|
||||
self.device) if 'normals' in input else None
|
||||
self.input_img_coeff = input['imgs_coeff'].to(
|
||||
self.device) if 'imgs_coeff' in input else None
|
||||
self.gt_lm_coeff = input['lms_coeff'].to(
|
||||
self.device) if 'lms_coeff' in input else None
|
||||
|
||||
def check_head_pose(self, coeffs):
|
||||
pi = 3.14
|
||||
if coeffs[0, 225] > pi / 6 or coeffs[0, 225] < -pi / 6:
|
||||
return False
|
||||
elif coeffs[0, 224] > pi / 6 or coeffs[0, 224] < -pi / 6:
|
||||
return False
|
||||
elif coeffs[0, 226] > pi / 6 or coeffs[0, 226] < -pi / 6:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_fusion_mask(self, keep_forehead=True):
|
||||
self.without_forehead_inds = torch.from_numpy(
|
||||
np.load(
|
||||
os.path.join(self.model_dir,
|
||||
'assets/3dmm/inds/bfm_withou_forehead_inds.npy'))
|
||||
).long().to(self.device)
|
||||
|
||||
h, w = self.shape_offset_uv.shape[1:3]
|
||||
self.fusion_mask = torch.zeros((h, w)).to(self.device).float()
|
||||
if keep_forehead:
|
||||
UVs_coords = self.nonlinear_UVs.clone()[:35709][
|
||||
self.without_forehead_inds]
|
||||
else:
|
||||
UVs_coords = self.nonlinear_UVs.clone()[:35709]
|
||||
UVs_coords[:, 0] *= w
|
||||
UVs_coords[:, 1] *= h
|
||||
UVs_coords_int = torch.floor(UVs_coords)
|
||||
UVs_coords_int = UVs_coords_int.long()
|
||||
|
||||
self.fusion_mask[h - 1 - UVs_coords_int[:, 1], UVs_coords_int[:,
|
||||
0]] = 1
|
||||
|
||||
# blur mask
|
||||
self.fusion_mask = self.fusion_mask.cpu().numpy()
|
||||
new_kernel1 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
||||
new_kernel2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (8, 8))
|
||||
self.fusion_mask = cv2.dilate(self.fusion_mask, new_kernel1, 1)
|
||||
self.fusion_mask = cv2.erode(self.fusion_mask, new_kernel2, 1)
|
||||
self.fusion_mask = cv2.blur(self.fusion_mask, (17, 17))
|
||||
self.fusion_mask = torch.from_numpy(self.fusion_mask).float().to(
|
||||
self.device)
|
||||
|
||||
def get_edge_mask(self):
|
||||
|
||||
h, w = self.shape_offset_uv.shape[1:3]
|
||||
self.edge_mask = torch.zeros((h, w)).to(self.device).float()
|
||||
UVs_coords = self.nonlinear_UVs.clone()[self.edge_points_inds]
|
||||
UVs_coords[:, 0] *= w
|
||||
UVs_coords[:, 1] *= h
|
||||
UVs_coords_int = torch.floor(UVs_coords)
|
||||
UVs_coords_int = UVs_coords_int.long()
|
||||
|
||||
self.edge_mask[h - 1 - UVs_coords_int[:, 1], UVs_coords_int[:, 0]] = 1
|
||||
|
||||
# blur mask
|
||||
self.edge_mask = self.edge_mask.cpu().numpy()
|
||||
new_kernel1 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (8, 8))
|
||||
self.edge_mask = cv2.dilate(self.edge_mask, new_kernel1, 1)
|
||||
self.edge_mask = cv2.blur(self.edge_mask, (5, 5))
|
||||
self.edge_mask = torch.from_numpy(self.edge_mask).float().to(
|
||||
self.device)
|
||||
|
||||
def blur_shape_offset_uv(self, global_blur=False, blur_size=3):
|
||||
if self.edge_mask is not None:
|
||||
shape_offset_uv_blur = self.shape_offset_uv[0].detach().cpu(
|
||||
).numpy()
|
||||
shape_offset_uv_blur = cv2.blur(shape_offset_uv_blur, (15, 15))
|
||||
shape_offset_uv_blur = torch.from_numpy(
|
||||
shape_offset_uv_blur).float().to(self.device)[None, ...]
|
||||
self.shape_offset_uv = shape_offset_uv_blur * self.edge_mask[
|
||||
None, ..., None] + self.shape_offset_uv * (
|
||||
1 - self.edge_mask[None, ..., None])
|
||||
|
||||
self.shape_offset_uv = self.shape_offset_uv * self.fusion_mask[None,
|
||||
...,
|
||||
None]
|
||||
|
||||
if global_blur and blur_size > 0:
|
||||
shape_offset_uv_blur = self.shape_offset_uv[0].detach().cpu(
|
||||
).numpy()
|
||||
shape_offset_uv_blur = cv2.blur(shape_offset_uv_blur,
|
||||
(blur_size, blur_size))
|
||||
shape_offset_uv_blur = torch.from_numpy(
|
||||
shape_offset_uv_blur).float().to(self.device)[None, ...]
|
||||
self.shape_offset_uv = shape_offset_uv_blur
|
||||
|
||||
def blur_offset_edge(self):
|
||||
shape_offset_uv = self.shape_offset_uv[0].detach().cpu().numpy()
|
||||
shape_offset_uv_head = self.shape_offset_uv_head[0].detach().cpu(
|
||||
).numpy()
|
||||
shape_offset_uv_head = cv2.resize(shape_offset_uv_head, (300, 300))
|
||||
shape_offset_uv_head = shape_offset_uv_head * (
|
||||
1 - self.jaw_edge_mask) + shape_offset_uv * self.jaw_edge_mask
|
||||
shape_offset_uv_head = cv2.resize(shape_offset_uv_head, (100, 100))
|
||||
|
||||
self.shape_offset_uv_head = torch.from_numpy(
|
||||
shape_offset_uv_head).float().to(self.device)[None, ...]
|
||||
|
||||
def fitting_nonlinear(self, coeff, n_iters=250):
|
||||
output_coeff = coeff.detach().clone()
|
||||
|
||||
output_coeff = self.headmodel_for_fitting.split_coeff(output_coeff)
|
||||
output_coeff['id'].requires_grad = True
|
||||
output_coeff['exp'].requires_grad = True
|
||||
output_coeff['tex'].requires_grad = True
|
||||
output_coeff['angle'].requires_grad = True
|
||||
output_coeff['gamma'].requires_grad = True
|
||||
output_coeff['trans'].requires_grad = True
|
||||
|
||||
self.shape_offset_uv = torch.zeros((1, 300, 300, 3),
|
||||
dtype=torch.float32).to(self.device)
|
||||
self.shape_offset_uv.requires_grad = True
|
||||
|
||||
self.texture_offset_uv = torch.zeros(
|
||||
(1, 300, 300, 3), dtype=torch.float32).to(self.device)
|
||||
self.texture_offset_uv.requires_grad = True
|
||||
|
||||
self.shape_offset_uv_head = torch.zeros(
|
||||
(1, 100, 100, 3), dtype=torch.float32).to(self.device)
|
||||
self.shape_offset_uv_head.requires_grad = True
|
||||
|
||||
self.texture_offset_uv_head = torch.zeros(
|
||||
(1, 100, 100, 3), dtype=torch.float32).to(self.device)
|
||||
self.texture_offset_uv_head.requires_grad = True
|
||||
|
||||
head_face_inds = np.load(
|
||||
os.path.join(self.model_dir,
|
||||
'assets/3dmm/inds/ours_head_face_inds.npy'))
|
||||
head_face_inds = torch.from_numpy(head_face_inds).to(self.device)
|
||||
head_faces = self.headmodel_for_fitting.face_buf[head_face_inds]
|
||||
|
||||
# print('before fitting', output_coeff)
|
||||
|
||||
opt_parameters = [
|
||||
self.shape_offset_uv, self.texture_offset_uv,
|
||||
self.shape_offset_uv_head, self.texture_offset_uv_head,
|
||||
output_coeff['id'], output_coeff['exp'], output_coeff['tex'],
|
||||
output_coeff['gamma']
|
||||
]
|
||||
optim = torch.optim.Adam(opt_parameters, lr=1e-3)
|
||||
|
||||
optim_pose = torch.optim.Adam([output_coeff['trans']], lr=1e-1)
|
||||
|
||||
self.get_edge_points_horizontal()
|
||||
|
||||
for i in range(n_iters): # 500
|
||||
self.pred_vertex_head, self.pred_tex, self.pred_color_head, self.pred_lm, face_shape, \
|
||||
face_shape_offset, self.verts_proj_head = \
|
||||
self.headmodel_for_fitting.compute_for_render_head_fitting(output_coeff, self.shape_offset_uv,
|
||||
self.texture_offset_uv,
|
||||
self.shape_offset_uv_head,
|
||||
self.texture_offset_uv_head,
|
||||
self.nonlinear_UVs)
|
||||
self.pred_vertex = self.pred_vertex_head[:, :35241]
|
||||
self.pred_color = self.pred_color_head[:, :35241]
|
||||
self.verts_proj = self.verts_proj_head[:, :35241]
|
||||
self.pred_mask_head, _, self.pred_head, self.occ_head = self.renderer_fitting(
|
||||
self.pred_vertex_head, head_faces, feat=self.pred_color_head)
|
||||
self.pred_mask, _, self.pred_face, self.occ_face = self.renderer_fitting(
|
||||
self.pred_vertex,
|
||||
self.headmodel_for_fitting.face_buf[:69732],
|
||||
feat=self.pred_color)
|
||||
|
||||
self.pred_coeffs_dict = self.headmodel_for_fitting.split_coeff(
|
||||
output_coeff)
|
||||
self.compute_losses_fitting()
|
||||
|
||||
if i < 150:
|
||||
optim_pose.zero_grad()
|
||||
(self.loss_lm + self.loss_color * 0.1).backward()
|
||||
optim_pose.step()
|
||||
else:
|
||||
optim.zero_grad()
|
||||
self.loss_all.backward()
|
||||
optim.step()
|
||||
|
||||
output_coeff = self.headmodel_for_fitting.merge_coeff(output_coeff)
|
||||
|
||||
self.get_edge_mask()
|
||||
self.get_fusion_mask(keep_forehead=False)
|
||||
self.blur_shape_offset_uv(global_blur=True)
|
||||
self.blur_offset_edge()
|
||||
return output_coeff
|
||||
|
||||
def forward(self):
|
||||
with torch.no_grad():
|
||||
output_coeff = self.net_recon(self.input_img_coeff)
|
||||
|
||||
if not self.check_head_pose(output_coeff):
|
||||
return None
|
||||
|
||||
with torch.enable_grad():
|
||||
output_coeff = self.fitting_nonlinear(output_coeff)
|
||||
|
||||
output_coeff = self.headmodel.split_coeff(output_coeff)
|
||||
eye_coeffs = output_coeff['exp'][0, 16] + output_coeff['exp'][
|
||||
0, 17] + output_coeff['exp'][0, 19]
|
||||
if eye_coeffs > 1.0:
|
||||
degree = 0.5
|
||||
else:
|
||||
degree = 1.0
|
||||
# degree = 0.5
|
||||
output_coeff['exp'][0, 16] += 1 * degree
|
||||
output_coeff['exp'][0, 17] += 1 * degree
|
||||
output_coeff['exp'][0, 19] += 1.5 * degree
|
||||
output_coeff = self.headmodel.merge_coeff(output_coeff)
|
||||
|
||||
self.pred_vertex, _, _, _, face_shape_ori, face_shape, _ = \
|
||||
self.headmodel.compute_for_render_head(output_coeff,
|
||||
self.shape_offset_uv.detach(),
|
||||
self.texture_offset_uv.detach(),
|
||||
self.shape_offset_uv_head.detach() * 0,
|
||||
self.texture_offset_uv_head.detach(),
|
||||
self.nonlinear_UVs,
|
||||
nose_coeff=0.1,
|
||||
neck_coeff=0.3,
|
||||
neckSlim_coeff=0.5,
|
||||
neckStretch_coeff=0.5)
|
||||
|
||||
UVs = np.array(self.template_output_mesh['uvs'])
|
||||
UVs_tensor = torch.tensor(UVs, dtype=torch.float32)
|
||||
UVs_tensor = torch.unsqueeze(UVs_tensor, 0).to(self.pred_vertex.device)
|
||||
|
||||
target_img = self.input_fat_img_hd
|
||||
target_img = target_img.permute(0, 2, 3, 1)
|
||||
face_buf = self.headmodel.face_buf
|
||||
# get texture map
|
||||
with torch.enable_grad():
|
||||
pred_mask, _, pred_face, texture_map, texture_mask = self.renderer.pred_shape_and_texture(
|
||||
self.pred_vertex, face_buf, UVs_tensor, target_img, None)
|
||||
self.pred_coeffs_dict = self.headmodel.split_coeff(output_coeff)
|
||||
|
||||
recon_shape = face_shape # get reconstructed shape, [1, 35709, 3]
|
||||
recon_shape[
|
||||
...,
|
||||
-1] = 10 - recon_shape[..., -1] # from camera space to world space
|
||||
recon_shape = recon_shape.cpu().numpy()[0]
|
||||
tri = self.headmodel.face_buf.cpu().numpy()
|
||||
|
||||
output = {}
|
||||
output['flag'] = 0
|
||||
|
||||
output['tex_map'] = texture_map
|
||||
output['tex_mask'] = texture_mask * 255.0
|
||||
'''
|
||||
coeffs
|
||||
{
|
||||
'id': id_coeffs,
|
||||
'exp': exp_coeffs,
|
||||
'tex': tex_coeffs,
|
||||
'angle': angles,
|
||||
'gamma': gammas,
|
||||
'trans': translations
|
||||
}
|
||||
'''
|
||||
output['coeffs'] = self.pred_coeffs_dict
|
||||
|
||||
normals = estimate_normals(recon_shape, tri)
|
||||
|
||||
output['vertices'] = recon_shape
|
||||
output['triangles'] = tri
|
||||
output['uvs'] = UVs
|
||||
output['faces_uv'] = self.template_output_mesh['faces_uv']
|
||||
output['normals'] = normals
|
||||
|
||||
return output
|
||||
|
||||
def get_edge_points_horizontal(self):
|
||||
left_points = []
|
||||
right_points = []
|
||||
for i in range(self.face_mask.shape[2]):
|
||||
inds = torch.where(self.face_mask[0, 0, i, :] > 0.5) # 0.9
|
||||
if len(inds[0]) > 0: # i > 112 and len(inds[0]) > 0
|
||||
left_points.append(int(inds[0][0]) + 1)
|
||||
right_points.append(int(inds[0][-1]))
|
||||
else:
|
||||
left_points.append(0)
|
||||
right_points.append(self.face_mask.shape[3] - 1)
|
||||
self.left_points = torch.tensor(left_points).long().to(self.device)
|
||||
self.right_points = torch.tensor(right_points).long().to(self.device)
|
||||
|
||||
def compute_losses_fitting(self):
|
||||
face_mask = self.pred_mask
|
||||
face_mask = face_mask.detach()
|
||||
self.loss_color = self.opt.w_color * self.comupte_color_loss(
|
||||
self.pred_face, self.input_img, face_mask) # 1.0
|
||||
|
||||
loss_reg, loss_gamma = self.compute_reg_loss(
|
||||
self.pred_coeffs_dict,
|
||||
w_id=self.opt.w_id,
|
||||
w_exp=self.opt.w_exp,
|
||||
w_tex=self.opt.w_tex)
|
||||
self.loss_reg = self.opt.w_reg * loss_reg # 1.0
|
||||
self.loss_gamma = self.opt.w_gamma * loss_gamma # 1.0
|
||||
|
||||
self.loss_lm = self.opt.w_lm * self.compute_lm_loss(
|
||||
self.pred_lm, self.gt_lm) * 0.1 # 0.1
|
||||
|
||||
self.loss_smooth_offset = TVLoss()(self.shape_offset_uv.permute(
|
||||
0, 3, 1, 2)) * 10000 # 10000
|
||||
|
||||
self.loss_reg_textureOff = torch.mean(
|
||||
torch.abs(self.texture_offset_uv)) * 10 # 10
|
||||
|
||||
self.loss_smooth_offset_std = TVLoss_std()(
|
||||
self.shape_offset_uv.permute(0, 3, 1, 2)) * 50000 # 50000
|
||||
|
||||
self.loss_points_horizontal, self.edge_points_inds = points_loss_horizontal(
|
||||
self.verts_proj, self.left_points, self.right_points) # 20
|
||||
self.loss_points_horizontal *= 20
|
||||
|
||||
self.loss_all = self.loss_color + self.loss_lm + self.loss_reg + self.loss_gamma
|
||||
self.loss_all += self.loss_smooth_offset + self.loss_smooth_offset_std + self.loss_reg_textureOff
|
||||
self.loss_all += self.loss_points_horizontal
|
||||
|
||||
head_mask = self.pred_mask_head
|
||||
head_mask = head_mask.detach()
|
||||
self.loss_color_head = self.opt.w_color * self.comupte_color_loss(
|
||||
self.pred_head, self.input_img, head_mask) # 1.0
|
||||
self.loss_smooth_offset_head = TVLoss()(
|
||||
self.shape_offset_uv_head.permute(0, 3, 1, 2)) * 100 # 10000
|
||||
self.loss_smooth_offset_std_head = TVLoss_std()(
|
||||
self.shape_offset_uv_head.permute(0, 3, 1, 2)) * 500 # 50000
|
||||
self.loss_mask = BinaryDiceLoss()(self.occ_head, self.head_mask) * 20
|
||||
|
||||
self.loss_all += self.loss_mask + self.loss_color_head
|
||||
self.loss_all += self.loss_smooth_offset_head + self.loss_smooth_offset_std_head
|
||||
367
modelscope/models/cv/head_reconstruction/models/losses.py
Normal file
367
modelscope/models/cv/head_reconstruction/models/losses.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from kornia.geometry import warp_affine
|
||||
|
||||
|
||||
def resize_n_crop(image, M, dsize=112):
|
||||
# image: (b, c, h, w)
|
||||
# M : (b, 2, 3)
|
||||
return warp_affine(image, M, dsize=(dsize, dsize))
|
||||
|
||||
|
||||
# perceptual level loss
|
||||
class PerceptualLoss(nn.Module):
|
||||
|
||||
def __init__(self, recog_net, input_size=112):
|
||||
super(PerceptualLoss, self).__init__()
|
||||
self.recog_net = recog_net
|
||||
self.preprocess = lambda x: 2 * x - 1
|
||||
self.input_size = input_size
|
||||
|
||||
def forward(self, imageA, imageB, M):
|
||||
"""
|
||||
1 - cosine distance
|
||||
Parameters:
|
||||
imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order
|
||||
imageB --same as imageA
|
||||
"""
|
||||
|
||||
imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))
|
||||
imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))
|
||||
|
||||
# freeze bn
|
||||
self.recog_net.eval()
|
||||
|
||||
id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)
|
||||
id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)
|
||||
cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
|
||||
return torch.sum(1 - cosine_d) / cosine_d.shape[0]
|
||||
|
||||
|
||||
def perceptual_loss(id_featureA, id_featureB):
|
||||
cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
|
||||
return torch.sum(1 - cosine_d) / cosine_d.shape[0]
|
||||
|
||||
|
||||
# image level loss
|
||||
def photo_loss(imageA, imageB, mask, eps=1e-6):
|
||||
"""
|
||||
l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
|
||||
Parameters:
|
||||
imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
|
||||
imageB --same as imageA
|
||||
"""
|
||||
loss = torch.sqrt(eps + torch.sum(
|
||||
(imageA - imageB)**2, dim=1, keepdims=True)) * mask
|
||||
loss = torch.sum(loss) / torch.max(
|
||||
torch.sum(mask),
|
||||
torch.tensor(1.0).to(mask.device))
|
||||
return loss
|
||||
|
||||
|
||||
def landmark_loss(predict_lm, gt_lm, weight=None):
|
||||
"""
|
||||
weighted mse loss
|
||||
Parameters:
|
||||
predict_lm --torch.tensor (B, 68, 2)
|
||||
gt_lm --torch.tensor (B, 68, 2)
|
||||
weight --numpy.array (1, 68)
|
||||
"""
|
||||
if not weight:
|
||||
weight = np.ones([68])
|
||||
weight[28:31] = 20
|
||||
weight[-8:] = 20
|
||||
weight = np.expand_dims(weight, 0)
|
||||
weight = torch.tensor(weight).to(predict_lm.device)
|
||||
loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight
|
||||
loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])
|
||||
return loss
|
||||
|
||||
|
||||
# regulization
|
||||
def reg_loss(coeffs_dict, w_id=1, w_exp=1, w_tex=1):
|
||||
"""
|
||||
l2 norm without the sqrt, from yu's implementation (mse)
|
||||
tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss
|
||||
Parameters:
|
||||
coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans
|
||||
|
||||
"""
|
||||
# coefficient regularization to ensure plausible 3d faces
|
||||
value_1 = w_id * torch.sum(coeffs_dict['id']**2)
|
||||
value_2 = w_exp * torch.sum(coeffs_dict['exp']**2)
|
||||
value_3 = w_tex * torch.sum(coeffs_dict['tex']**2)
|
||||
creg_loss = value_1 + value_2 + value_3
|
||||
creg_loss = creg_loss / coeffs_dict['id'].shape[0]
|
||||
|
||||
# gamma regularization to ensure a nearly-monochromatic light
|
||||
gamma = coeffs_dict['gamma'].reshape([-1, 3, 9])
|
||||
gamma_mean = torch.mean(gamma, dim=1, keepdims=True)
|
||||
gamma_loss = torch.mean((gamma - gamma_mean)**2)
|
||||
|
||||
return creg_loss, gamma_loss
|
||||
|
||||
|
||||
def reflectance_loss(texture, mask):
|
||||
"""
|
||||
minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo
|
||||
Parameters:
|
||||
texture --torch.tensor, (B, N, 3)
|
||||
mask --torch.tensor, (N), 1 or 0
|
||||
|
||||
"""
|
||||
mask = mask.reshape([1, mask.shape[0], 1])
|
||||
texture_mean = torch.sum(
|
||||
mask * texture, dim=1, keepdims=True) / torch.sum(mask)
|
||||
loss = torch.sum(((texture - texture_mean) * mask)**2) / (
|
||||
texture.shape[0] * torch.sum(mask))
|
||||
return loss
|
||||
|
||||
|
||||
def lm_3d_loss(pred_lm_3d, gt_lm_3d, mask):
|
||||
loss = torch.abs(pred_lm_3d - gt_lm_3d)[mask, :]
|
||||
loss = torch.mean(loss)
|
||||
return loss
|
||||
|
||||
|
||||
class TVLoss(nn.Module):
|
||||
|
||||
def __init__(self, TVLoss_weight=1):
|
||||
super(TVLoss, self).__init__()
|
||||
self.TVLoss_weight = TVLoss_weight
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.size()[0]
|
||||
h_x = x.size()[2]
|
||||
w_x = x.size()[3]
|
||||
count_h = self._tensor_size(x[:, :, 1:, :])
|
||||
count_w = self._tensor_size(x[:, :, :, 1:])
|
||||
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
|
||||
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
|
||||
return self.TVLoss_weight * 2 * (h_tv / count_h
|
||||
+ w_tv / count_w) / batch_size
|
||||
|
||||
def _tensor_size(self, t):
|
||||
return t.size()[1] * t.size()[2] * t.size()[3]
|
||||
|
||||
|
||||
class TVLoss_std(nn.Module):
|
||||
|
||||
def __init__(self, TVLoss_weight=1):
|
||||
super(TVLoss_std, self).__init__()
|
||||
self.TVLoss_weight = TVLoss_weight
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.size()[0]
|
||||
h_x = x.size()[2]
|
||||
w_x = x.size()[3]
|
||||
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2)
|
||||
h_tv = ((h_tv - torch.mean(h_tv))**2).sum()
|
||||
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2)
|
||||
w_tv = ((w_tv - torch.mean(w_tv))**2).sum()
|
||||
return self.TVLoss_weight * 2 * (h_tv + w_tv) / batch_size
|
||||
|
||||
def _tensor_size(self, t):
|
||||
return t.size()[1] * t.size()[2] * t.size()[3]
|
||||
|
||||
|
||||
def photo_loss_sum(imageA, imageB, mask, eps=1e-6):
|
||||
"""
|
||||
l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
|
||||
Parameters:
|
||||
imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
|
||||
imageB --same as imageA
|
||||
"""
|
||||
loss = torch.sqrt(eps + torch.sum(
|
||||
(imageA - imageB)**2, dim=1, keepdims=True)) * mask
|
||||
loss = torch.sum(loss) / (
|
||||
imageA.shape[0] * imageA.shape[2] * imageA.shape[3])
|
||||
return loss
|
||||
|
||||
|
||||
def points_loss_horizontal(verts, left_points, right_points, width=224):
|
||||
verts_int = torch.ceil(verts[0]).long().clamp(0, width - 1) # (n, 2)
|
||||
verts_left = left_points[width - 1 - verts_int[:, 1]].float()
|
||||
verts_right = right_points[width - 1 - verts_int[:, 1]].float()
|
||||
verts_x = verts[0, :, 0]
|
||||
dist = (verts_left - verts_x) / width * (verts_right - verts_x) / width
|
||||
dist /= torch.max(
|
||||
torch.abs((verts_left - verts_x) / width),
|
||||
torch.abs((verts_right - verts_x) / width))
|
||||
edge_inds = torch.where(dist > 0)[0]
|
||||
dist += 0.01
|
||||
dist = torch.nn.functional.relu(dist).clone()
|
||||
dist -= 0.01
|
||||
dist = torch.abs(dist)
|
||||
loss = torch.mean(dist)
|
||||
return loss, edge_inds
|
||||
|
||||
|
||||
class LaplacianLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(LaplacianLoss, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, slice_num = x.size()[:2]
|
||||
z_x = x.size()[2]
|
||||
h_x = x.size()[3]
|
||||
w_x = x.size()[4]
|
||||
count_z = self._tensor_size(x[:, :, 1:, :, :])
|
||||
count_h = self._tensor_size(x[:, :, :, 1:, :])
|
||||
count_w = self._tensor_size(x[:, :, :, :, 1:])
|
||||
z_tv = torch.pow((x[:, :, 1:, :, :] - x[:, :, :z_x - 1, :, :]),
|
||||
2).sum()
|
||||
h_tv = torch.pow((x[:, :, :, 1:, :] - x[:, :, :, :h_x - 1, :]),
|
||||
2).sum()
|
||||
w_tv = torch.pow((x[:, :, :, :, 1:] - x[:, :, :, :, :w_x - 1]),
|
||||
2).sum()
|
||||
return 2 * (z_tv / count_z + h_tv / count_h + w_tv / count_w) / (
|
||||
batch_size * slice_num)
|
||||
|
||||
def _tensor_size(self, t):
|
||||
return t.size()[2] * t.size()[3] * t.size()[4]
|
||||
|
||||
|
||||
class LaplacianLoss_L1(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(LaplacianLoss_L1, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, slice_num = x.size()[:2]
|
||||
z_x = x.size()[2]
|
||||
h_x = x.size()[3]
|
||||
w_x = x.size()[4]
|
||||
count_z = self._tensor_size(x[:, :, 1:, :, :])
|
||||
count_h = self._tensor_size(x[:, :, :, 1:, :])
|
||||
count_w = self._tensor_size(x[:, :, :, :, 1:])
|
||||
z_tv = torch.abs((x[:, :, 1:, :, :] - x[:, :, :z_x - 1, :, :])).sum()
|
||||
h_tv = torch.abs((x[:, :, :, 1:, :] - x[:, :, :, :h_x - 1, :])).sum()
|
||||
w_tv = torch.abs((x[:, :, :, :, 1:] - x[:, :, :, :, :w_x - 1])).sum()
|
||||
return 2 * (z_tv / count_z + h_tv / count_h + w_tv / count_w) / (
|
||||
batch_size * slice_num)
|
||||
|
||||
def _tensor_size(self, t):
|
||||
return t.size()[2] * t.size()[3] * t.size()[4]
|
||||
|
||||
|
||||
class GANLoss(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
gan_mode,
|
||||
target_real_label=1.0,
|
||||
target_fake_label=0.0,
|
||||
tensor=torch.FloatTensor):
|
||||
super(GANLoss, self).__init__()
|
||||
self.real_label = target_real_label
|
||||
self.fake_label = target_fake_label
|
||||
self.real_label_tensor = None
|
||||
self.fake_label_tensor = None
|
||||
self.zero_tensor = None
|
||||
self.Tensor = tensor
|
||||
self.gan_mode = gan_mode
|
||||
if gan_mode == 'ls':
|
||||
pass
|
||||
elif gan_mode == 'original':
|
||||
pass
|
||||
elif gan_mode == 'w':
|
||||
pass
|
||||
elif gan_mode == 'hinge':
|
||||
pass
|
||||
else:
|
||||
raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
|
||||
|
||||
def get_target_tensor(self, input, target_is_real):
|
||||
if target_is_real:
|
||||
if self.real_label_tensor is None:
|
||||
self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
|
||||
self.real_label_tensor.requires_grad_(False)
|
||||
return self.real_label_tensor.expand_as(input)
|
||||
else:
|
||||
if self.fake_label_tensor is None:
|
||||
self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
|
||||
self.fake_label_tensor.requires_grad_(False)
|
||||
return self.fake_label_tensor.expand_as(input)
|
||||
|
||||
def get_zero_tensor(self, input):
|
||||
if self.zero_tensor is None:
|
||||
self.zero_tensor = self.Tensor(1).fill_(0)
|
||||
self.zero_tensor.requires_grad_(False)
|
||||
return self.zero_tensor.expand_as(input)
|
||||
|
||||
def loss(self, input, target_is_real, for_discriminator=True):
|
||||
if self.gan_mode == 'original': # cross entropy loss
|
||||
target_tensor = self.get_target_tensor(input, target_is_real)
|
||||
loss = F.binary_cross_entropy_with_logits(input, target_tensor)
|
||||
return loss
|
||||
elif self.gan_mode == 'ls':
|
||||
target_tensor = self.get_target_tensor(input, target_is_real)
|
||||
return F.mse_loss(input, target_tensor)
|
||||
elif self.gan_mode == 'hinge':
|
||||
if for_discriminator:
|
||||
if target_is_real:
|
||||
minval = torch.min(input - 1, self.get_zero_tensor(input))
|
||||
loss = -torch.mean(minval)
|
||||
else:
|
||||
minval = torch.min(-input - 1, self.get_zero_tensor(input))
|
||||
loss = -torch.mean(minval)
|
||||
else:
|
||||
assert target_is_real, "The generator's hinge loss must be aiming for real"
|
||||
loss = -torch.mean(input)
|
||||
return loss
|
||||
else:
|
||||
# wgan
|
||||
if target_is_real:
|
||||
return -input.mean()
|
||||
else:
|
||||
return input.mean()
|
||||
|
||||
def __call__(self, input, target_is_real, for_discriminator=True):
|
||||
# computing loss is a bit complicated because |input| may not be
|
||||
# a tensor, but list of tensors in case of multiscale discriminator
|
||||
if isinstance(input, list):
|
||||
loss = 0
|
||||
for pred_i in input:
|
||||
if isinstance(pred_i, list):
|
||||
pred_i = pred_i[-1]
|
||||
loss_tensor = self.loss(pred_i, target_is_real,
|
||||
for_discriminator)
|
||||
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
|
||||
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
|
||||
loss += new_loss
|
||||
return loss / len(input)
|
||||
else:
|
||||
return self.loss(input, target_is_real, for_discriminator)
|
||||
|
||||
|
||||
class BinaryDiceLoss(nn.Module):
|
||||
|
||||
def __init__(self, smooth=1, p=1, reduction='mean'):
|
||||
super(BinaryDiceLoss, self).__init__()
|
||||
self.smooth = smooth
|
||||
self.p = p
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, predict, target):
|
||||
assert predict.shape[0] == target.shape[
|
||||
0], "predict & target batch size don't match"
|
||||
predict = predict.contiguous().view(predict.shape[0], -1)
|
||||
target = target.contiguous().view(target.shape[0], -1)
|
||||
|
||||
num = torch.sum(torch.mul(predict, target), dim=1)
|
||||
den = torch.sum(predict + target, dim=1)
|
||||
|
||||
loss = 1 - (2 * num + self.smooth) / (den + self.smooth)
|
||||
|
||||
if self.reduction == 'mean':
|
||||
return loss.mean()
|
||||
elif self.reduction == 'sum':
|
||||
return loss.sum()
|
||||
elif self.reduction == 'none':
|
||||
return loss
|
||||
else:
|
||||
raise Exception('Unexpected reduction {}'.format(self.reduction))
|
||||
577
modelscope/models/cv/head_reconstruction/models/networks.py
Normal file
577
modelscope/models/cv/head_reconstruction/models/networks.py
Normal file
@@ -0,0 +1,577 @@
|
||||
# Part of the implementation is borrowed and modified from Deep3DFaceRecon_pytorch,
|
||||
# publicly available at https://github.com/sicxu/Deep3DFaceRecon_pytorch
|
||||
import os
|
||||
from typing import Any, Callable, List, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from kornia.geometry import warp_affine
|
||||
from torch import Tensor
|
||||
from torch.optim import lr_scheduler
|
||||
|
||||
try:
|
||||
from torch.hub import load_state_dict_from_url
|
||||
except ImportError:
|
||||
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
||||
|
||||
|
||||
def resize_n_crop(image, M, dsize=112):
|
||||
# image: (b, c, h, w)
|
||||
# M : (b, 2, 3)
|
||||
return warp_affine(image, M, dsize=(dsize, dsize))
|
||||
|
||||
|
||||
def filter_state_dict(state_dict, remove_name='fc'):
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
if remove_name in key:
|
||||
continue
|
||||
new_state_dict[key] = state_dict[key]
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def define_net_recon(net_recon, use_last_fc=False, init_path=None):
|
||||
return ReconNetWrapper(
|
||||
net_recon, use_last_fc=use_last_fc, init_path=init_path)
|
||||
|
||||
|
||||
def define_net_recon2(net_recon, use_last_fc=False, init_path=None):
|
||||
return ReconNetWrapper2(
|
||||
net_recon, use_last_fc=use_last_fc, init_path=init_path)
|
||||
|
||||
|
||||
class ReconNetWrapper(nn.Module):
|
||||
fc_dim = 257
|
||||
|
||||
def __init__(self, net_recon, use_last_fc=False, init_path=None):
|
||||
super(ReconNetWrapper, self).__init__()
|
||||
self.use_last_fc = use_last_fc
|
||||
if net_recon not in func_dict:
|
||||
return NotImplementedError('network [%s] is not implemented',
|
||||
net_recon)
|
||||
func, last_dim = func_dict[net_recon]
|
||||
backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)
|
||||
if init_path and os.path.isfile(init_path):
|
||||
state_dict = filter_state_dict(
|
||||
torch.load(init_path, map_location='cpu'))
|
||||
backbone.load_state_dict(state_dict)
|
||||
print('loading init net_recon %s from %s' % (net_recon, init_path))
|
||||
self.backbone = backbone
|
||||
if not use_last_fc:
|
||||
self.final_layers = nn.ModuleList([
|
||||
conv1x1(last_dim, 80, bias=True), # id layer
|
||||
conv1x1(last_dim, 64, bias=True), # exp layer
|
||||
conv1x1(last_dim, 80, bias=True), # tex layer
|
||||
conv1x1(last_dim, 3, bias=True), # angle layer
|
||||
conv1x1(last_dim, 27, bias=True), # gamma layer
|
||||
conv1x1(last_dim, 2, bias=True), # tx, ty
|
||||
conv1x1(last_dim, 1, bias=True) # tz
|
||||
])
|
||||
for m in self.final_layers:
|
||||
nn.init.constant_(m.weight, 0.)
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
if not self.use_last_fc:
|
||||
output = []
|
||||
for layer in self.final_layers:
|
||||
output.append(layer(x))
|
||||
x = torch.flatten(torch.cat(output, dim=1), 1)
|
||||
return x
|
||||
|
||||
|
||||
class ReconNetWrapper2(nn.Module):
|
||||
fc_dim = 264
|
||||
|
||||
def __init__(self, net_recon, use_last_fc=False, init_path=None):
|
||||
super(ReconNetWrapper2, self).__init__()
|
||||
self.use_last_fc = use_last_fc
|
||||
if net_recon not in func_dict:
|
||||
return NotImplementedError('network [%s] is not implemented',
|
||||
net_recon)
|
||||
func, last_dim = func_dict[net_recon]
|
||||
backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)
|
||||
if init_path and os.path.isfile(init_path):
|
||||
state_dict = filter_state_dict(
|
||||
torch.load(init_path, map_location='cpu'))
|
||||
backbone.load_state_dict(state_dict)
|
||||
print('loading init net_recon %s from %s' % (net_recon, init_path))
|
||||
self.backbone = backbone
|
||||
if not use_last_fc:
|
||||
self.final_layers2 = nn.ModuleList([
|
||||
conv1x1(last_dim, 80, bias=True), # id layer
|
||||
conv1x1(last_dim, 51, bias=True), # exp layer
|
||||
conv1x1(last_dim, 100, bias=True), # tex layer
|
||||
conv1x1(last_dim, 3, bias=True), # angle layer
|
||||
conv1x1(last_dim, 27, bias=True), # gamma layer
|
||||
conv1x1(last_dim, 2, bias=True), # tx, ty
|
||||
conv1x1(last_dim, 1, bias=True) # tz
|
||||
])
|
||||
for m in self.final_layers2:
|
||||
nn.init.constant_(m.weight, 0.)
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.backbone(x)
|
||||
if not self.use_last_fc:
|
||||
output = []
|
||||
for layer in self.final_layers2:
|
||||
output.append(layer(x))
|
||||
x = torch.flatten(torch.cat(output, dim=1), 1)
|
||||
return x
|
||||
|
||||
|
||||
# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py
|
||||
__all__ = [
|
||||
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
|
||||
'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2',
|
||||
'wide_resnet101_2'
|
||||
]
|
||||
|
||||
model_urls = {
|
||||
'resnet18':
|
||||
'https://download.pytorch.org/models/resnet18-f37072fd.pth',
|
||||
'resnet34':
|
||||
'https://download.pytorch.org/models/resnet34-b627a593.pth',
|
||||
'resnet50':
|
||||
'https://download.pytorch.org/models/resnet50-0676ba61.pth',
|
||||
'resnet101':
|
||||
'https://download.pytorch.org/models/resnet101-63fe2227.pth',
|
||||
'resnet152':
|
||||
'https://download.pytorch.org/models/resnet152-394f9c45.pth',
|
||||
'resnext50_32x4d':
|
||||
'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
||||
'resnext101_32x8d':
|
||||
'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
||||
'wide_resnet50_2':
|
||||
'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
|
||||
'wide_resnet101_2':
|
||||
'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
|
||||
}
|
||||
|
||||
|
||||
def conv3x3(in_planes: int,
|
||||
out_planes: int,
|
||||
stride: int = 1,
|
||||
groups: int = 1,
|
||||
dilation: int = 1) -> nn.Conv2d:
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes: int,
|
||||
out_planes: int,
|
||||
stride: int = 1,
|
||||
bias: bool = False) -> nn.Conv2d:
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(
|
||||
in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError(
|
||||
'BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError(
|
||||
'Dilation > 1 not supported in BasicBlock')
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion: int = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
layers: List[int],
|
||||
num_classes: int = 1000,
|
||||
zero_init_residual: bool = False,
|
||||
use_last_fc: bool = False,
|
||||
groups: int = 1,
|
||||
width_per_group: int = 64,
|
||||
replace_stride_with_dilation: Optional[List[bool]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError('replace_stride_with_dilation should be None '
|
||||
'or a 3-element tuple, got {}'.format(
|
||||
replace_stride_with_dilation))
|
||||
self.use_last_fc = use_last_fc
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(
|
||||
block,
|
||||
128,
|
||||
layers[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(
|
||||
block,
|
||||
256,
|
||||
layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(
|
||||
block,
|
||||
512,
|
||||
layers[3],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
|
||||
if self.use_last_fc:
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight,
|
||||
0) # type: ignore[arg-type]
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight,
|
||||
0) # type: ignore[arg-type]
|
||||
|
||||
def _make_layer(self,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
planes: int,
|
||||
blocks: int,
|
||||
stride: int = 1,
|
||||
dilate: bool = False) -> nn.Sequential:
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(
|
||||
block(
|
||||
self.inplanes,
|
||||
planes,
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
if self.use_last_fc:
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def _resnet(arch: str, block: Type[Union[BasicBlock,
|
||||
Bottleneck]], layers: List[int],
|
||||
pretrained: bool, progress: bool, **kwargs: Any) -> ResNet:
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
state_dict = load_state_dict_from_url(
|
||||
model_urls[arch], progress=progress)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def resnet18(pretrained: bool = False,
|
||||
progress: bool = True,
|
||||
**kwargs: Any) -> ResNet:
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained: bool = False,
|
||||
progress: bool = True,
|
||||
**kwargs: Any) -> ResNet:
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained: bool = False,
|
||||
progress: bool = True,
|
||||
**kwargs: Any) -> ResNet:
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained: bool = False,
|
||||
progress: bool = True,
|
||||
**kwargs: Any) -> ResNet:
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained: bool = False,
|
||||
progress: bool = True,
|
||||
**kwargs: Any) -> ResNet:
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained: bool = False,
|
||||
progress: bool = True,
|
||||
**kwargs: Any) -> ResNet:
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained: bool = False,
|
||||
progress: bool = True,
|
||||
**kwargs: Any) -> ResNet:
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained: bool = False,
|
||||
progress: bool = True,
|
||||
**kwargs: Any) -> ResNet:
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained: bool = False,
|
||||
progress: bool = True,
|
||||
**kwargs: Any) -> ResNet:
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
func_dict = {'resnet18': (resnet18, 512), 'resnet50': (resnet50, 2048)}
|
||||
414
modelscope/models/cv/head_reconstruction/models/nv_diffrast.py
Normal file
414
modelscope/models/cv/head_reconstruction/models/nv_diffrast.py
Normal file
@@ -0,0 +1,414 @@
|
||||
# Part of the implementation is borrowed and modified from Deep3DFaceRecon_pytorch,
|
||||
# publicly available at https://github.com/sicxu/Deep3DFaceRecon_pytorch
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import nvdiffrast
|
||||
import nvdiffrast.torch as dr
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from .losses import TVLoss, TVLoss_std
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def ndc_projection(x=0.1, n=1.0, f=50.0):
|
||||
return np.array([[n / x, 0, 0, 0], [0, n / -x, 0, 0],
|
||||
[0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],
|
||||
[0, 0, -1, 0]]).astype(np.float32)
|
||||
|
||||
|
||||
def to_image(face_shape):
|
||||
"""
|
||||
Return:
|
||||
face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
|
||||
|
||||
Parameters:
|
||||
face_shape -- torch.tensor, size (B, N, 3)
|
||||
"""
|
||||
|
||||
focal = 1015.
|
||||
center = 112.
|
||||
persc_proj = np.array([focal, 0, center, 0, focal, center, 0, 0,
|
||||
1]).reshape([3, 3]).astype(np.float32).transpose()
|
||||
|
||||
persc_proj = torch.tensor(persc_proj).to(face_shape.device)
|
||||
|
||||
face_proj = face_shape @ persc_proj
|
||||
face_proj = face_proj[..., :2] / face_proj[..., 2:]
|
||||
|
||||
return face_proj
|
||||
|
||||
|
||||
class MeshRenderer(nn.Module):
|
||||
|
||||
def __init__(self, rasterize_fov, znear=0.1, zfar=10, rasterize_size=224):
|
||||
super(MeshRenderer, self).__init__()
|
||||
|
||||
x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
|
||||
self.ndc_proj = torch.tensor(ndc_projection(
|
||||
x=x, n=znear,
|
||||
f=zfar)).matmul(torch.diag(torch.tensor([1., -1, -1, 1])))
|
||||
self.rasterize_size = rasterize_size
|
||||
self.glctx = None
|
||||
|
||||
def forward(self, vertex, tri, feat=None):
|
||||
"""
|
||||
Return:
|
||||
mask -- torch.tensor, size (B, 1, H, W)
|
||||
depth -- torch.tensor, size (B, 1, H, W)
|
||||
features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
|
||||
|
||||
Parameters:
|
||||
vertex -- torch.tensor, size (B, N, 3)
|
||||
tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
|
||||
feat(optional) -- torch.tensor, size (B, C), features
|
||||
"""
|
||||
device = vertex.device
|
||||
rsize = int(self.rasterize_size)
|
||||
ndc_proj = self.ndc_proj.to(device)
|
||||
verts_proj = to_image(vertex)
|
||||
# trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
|
||||
if vertex.shape[-1] == 3:
|
||||
vertex = torch.cat(
|
||||
[vertex, torch.ones([*vertex.shape[:2], 1]).to(device)],
|
||||
dim=-1)
|
||||
vertex[..., 1] = -vertex[..., 1]
|
||||
|
||||
vertex_ndc = vertex @ ndc_proj.t()
|
||||
if self.glctx is None:
|
||||
if nvdiffrast.__version__ == '0.2.7':
|
||||
self.glctx = dr.RasterizeGLContext(device=device)
|
||||
else:
|
||||
self.glctx = dr.RasterizeCudaContext(device=device)
|
||||
|
||||
ranges = None
|
||||
if isinstance(tri, List) or len(tri.shape) == 3:
|
||||
vum = vertex_ndc.shape[1]
|
||||
fnum = torch.tensor([f.shape[0]
|
||||
for f in tri]).unsqueeze(1).to(device)
|
||||
|
||||
print('fnum shape:{}'.format(fnum.shape))
|
||||
|
||||
fstartidx = torch.cumsum(fnum, dim=0) - fnum
|
||||
ranges = torch.cat([fstartidx, fnum],
|
||||
axis=1).type(torch.int32).cpu()
|
||||
for i in range(tri.shape[0]):
|
||||
tri[i] = tri[i] + i * vum
|
||||
vertex_ndc = torch.cat(vertex_ndc, dim=0)
|
||||
tri = torch.cat(tri, dim=0)
|
||||
|
||||
# for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
|
||||
tri = tri.type(torch.int32).contiguous()
|
||||
rast_out, _ = dr.rasterize(
|
||||
self.glctx,
|
||||
vertex_ndc.contiguous(),
|
||||
tri,
|
||||
resolution=[rsize, rsize],
|
||||
ranges=ranges)
|
||||
|
||||
depth, _ = dr.interpolate(
|
||||
vertex.reshape([-1, 4])[..., 2].unsqueeze(1).contiguous(),
|
||||
rast_out, tri)
|
||||
depth = depth.permute(0, 3, 1, 2)
|
||||
mask = (rast_out[..., 3] > 0).float().unsqueeze(1)
|
||||
depth = mask * depth
|
||||
|
||||
image = None
|
||||
|
||||
verts_x = verts_proj[0, :, 0]
|
||||
verts_y = 224 - verts_proj[0, :, 1]
|
||||
verts_int = torch.ceil(verts_proj[0]).long() # (n, 2)
|
||||
verts_xr_int = verts_int[:, 0].clamp(1, 224 - 1)
|
||||
verts_yt_int = 224 - verts_int[:, 1].clamp(2, 224)
|
||||
verts_right_float = verts_xr_int - verts_x
|
||||
verts_left_float = 1 - verts_right_float
|
||||
verts_top_float = verts_y - verts_yt_int
|
||||
verts_bottom_float = 1 - verts_top_float
|
||||
|
||||
rast_lt = rast_out[0, verts_yt_int, verts_xr_int - 1, 3]
|
||||
rast_lb = rast_out[0, verts_yt_int + 1, verts_xr_int - 1, 3]
|
||||
rast_rt = rast_out[0, verts_yt_int, verts_xr_int, 3]
|
||||
rast_rb = rast_out[0, verts_yt_int + 1, verts_xr_int, 3]
|
||||
|
||||
occ_feat = (rast_lt > 0) * 1.0 * (verts_left_float + verts_top_float) + \
|
||||
(rast_lb > 0) * 1.0 * (verts_left_float + verts_bottom_float) + \
|
||||
(rast_rt > 0) * 1.0 * (verts_right_float + verts_top_float) + \
|
||||
(rast_rb > 0) * 1.0 * (verts_right_float + verts_bottom_float)
|
||||
occ_feat = occ_feat[None, :, None] / 4.0
|
||||
|
||||
occ, _ = dr.interpolate(occ_feat, rast_out, tri)
|
||||
occ = occ.permute(0, 3, 1, 2)
|
||||
|
||||
if feat is not None:
|
||||
image, _ = dr.interpolate(feat, rast_out, tri)
|
||||
image = image.permute(0, 3, 1, 2)
|
||||
image = mask * image
|
||||
|
||||
return mask, depth, image, occ
|
||||
|
||||
def render_uv_texture(self, vertex, tri, uv, uv_texture):
|
||||
"""
|
||||
Return:
|
||||
mask -- torch.tensor, size (B, 1, H, W)
|
||||
depth -- torch.tensor, size (B, 1, H, W)
|
||||
features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
|
||||
|
||||
Parameters:
|
||||
vertex -- torch.tensor, size (B, N, 3)
|
||||
tri -- torch.tensor, size (M, 3), triangles
|
||||
uv -- torch.tensor, size (B,N, 2), uv mapping
|
||||
base_tex -- torch.tensor, size (B,H,W,C)
|
||||
"""
|
||||
device = vertex.device
|
||||
rsize = int(self.rasterize_size)
|
||||
ndc_proj = self.ndc_proj.to(device)
|
||||
# trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
|
||||
if vertex.shape[-1] == 3:
|
||||
vertex = torch.cat(
|
||||
[vertex, torch.ones([*vertex.shape[:2], 1]).to(device)],
|
||||
dim=-1)
|
||||
vertex[..., 1] = -vertex[..., 1]
|
||||
|
||||
vertex_ndc = vertex @ ndc_proj.t()
|
||||
if self.glctx is None:
|
||||
if nvdiffrast.__version__ == '0.2.7':
|
||||
self.glctx = dr.RasterizeGLContext(device=device)
|
||||
else:
|
||||
self.glctx = dr.RasterizeCudaContext(device=device)
|
||||
|
||||
ranges = None
|
||||
if isinstance(tri, List) or len(tri.shape) == 3:
|
||||
vum = vertex_ndc.shape[1]
|
||||
fnum = torch.tensor([f.shape[0]
|
||||
for f in tri]).unsqueeze(1).to(device)
|
||||
|
||||
print('fnum shape:{}'.format(fnum.shape))
|
||||
|
||||
fstartidx = torch.cumsum(fnum, dim=0) - fnum
|
||||
ranges = torch.cat([fstartidx, fnum],
|
||||
axis=1).type(torch.int32).cpu()
|
||||
for i in range(tri.shape[0]):
|
||||
tri[i] = tri[i] + i * vum
|
||||
vertex_ndc = torch.cat(vertex_ndc, dim=0)
|
||||
tri = torch.cat(tri, dim=0)
|
||||
|
||||
# for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
|
||||
tri = tri.type(torch.int32).contiguous()
|
||||
rast_out, _ = dr.rasterize(
|
||||
self.glctx,
|
||||
vertex_ndc.contiguous(),
|
||||
tri,
|
||||
resolution=[rsize, rsize],
|
||||
ranges=ranges)
|
||||
|
||||
depth, _ = dr.interpolate(
|
||||
vertex.reshape([-1, 4])[..., 2].unsqueeze(1).contiguous(),
|
||||
rast_out, tri)
|
||||
depth = depth.permute(0, 3, 1, 2)
|
||||
mask = (rast_out[..., 3] > 0).float().unsqueeze(1)
|
||||
depth = mask * depth
|
||||
uv[..., -1] = 1.0 - uv[..., -1]
|
||||
|
||||
rast_out, rast_db = dr.rasterize(
|
||||
self.glctx,
|
||||
vertex_ndc.contiguous(),
|
||||
tri,
|
||||
resolution=[rsize, rsize],
|
||||
ranges=ranges)
|
||||
|
||||
interp_out, uv_da = dr.interpolate(
|
||||
uv, rast_out, tri, rast_db, diff_attrs='all')
|
||||
|
||||
uv_texture = uv_texture.permute(0, 2, 3, 1).contiguous()
|
||||
img = dr.texture(
|
||||
uv_texture, interp_out, filter_mode='linear') # , uv_da)
|
||||
img = img * torch.clamp(rast_out[..., -1:], 0,
|
||||
1) # Mask out background.
|
||||
|
||||
image = img.permute(0, 3, 1, 2)
|
||||
|
||||
return mask, depth, image
|
||||
|
||||
def pred_shape_and_texture(self,
|
||||
vertex,
|
||||
tri,
|
||||
uv,
|
||||
target_img,
|
||||
base_tex=None):
|
||||
"""
|
||||
Return:
|
||||
mask -- torch.tensor, size (B, 1, H, W)
|
||||
depth -- torch.tensor, size (B, 1, H, W)
|
||||
features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
|
||||
|
||||
Parameters:
|
||||
vertex -- torch.tensor, size (B, N, 3)
|
||||
tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
|
||||
uv -- torch.tensor, size (B,N, 2), uv mapping
|
||||
base_tex -- torch.tensor, size (B,H,W,C)
|
||||
"""
|
||||
uv = uv.clone()
|
||||
|
||||
device = vertex.device
|
||||
rsize = int(self.rasterize_size)
|
||||
ndc_proj = self.ndc_proj.to(device)
|
||||
# trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
|
||||
if vertex.shape[-1] == 3:
|
||||
vertex = torch.cat(
|
||||
[vertex, torch.ones([*vertex.shape[:2], 1]).to(device)],
|
||||
dim=-1)
|
||||
vertex[..., 1] = -vertex[..., 1]
|
||||
|
||||
vertex_ndc = vertex @ ndc_proj.t()
|
||||
if self.glctx is None:
|
||||
if nvdiffrast.__version__ == '0.2.7':
|
||||
self.glctx = dr.RasterizeGLContext(device=device)
|
||||
else:
|
||||
self.glctx = dr.RasterizeCudaContext(device=device)
|
||||
# print("create glctx on device cuda:%d" % device.index)
|
||||
|
||||
# print('vertex_ndc shape:{}'.format(vertex_ndc.shape)) # Size([1, 35709, 4])
|
||||
# print('tri shape:{}'.format(tri.shape)) # Size([70789, 3])
|
||||
|
||||
ranges = None
|
||||
if isinstance(tri, List) or len(tri.shape) == 3:
|
||||
vum = vertex_ndc.shape[1]
|
||||
fnum = torch.tensor([f.shape[0]
|
||||
for f in tri]).unsqueeze(1).to(device)
|
||||
|
||||
# print('fnum shape:{}'.format(fnum.shape))
|
||||
|
||||
fstartidx = torch.cumsum(fnum, dim=0) - fnum
|
||||
ranges = torch.cat([fstartidx, fnum],
|
||||
axis=1).type(torch.int32).cpu()
|
||||
for i in range(tri.shape[0]):
|
||||
tri[i] = tri[i] + i * vum
|
||||
vertex_ndc = torch.cat(vertex_ndc, dim=0)
|
||||
tri = torch.cat(tri, dim=0)
|
||||
|
||||
# for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
|
||||
tri = tri.type(torch.int32).contiguous()
|
||||
rast_out, _ = dr.rasterize(
|
||||
self.glctx,
|
||||
vertex_ndc.contiguous(),
|
||||
tri,
|
||||
resolution=[rsize, rsize],
|
||||
ranges=ranges)
|
||||
|
||||
depth, _ = dr.interpolate(
|
||||
vertex.reshape([-1, 4])[..., 2].unsqueeze(1).contiguous(),
|
||||
rast_out, tri)
|
||||
depth = depth.permute(0, 3, 1, 2)
|
||||
mask = (rast_out[..., 3] > 0).float().unsqueeze(1)
|
||||
depth = mask * depth
|
||||
uv[..., -1] = 1.0 - uv[..., -1]
|
||||
|
||||
rast_out, rast_db = dr.rasterize(
|
||||
self.glctx,
|
||||
vertex_ndc.contiguous(),
|
||||
tri,
|
||||
resolution=[rsize, rsize],
|
||||
ranges=ranges)
|
||||
|
||||
interp_out, uv_da = dr.interpolate(
|
||||
uv, rast_out, tri, rast_db, diff_attrs='all')
|
||||
|
||||
mask_3c = mask.permute(0, 2, 3, 1)
|
||||
mask_3c = torch.cat((mask_3c, mask_3c, mask_3c), dim=-1)
|
||||
maskout_img = mask_3c * target_img
|
||||
mean_color = torch.sum(maskout_img, dim=(1, 2))
|
||||
valid_pixel_count = torch.sum(mask)
|
||||
|
||||
mean_color = mean_color / valid_pixel_count
|
||||
|
||||
tex = torch.zeros((1, int(128), 128, 3), dtype=torch.float32)
|
||||
# tex = torch.zeros((1, 128, 128, 3), dtype=torch.float32)
|
||||
tex[:, :, :, 0] = mean_color[0, 0]
|
||||
tex[:, :, :, 1] = mean_color[0, 1]
|
||||
tex[:, :, :, 2] = mean_color[0, 2]
|
||||
|
||||
tex = tex.cuda()
|
||||
|
||||
tex_mask = torch.zeros((1, int(2048), 2048, 3), dtype=torch.float32)
|
||||
# tex_mask = torch.zeros((1, 2048, 2048, 3), dtype=torch.float32)
|
||||
tex_mask[:, :, :, 1] = 1.0
|
||||
tex_mask = tex_mask.cuda()
|
||||
tex_mask.requires_grad = True
|
||||
tex_mask = tex_mask.contiguous()
|
||||
|
||||
criterionTV = TVLoss()
|
||||
|
||||
if base_tex is not None:
|
||||
base_tex = base_tex.cuda()
|
||||
|
||||
for tex_resolution in [64, 128, 256, 512, 1024, 2048]:
|
||||
tex = tex.detach()
|
||||
tex = tex.permute(0, 3, 1, 2)
|
||||
tex = F.interpolate(tex, (int(tex_resolution), tex_resolution))
|
||||
# tex = F.interpolate(tex, (tex_resolution, tex_resolution))
|
||||
tex = tex.permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
if base_tex is not None:
|
||||
_base_tex = base_tex.permute(0, 3, 1, 2)
|
||||
_base_tex = F.interpolate(
|
||||
_base_tex, (int(tex_resolution), tex_resolution))
|
||||
# _base_tex = F.interpolate(_base_tex, (tex_resolution, tex_resolution))
|
||||
_base_tex = _base_tex.permute(0, 2, 3, 1).contiguous()
|
||||
tex += _base_tex
|
||||
|
||||
tex.requires_grad = True
|
||||
|
||||
optim = torch.optim.Adam([tex], lr=1e-2)
|
||||
|
||||
texture_opt_iters = 200
|
||||
|
||||
if tex_resolution == 2048:
|
||||
optim_mask = torch.optim.Adam([tex_mask], lr=1e-2)
|
||||
|
||||
for i in range(int(texture_opt_iters)):
|
||||
|
||||
if tex_resolution == 2048:
|
||||
optim_mask.zero_grad()
|
||||
rendered = dr.texture(
|
||||
tex_mask, interp_out, filter_mode='linear') # , uv_da)
|
||||
rendered = rendered * torch.clamp(
|
||||
rast_out[..., -1:], 0, 1) # Mask out background.
|
||||
tex_loss = torch.mean((target_img - rendered)**2)
|
||||
|
||||
tex_loss.backward()
|
||||
optim_mask.step()
|
||||
|
||||
optim.zero_grad()
|
||||
|
||||
img = dr.texture(
|
||||
tex, interp_out, filter_mode='linear') # , uv_da)
|
||||
img = img * torch.clamp(rast_out[..., -1:], 0,
|
||||
1) # Mask out background.
|
||||
recon_loss = torch.mean((target_img - img)**2)
|
||||
|
||||
if tex_resolution < 2048:
|
||||
tv_loss = criterionTV(tex.permute(0, 3, 1, 2))
|
||||
|
||||
total_loss = recon_loss + tv_loss * 0.01
|
||||
else:
|
||||
|
||||
total_loss = recon_loss
|
||||
|
||||
total_loss.backward()
|
||||
optim.step()
|
||||
|
||||
tex_map = tex[0].detach().cpu().numpy()[..., ::-1] * 255.0
|
||||
|
||||
image = img.permute(0, 3, 1, 2)
|
||||
|
||||
tex_mask = tex_mask[0].detach().cpu().numpy() * 255.0
|
||||
tex_mask = np.where(tex_mask[..., 1] > 250, 1.0, 0.0) * np.where(
|
||||
tex_mask[..., 0] < 10, 1.0, 0) * np.where(tex_mask[..., 2] < 10,
|
||||
1.0, 0)
|
||||
tex_mask = 1.0 - tex_mask
|
||||
|
||||
return mask, depth, image, tex_map, tex_mask
|
||||
21
modelscope/models/cv/head_reconstruction/models/opt.py
Normal file
21
modelscope/models/cv/head_reconstruction/models/opt.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
bfm_folder = ''
|
||||
bfm_model = 'head_model_for_maas.mat'
|
||||
camera_d = 10.0
|
||||
center = 112.0
|
||||
focal = 1015.0
|
||||
isTrain = False
|
||||
net_recon = 'resnet50'
|
||||
phase = 'test'
|
||||
use_ddp = False
|
||||
use_last_fc = False
|
||||
z_far = 15.0
|
||||
z_near = 5.0
|
||||
lr = 0.001
|
||||
w_color = 1.92
|
||||
w_reg = 3.0e-4
|
||||
w_gamma = 10.0
|
||||
w_lm = 1.6e-3
|
||||
w_id = 1.0
|
||||
w_exp = 0.8
|
||||
w_tex = 1.7e-2
|
||||
145
modelscope/models/cv/head_reconstruction/models/tex_processor.py
Normal file
145
modelscope/models/cv/head_reconstruction/models/tex_processor.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_fade_out_mask(length, start_value, end_value, fade_start_ratio,
|
||||
fade_end_ratio):
|
||||
fade_start_ind = int(length * fade_start_ratio)
|
||||
fade_end_ind = int(length * fade_end_ratio)
|
||||
|
||||
left_part = np.array([start_value] * fade_start_ind)
|
||||
fade_part = np.linspace(start_value, end_value,
|
||||
fade_end_ind - fade_start_ind)
|
||||
len_right = length - len(left_part) - len(fade_part)
|
||||
right_part = np.array([end_value] * len_right)
|
||||
|
||||
fade_out_mask = np.concatenate([left_part, fade_part, right_part], axis=0)
|
||||
return fade_out_mask
|
||||
|
||||
|
||||
class TexProcesser():
|
||||
|
||||
def __init__(self, model_root):
|
||||
|
||||
self.tex_size = 4096
|
||||
|
||||
self.bald_tex_bg = cv2.imread(
|
||||
os.path.join(model_root,
|
||||
'assets/texture/template_bald_tex_2.jpg')).astype(
|
||||
np.float32)
|
||||
self.hair_tex_bg = cv2.imread(
|
||||
os.path.join(model_root,
|
||||
'assets/texture/template_withHair_tex.jpg')).astype(
|
||||
np.float32)
|
||||
|
||||
self.hair_mask = cv2.imread(
|
||||
os.path.join(model_root,
|
||||
'assets/texture/hair_mask_male.png'))[..., 0].astype(
|
||||
np.float32) / 255.0
|
||||
self.hair_mask = cv2.resize(self.hair_mask, (4096, 4096 + 1024))
|
||||
|
||||
front_mask = cv2.imread(
|
||||
os.path.join(model_root,
|
||||
'assets/texture/face_mask_singleview.jpg')).astype(
|
||||
np.float32) / 255
|
||||
front_mask = cv2.resize(front_mask, (1024, 1024))
|
||||
front_mask = cv2.resize(front_mask, (0, 0), fx=0.1, fy=0.1)
|
||||
front_mask = cv2.erode(front_mask,
|
||||
np.ones(shape=(7, 7), dtype=np.float32))
|
||||
front_mask = cv2.GaussianBlur(front_mask, (13, 13), 0)
|
||||
self.front_mask = cv2.resize(front_mask,
|
||||
(self.tex_size, self.tex_size))
|
||||
self.binary_front_mask = self.front_mask.copy()
|
||||
self.binary_front_mask[(self.front_mask < 0.3)
|
||||
+ (self.front_mask > 0.7)] = 0
|
||||
self.binary_front_mask[self.binary_front_mask != 0] = 1.0
|
||||
self.binary_front_mask_ = self.binary_front_mask.copy()
|
||||
self.binary_front_mask_[:int(4096 * 375 / 950)] = 0
|
||||
self.binary_front_mask_[int(4096 * 600 / 950):] = 0
|
||||
self.binary_front_mask = np.zeros((4096 + 1024, 4096, 3),
|
||||
dtype=np.float32)
|
||||
self.binary_front_mask[:4096, :] = self.binary_front_mask_
|
||||
self.front_mask_ = self.front_mask.copy()
|
||||
self.front_mask = np.zeros((4096 + 1024, 4096, 3), dtype=np.float32)
|
||||
self.front_mask[:4096, :] = self.front_mask_
|
||||
|
||||
self.fg_mask = cv2.imread(
|
||||
os.path.join(model_root,
|
||||
'assets/texture/fg_mask.png'))[..., 0].astype(
|
||||
np.float32) / 255.0
|
||||
self.fg_mask = cv2.resize(self.fg_mask, (256, 256))
|
||||
self.fg_mask = cv2.dilate(self.fg_mask,
|
||||
np.ones(shape=(13, 13), dtype=np.float32))
|
||||
self.fg_mask = cv2.blur(self.fg_mask, (27, 27), 0)
|
||||
self.fg_mask = cv2.resize(self.fg_mask, (4096, 4096 + 1024))
|
||||
self.fg_mask = self.fg_mask[..., None]
|
||||
|
||||
self.cheek_mask = cv2.imread(
|
||||
os.path.join(model_root,
|
||||
'assets/texture/cheek_area_mask.png'))[..., 0].astype(
|
||||
np.float32) / 255.0
|
||||
self.cheek_mask = cv2.resize(self.cheek_mask, (4096, 4096 + 1024))
|
||||
self.cheek_mask = self.cheek_mask[..., None]
|
||||
|
||||
self.bald_tex_bg = self.bald_tex_bg[:4096]
|
||||
self.hair_tex_bg = self.hair_tex_bg[:4096]
|
||||
self.fg_mask = self.fg_mask[:4096]
|
||||
self.hair_mask = self.hair_mask[:4096]
|
||||
self.front_mask = self.front_mask[:4096]
|
||||
self.binary_front_mask = self.binary_front_mask[:4096]
|
||||
self.front_mask_ = self.front_mask_[:4096]
|
||||
|
||||
self.cheek_mask_left = self.cheek_mask[:4096]
|
||||
self.cheek_mask_right = self.cheek_mask[:4096].copy()[:, ::-1]
|
||||
|
||||
def post_process_texture(self, tex_map, hair_tex=True):
|
||||
tex_map = cv2.resize(tex_map, (self.tex_size, self.tex_size))
|
||||
|
||||
# if hair_tex is true and there is a dark side, use the mirror texture
|
||||
if hair_tex:
|
||||
left_cheek_light_mean = np.mean(
|
||||
tex_map[self.cheek_mask_left[..., 0] == 1.0])
|
||||
right_cheek_light_mean = np.mean(
|
||||
tex_map[self.cheek_mask_right[..., 0] == 1.0])
|
||||
|
||||
tex_map_flip = tex_map[:, ::-1, :]
|
||||
w = tex_map.shape[1]
|
||||
half_w = w // 2
|
||||
if left_cheek_light_mean > right_cheek_light_mean * 1.5:
|
||||
tex_map[:, half_w:, :] = tex_map_flip[:, half_w:, :]
|
||||
elif right_cheek_light_mean > left_cheek_light_mean * 2:
|
||||
tex_map[:, :half_w, :] = tex_map_flip[:, :half_w, :]
|
||||
|
||||
# change the color of template texture
|
||||
bg_mean_rgb = np.mean(
|
||||
self.bald_tex_bg[self.binary_front_mask[..., 0] == 1.0],
|
||||
axis=0)[None, None]
|
||||
pred_tex_mean_rgb = np.mean(
|
||||
tex_map[self.binary_front_mask[..., 0] == 1.0], axis=0)[None,
|
||||
None] * 1.1
|
||||
_bald_tex_bg = self.bald_tex_bg.copy()
|
||||
_bald_tex_bg = self.bald_tex_bg + (pred_tex_mean_rgb - bg_mean_rgb)
|
||||
|
||||
if hair_tex:
|
||||
# inpaint hair
|
||||
tex_gray = cv2.cvtColor(
|
||||
tex_map.astype(np.uint8),
|
||||
cv2.COLOR_BGR2GRAY).astype(np.float32)
|
||||
hair_mask = (self.hair_mask == 1.0) * (tex_gray < 120)
|
||||
hair_bgr = np.mean(tex_map[hair_mask, :], axis=0) * 0.5
|
||||
if hair_bgr is None:
|
||||
hair_bgr = 20.0
|
||||
_bald_tex_bg[self.hair_mask == 1.0] = hair_bgr
|
||||
|
||||
# fuse
|
||||
tex_map = _bald_tex_bg * (1.
|
||||
- self.fg_mask) + tex_map * self.fg_mask
|
||||
else:
|
||||
# fuse
|
||||
tex_map = _bald_tex_bg * (
|
||||
1. - self.front_mask) + tex_map * self.front_mask
|
||||
|
||||
return tex_map
|
||||
28
modelscope/models/cv/human3d_animation/__init__.py
Normal file
28
modelscope/models/cv/human3d_animation/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .generate_skeleton import gen_skeleton_bvh
|
||||
from .utils import (read_obj, write_obj, render, rotate_x, rotate_y,
|
||||
translate, projection)
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'generate_skeleton': ['gen_skeleton_bvh'],
|
||||
'utils': [
|
||||
'read_obj', 'write_obj', 'render', 'rotate_x', 'rotate_y',
|
||||
'translate', 'projection'
|
||||
],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
184
modelscope/models/cv/human3d_animation/bvh_writer.py
Normal file
184
modelscope/models/cv/human3d_animation/bvh_writer.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .transforms import aa2quat, batch_rodrigues, mat2aa, quat2euler
|
||||
|
||||
|
||||
def write_bvh(parent,
|
||||
offset,
|
||||
rotation,
|
||||
position,
|
||||
names,
|
||||
frametime,
|
||||
order,
|
||||
path,
|
||||
endsite=None):
|
||||
file = open(path, 'w')
|
||||
frame = rotation.shape[0]
|
||||
joint_num = rotation.shape[1]
|
||||
order = order.upper()
|
||||
|
||||
file_string = 'HIERARCHY\n'
|
||||
|
||||
seq = []
|
||||
|
||||
def write_static(idx, prefix):
|
||||
nonlocal parent, offset, rotation, names
|
||||
nonlocal order, endsite, file_string, seq
|
||||
seq.append(idx)
|
||||
if idx == 0:
|
||||
name_label = 'ROOT ' + names[idx]
|
||||
channel_label = 'CHANNELS 6 Xposition Yposition Zposition \
|
||||
{}rotation {}rotation {}rotation'.format(*order)
|
||||
else:
|
||||
name_label = 'JOINT ' + names[idx]
|
||||
channel_label = 'CHANNELS 3 {}rotation {}rotation \
|
||||
{}rotation'.format(*order)
|
||||
offset_label = 'OFFSET %.6f %.6f %.6f' % (
|
||||
offset[idx][0], offset[idx][1], offset[idx][2])
|
||||
|
||||
file_string += prefix + name_label + '\n'
|
||||
file_string += prefix + '{\n'
|
||||
file_string += prefix + '\t' + offset_label + '\n'
|
||||
file_string += prefix + '\t' + channel_label + '\n'
|
||||
|
||||
has_child = False
|
||||
for y in range(idx + 1, rotation.shape[1]):
|
||||
if parent[y] == idx:
|
||||
has_child = True
|
||||
write_static(y, prefix + '\t')
|
||||
if not has_child:
|
||||
file_string += prefix + '\t' + 'End Site\n'
|
||||
file_string += prefix + '\t' + '{\n'
|
||||
file_string += prefix + '\t\t' + 'OFFSET 0 0 0\n'
|
||||
file_string += prefix + '\t' + '}\n'
|
||||
|
||||
file_string += prefix + '}\n'
|
||||
|
||||
write_static(0, '')
|
||||
|
||||
file_string += 'MOTION\n' + 'Frames: {}\n'.format(
|
||||
frame) + 'Frame Time: %.8f\n' % frametime
|
||||
for i in range(frame):
|
||||
file_string += '%.6f %.6f %.6f ' % (position[i][0], position[i][1],
|
||||
position[i][2])
|
||||
|
||||
for j in range(joint_num):
|
||||
idx = seq[j]
|
||||
file_string += '%.6f %.6f %.6f ' % (
|
||||
rotation[i][idx][0], rotation[i][idx][1], rotation[i][idx][2])
|
||||
|
||||
file_string += '\n'
|
||||
|
||||
file.write(file_string)
|
||||
return file_string
|
||||
|
||||
|
||||
class WriterWrapper:
|
||||
|
||||
def __init__(self, parents):
|
||||
self.parents = parents
|
||||
|
||||
def axis2euler(self, rot):
|
||||
rot = rot.reshape(rot.shape[0], -1, 3) # 45, 24, 3
|
||||
quat = aa2quat(rot)
|
||||
euler = quat2euler(quat, order='xyz')
|
||||
rot = euler
|
||||
return rot
|
||||
|
||||
def mapper_rot_mixamo(self, rot, n_bone):
|
||||
rot = rot.reshape(rot.shape[0], -1, 3)
|
||||
|
||||
smpl_mapper = [
|
||||
0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 17, 21, 15, 18, 22, 19,
|
||||
23, 20, 24
|
||||
]
|
||||
|
||||
if n_bone > 24:
|
||||
hand_mapper = list(range(25, 65))
|
||||
smpl_mapper += hand_mapper
|
||||
|
||||
new_rot = torch.zeros((rot.shape[0], n_bone, 3)) # n, 24, 3
|
||||
new_rot[:, :len(smpl_mapper), :] = rot[:, smpl_mapper, :]
|
||||
|
||||
return new_rot
|
||||
|
||||
def transform_rot_with_restpose(self, rot, rest_pose, node_list, n_bone):
|
||||
|
||||
rest_pose = batch_rodrigues(rest_pose.reshape(-1, 3)).reshape(
|
||||
1, n_bone, 3, 3) # N*3-> N*3*3
|
||||
|
||||
frame_num = rot.shape[0]
|
||||
rot = rot.reshape(rot.shape[0], -1, 3)
|
||||
new_rot = rot.clone()
|
||||
for k in range(frame_num):
|
||||
action_rot = batch_rodrigues(rot[k].reshape(-1, 3)).reshape(
|
||||
1, n_bone, 3, 3)
|
||||
for i in node_list:
|
||||
rot1 = rest_pose[0, i, :, :]
|
||||
rot2 = action_rot[0, i, :, :]
|
||||
nrot = torch.matmul(rot2, torch.inverse(rot1))
|
||||
nvec = mat2aa(nrot)
|
||||
new_rot[k, i, :] = nvec
|
||||
|
||||
new_rot = self.axis2euler(new_rot) # =# 45,24,3
|
||||
return new_rot
|
||||
|
||||
def transform_rot_with_stdApose(self, rot, rest_pose):
|
||||
print('transform_rot_with_stdApose')
|
||||
rot = rot.reshape(rot.shape[0], -1, 3)
|
||||
rest_pose = self.axis2euler(rest_pose)
|
||||
print(rot.shape)
|
||||
print(rest_pose.shape)
|
||||
smpl_left_arm_idx = 18
|
||||
smpl_right_arm_idx = 19
|
||||
std_arm_rot = torch.tensor([[21.7184, -4.8148, 16.3985],
|
||||
[-20.1108, 10.7190, -8.9279]])
|
||||
x = rest_pose[:, smpl_left_arm_idx:smpl_right_arm_idx + 1, :]
|
||||
delta = (x - std_arm_rot)
|
||||
rot[:, smpl_left_arm_idx:smpl_right_arm_idx + 1, :] -= delta
|
||||
return rot
|
||||
|
||||
def write(self,
|
||||
filename,
|
||||
offset,
|
||||
rot=None,
|
||||
action_loc=None,
|
||||
rest_pose=None,
|
||||
correct_arm=0): # offset: [24,3], rot:[45,72]
|
||||
if not isinstance(offset, torch.Tensor):
|
||||
offset = torch.tensor(offset)
|
||||
n_bone = offset.shape[0] # 24
|
||||
pos = offset[0].unsqueeze(0) # 1,3
|
||||
|
||||
if rot is None:
|
||||
rot = np.zeros((1, n_bone, 3))
|
||||
else: # rot: 45, 72
|
||||
if rest_pose is None:
|
||||
rot = self.mapper_rot_mixamo(rot, n_bone)
|
||||
else:
|
||||
if correct_arm == 1:
|
||||
rot = self.mapper_rot_mixamo(rot, n_bone)
|
||||
print(rot.shape)
|
||||
node_list_chage = [16, 17]
|
||||
n_bone = rot.shape[1]
|
||||
print(rot[0, 19, :])
|
||||
else:
|
||||
node_list_chage = [1, 2, 3, 6, 9, 12, 13, 14, 15, 16, 17]
|
||||
rot = self.transform_rot_with_restpose(
|
||||
rot, rest_pose, node_list_chage, n_bone)
|
||||
|
||||
rest = torch.zeros((1, n_bone * 3))
|
||||
rest = self.axis2euler(rest)
|
||||
frames_add = 1
|
||||
rest = rest.repeat(frames_add, 1, 1)
|
||||
rot = torch.cat((rest, rot), 0)
|
||||
|
||||
pos = pos.repeat(rot.shape[0], 1)
|
||||
action_len = action_loc.shape[0]
|
||||
pos[-action_len:, :] = action_loc[..., :]
|
||||
|
||||
names = ['%02d' % i for i in range(n_bone)]
|
||||
write_bvh(self.parents, offset, rot, pos, names, 0.0333, 'xyz',
|
||||
filename)
|
||||
167
modelscope/models/cv/human3d_animation/generate_skeleton.py
Normal file
167
modelscope/models/cv/human3d_animation/generate_skeleton.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .bvh_writer import WriterWrapper
|
||||
from .utils import matrix_to_axis_angle, rotation_6d_to_matrix
|
||||
|
||||
|
||||
def laod_smpl_params(pose_fname):
|
||||
with open(pose_fname, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
pose = torch.from_numpy(data['pose'])
|
||||
beta = torch.from_numpy(data['betas'])
|
||||
trans = torch.from_numpy(data['trans'])
|
||||
if 'joints' in data:
|
||||
joints = torch.from_numpy(data['joints'])
|
||||
joints = joints.reshape(1, -1, 3)
|
||||
else:
|
||||
joints = None
|
||||
trans = trans.reshape(1, 3)
|
||||
beta = beta.reshape(1, -1)[:, :10]
|
||||
pose = pose.reshape(-1, 24 * 3)
|
||||
return pose, beta, trans, joints
|
||||
|
||||
|
||||
def set_pose_param(pose, start, end):
|
||||
pose[:, start * 3:(end + 1) * 3] = 0
|
||||
return pose
|
||||
|
||||
|
||||
def load_test_anim(filename, device, mode='move'):
|
||||
anim = np.load(filename)
|
||||
anim = torch.tensor(anim, device=device, dtype=torch.float)
|
||||
poses = anim[:, :-3]
|
||||
loc = anim[:, -3:]
|
||||
if os.path.basename(filename)[:5] == 'comb_':
|
||||
loc = loc / 100
|
||||
repeat = 0
|
||||
idx = -1
|
||||
for i in range(poses.shape[0]):
|
||||
if i == 0:
|
||||
continue
|
||||
if repeat >= 5:
|
||||
idx = i
|
||||
break
|
||||
if poses[i].equal(poses[i - 1]):
|
||||
repeat += 1
|
||||
else:
|
||||
repeat = 0
|
||||
poses = poses[:idx - 5, :]
|
||||
loc = loc[:idx - 5, :]
|
||||
|
||||
if mode == 'inplace':
|
||||
loc[1:, :] = loc[0, :]
|
||||
|
||||
return poses, loc
|
||||
|
||||
|
||||
def load_syn_motion(filename, device, mode='move'):
|
||||
data = np.load(filename, allow_pickle=True).item()
|
||||
anim = data['thetas']
|
||||
n_joint, c, t = anim.shape
|
||||
|
||||
anim = torch.tensor(anim, device=device, dtype=torch.float)
|
||||
anim = anim.permute(2, 0, 1) # 180, 24, 6
|
||||
poses = anim.reshape(-1, 6)
|
||||
poses = rotation_6d_to_matrix(poses)
|
||||
poses = matrix_to_axis_angle(poses)
|
||||
poses = poses.reshape(-1, 24, 3)
|
||||
|
||||
loc = data['root_translation']
|
||||
loc = torch.tensor(loc, device=device, dtype=torch.float)
|
||||
loc = loc.permute(1, 0)
|
||||
|
||||
if mode == 'inplace':
|
||||
loc = torch.zeros((t, 3))
|
||||
|
||||
print('load %s' % filename)
|
||||
|
||||
return poses, loc
|
||||
|
||||
|
||||
def load_action(action_name,
|
||||
model_dir,
|
||||
action_dir,
|
||||
mode='move',
|
||||
device=torch.device('cpu')):
|
||||
action_path = os.path.join(action_dir, action_name + '.npy')
|
||||
if not os.path.exists(action_path):
|
||||
print('can not find action %s, use default action instead' %
|
||||
(action_name))
|
||||
action_path = os.path.join(model_dir, '3D-assets', 'SwingDancing.npy')
|
||||
print('load action %s' % action_path)
|
||||
test_pose, test_loc = load_test_anim(
|
||||
action_path, device, mode=mode) # pose:[45,72], loc:[45,1,3]
|
||||
|
||||
return test_pose, test_loc
|
||||
|
||||
|
||||
def load_action_list(action,
|
||||
model_dir,
|
||||
action_dir,
|
||||
mode='move',
|
||||
device=torch.device('cpu')):
|
||||
action_list = action.split(',')
|
||||
test_pose, test_loc = load_action(
|
||||
action_list[0], model_dir, action_dir, mode=mode, device=device)
|
||||
final_loc = test_loc[-1, :]
|
||||
idx = 0
|
||||
if len(action_list) > 1:
|
||||
for action in action_list:
|
||||
if idx == 0:
|
||||
idx += 1
|
||||
continue
|
||||
print('load action %s' % action)
|
||||
pose, loc = load_action(
|
||||
action, model_dir, action_dir, mode=mode, device=device)
|
||||
delta_loc = final_loc - loc[0, :]
|
||||
loc += delta_loc
|
||||
final_loc = loc[-1, :]
|
||||
test_pose = torch.cat([test_pose, pose], 0)
|
||||
test_loc = torch.cat([test_loc, loc], 0)
|
||||
idx += 1
|
||||
return test_pose, test_loc
|
||||
|
||||
|
||||
def gen_skeleton_bvh(model_dir, action_dir, case_dir, action, mode='move'):
|
||||
outpath_a = os.path.join(case_dir, 'skeleton_a.bvh')
|
||||
device = torch.device('cpu')
|
||||
assets_dir = os.path.join(model_dir, '3D-assets')
|
||||
pkl_path = os.path.join(assets_dir, 'smpl.pkl')
|
||||
poses, shapes, trans, joints = laod_smpl_params(pkl_path)
|
||||
if action.endswith('.npy'):
|
||||
skeleton_path = os.path.join(assets_dir, 'skeleton_nohand.npy')
|
||||
else:
|
||||
skeleton_path = os.path.join(assets_dir, 'skeleton.npy')
|
||||
data = np.load(skeleton_path, allow_pickle=True).item()
|
||||
skeleton = data['skeleton']
|
||||
parent = data['parent']
|
||||
skeleton = skeleton.squeeze(0)
|
||||
bvh_writer = WriterWrapper(parent)
|
||||
|
||||
if action.endswith('.npy'):
|
||||
action_path = action
|
||||
print('load action %s' % action_path)
|
||||
test_pose, test_loc = load_syn_motion(action_path, device, mode=mode)
|
||||
bvh_writer.write(
|
||||
outpath_a,
|
||||
skeleton,
|
||||
test_pose,
|
||||
action_loc=test_loc,
|
||||
rest_pose=poses)
|
||||
|
||||
else:
|
||||
print('load action %s' % action)
|
||||
test_pose, test_loc = load_action_list(
|
||||
action, model_dir, action_dir, mode='move', device=device)
|
||||
std_y = torch.tensor(0.99)
|
||||
test_loc = test_loc + (skeleton[0, 1] - std_y)
|
||||
bvh_writer.write(outpath_a, skeleton, test_pose, action_loc=test_loc)
|
||||
|
||||
print('save %s' % outpath_a)
|
||||
|
||||
return 0
|
||||
316
modelscope/models/cv/human3d_animation/transforms.py
Normal file
316
modelscope/models/cv/human3d_animation/transforms.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from https://github.com/facebookresearch/pytorch3d
|
||||
# All Rights Reserved.
|
||||
# ------------------------------------------------------------------------
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def batch_mm(matrix, matrix_batch):
|
||||
"""
|
||||
https://github.com/pytorch/pytorch/issues/14489#issuecomment-607730242
|
||||
:param matrix: Sparse or dense matrix, size (m, n).
|
||||
:param matrix_batch: Batched dense matrices, size (b, n, k).
|
||||
:return: The batched matrix-matrix product,
|
||||
size (m, n) x (b, n, k) = (b, m, k).
|
||||
"""
|
||||
batch_size = matrix_batch.shape[0]
|
||||
# Stack the vector batch into columns. (b, n, k) -> (n, b, k) -> (n, b*k)
|
||||
vectors = matrix_batch.transpose(0, 1).reshape(matrix.shape[1], -1)
|
||||
|
||||
# A matrix-matrix product is a batched matrix-vector
|
||||
# product of the columns.
|
||||
# And then reverse the reshaping.
|
||||
# (m, n) x (n, b*k) = (m, b*k) -> (m, b, k) -> (b, m, k)
|
||||
return matrix.mm(vectors).reshape(matrix.shape[0], batch_size,
|
||||
-1).transpose(1, 0)
|
||||
|
||||
|
||||
def aa2quat(rots, form='wxyz', unified_orient=True):
|
||||
"""
|
||||
Convert angle-axis representation to wxyz quaternion
|
||||
and to the half plan (w >= 0)
|
||||
@param rots: angle-axis rotations, (*, 3)
|
||||
@param form: quaternion format, either 'wxyz' or 'xyzw'
|
||||
@param unified_orient: Use unified orientation for quaternion
|
||||
(quaternion is dual cover of SO3)
|
||||
:return:
|
||||
"""
|
||||
angles = rots.norm(dim=-1, keepdim=True)
|
||||
norm = angles.clone()
|
||||
norm[norm < 1e-8] = 1
|
||||
axis = rots / norm
|
||||
quats = torch.empty(
|
||||
rots.shape[:-1] + (4, ), device=rots.device, dtype=rots.dtype)
|
||||
angles = angles * 0.5
|
||||
if form == 'wxyz':
|
||||
quats[..., 0] = torch.cos(angles.squeeze(-1))
|
||||
quats[..., 1:] = torch.sin(angles) * axis
|
||||
elif form == 'xyzw':
|
||||
quats[..., :3] = torch.sin(angles) * axis
|
||||
quats[..., 3] = torch.cos(angles.squeeze(-1))
|
||||
|
||||
if unified_orient:
|
||||
idx = quats[..., 0] < 0
|
||||
quats[idx, :] *= -1
|
||||
|
||||
return quats
|
||||
|
||||
|
||||
def quat2aa(quats):
|
||||
"""
|
||||
Convert wxyz quaternions to angle-axis representation
|
||||
:param quats:
|
||||
:return:
|
||||
"""
|
||||
_cos = quats[..., 0]
|
||||
xyz = quats[..., 1:]
|
||||
_sin = xyz.norm(dim=-1)
|
||||
norm = _sin.clone()
|
||||
norm[norm < 1e-7] = 1
|
||||
axis = xyz / norm.unsqueeze(-1)
|
||||
angle = torch.atan2(_sin, _cos) * 2
|
||||
return axis * angle.unsqueeze(-1)
|
||||
|
||||
|
||||
def quat2mat(quats: torch.Tensor):
|
||||
"""
|
||||
Convert (w, x, y, z) quaternions to 3x3 rotation matrix
|
||||
:param quats: quaternions of shape (..., 4)
|
||||
:return: rotation matrices of shape (..., 3, 3)
|
||||
"""
|
||||
qw = quats[..., 0]
|
||||
qx = quats[..., 1]
|
||||
qy = quats[..., 2]
|
||||
qz = quats[..., 3]
|
||||
|
||||
x2 = qx + qx
|
||||
y2 = qy + qy
|
||||
z2 = qz + qz
|
||||
xx = qx * x2
|
||||
yy = qy * y2
|
||||
wx = qw * x2
|
||||
xy = qx * y2
|
||||
yz = qy * z2
|
||||
wy = qw * y2
|
||||
xz = qx * z2
|
||||
zz = qz * z2
|
||||
wz = qw * z2
|
||||
|
||||
m = torch.empty(
|
||||
quats.shape[:-1] + (3, 3), device=quats.device, dtype=quats.dtype)
|
||||
m[..., 0, 0] = 1.0 - (yy + zz)
|
||||
m[..., 0, 1] = xy - wz
|
||||
m[..., 0, 2] = xz + wy
|
||||
m[..., 1, 0] = xy + wz
|
||||
m[..., 1, 1] = 1.0 - (xx + zz)
|
||||
m[..., 1, 2] = yz - wx
|
||||
m[..., 2, 0] = xz - wy
|
||||
m[..., 2, 1] = yz + wx
|
||||
m[..., 2, 2] = 1.0 - (xx + yy)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
def quat2euler(q, order='xyz', degrees=True):
|
||||
"""
|
||||
Convert (w, x, y, z) quaternions to xyz euler angles.
|
||||
This is used for bvh output.
|
||||
"""
|
||||
q0 = q[..., 0]
|
||||
q1 = q[..., 1]
|
||||
q2 = q[..., 2]
|
||||
q3 = q[..., 3]
|
||||
es = torch.empty(q0.shape + (3, ), device=q.device, dtype=q.dtype)
|
||||
|
||||
if order == 'xyz':
|
||||
es[..., 2] = torch.atan2(2 * (q0 * q3 - q1 * q2),
|
||||
q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
|
||||
es[..., 1] = torch.asin((2 * (q1 * q3 + q0 * q2)).clip(-1, 1))
|
||||
es[..., 0] = torch.atan2(2 * (q0 * q1 - q2 * q3),
|
||||
q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
|
||||
else:
|
||||
raise NotImplementedError('Cannot convert to ordering %s' % order)
|
||||
|
||||
if degrees:
|
||||
es = es * 180 / np.pi
|
||||
|
||||
return es
|
||||
|
||||
|
||||
def aa2mat(rots):
|
||||
"""
|
||||
Convert angle-axis representation to rotation matrix
|
||||
:param rots: angle-axis representation
|
||||
:return:
|
||||
"""
|
||||
quat = aa2quat(rots)
|
||||
mat = quat2mat(quat)
|
||||
return mat
|
||||
|
||||
|
||||
def inv_affine(mat):
|
||||
"""
|
||||
Calculate the inverse of any affine transformation
|
||||
"""
|
||||
affine = torch.zeros((mat.shape[:2] + (1, 4)))
|
||||
affine[..., 3] = 1
|
||||
vert_mat = torch.cat((mat, affine), dim=2)
|
||||
vert_mat_inv = torch.inverse(vert_mat)
|
||||
return vert_mat_inv[..., :3, :]
|
||||
|
||||
|
||||
def inv_rigid_affine(mat):
|
||||
"""
|
||||
Calculate the inverse of a rigid affine transformation
|
||||
"""
|
||||
res = mat.clone()
|
||||
res[..., :3] = mat[..., :3].transpose(-2, -1)
|
||||
res[...,
|
||||
3] = -torch.matmul(res[..., :3], mat[..., 3].unsqueeze(-1)).squeeze(-1)
|
||||
return res
|
||||
|
||||
|
||||
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns torch.sqrt(torch.max(0, x))
|
||||
but with a zero subgradient where x is 0.
|
||||
"""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
return ret
|
||||
|
||||
|
||||
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotations given as rotation matrices to quaternions.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
|
||||
Returns:
|
||||
quaternions with real part first, as tensor of shape (..., 4).
|
||||
"""
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f'Invalid rotation matrix shape {matrix.shape}.')
|
||||
|
||||
batch_dim = matrix.shape[:-2]
|
||||
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
||||
matrix.reshape(batch_dim + (9, )), dim=-1)
|
||||
|
||||
q_abs = _sqrt_positive_part(
|
||||
torch.stack(
|
||||
[
|
||||
1.0 + m00 + m11 + m22,
|
||||
1.0 + m00 - m11 - m22,
|
||||
1.0 - m00 + m11 - m22,
|
||||
1.0 - m00 - m11 + m22,
|
||||
],
|
||||
dim=-1,
|
||||
))
|
||||
|
||||
# we produce the desired quaternion multiplied by each of r, i, j, k
|
||||
quat_by_rijk = torch.stack(
|
||||
[
|
||||
torch.stack([q_abs[..., 0]**2, m21 - m12, m02 - m20, m10 - m01],
|
||||
dim=-1),
|
||||
torch.stack([m21 - m12, q_abs[..., 1]**2, m10 + m01, m02 + m20],
|
||||
dim=-1),
|
||||
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2]**2, m12 + m21],
|
||||
dim=-1),
|
||||
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3]**2],
|
||||
dim=-1),
|
||||
],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
||||
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
||||
|
||||
return quat_candidates[F.one_hot(q_abs.argmax(
|
||||
dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4, ))
|
||||
|
||||
|
||||
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotations given as quaternions to axis/angle.
|
||||
|
||||
Args:
|
||||
quaternions: quaternions with real part first,
|
||||
as tensor of shape (..., 4).
|
||||
|
||||
Returns:
|
||||
Rotations given as a vector in axis angle form, as a tensor
|
||||
of shape (..., 3), where the magnitude is the angle
|
||||
turned anticlockwise in radians around the vector's
|
||||
direction.
|
||||
"""
|
||||
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
||||
half_angles = torch.atan2(norms, quaternions[..., :1])
|
||||
angles = 2 * half_angles
|
||||
eps = 1e-6
|
||||
small_angles = angles.abs() < eps
|
||||
sin_half_angles_over_angles = torch.empty_like(angles)
|
||||
sin_half_angles_over_angles[~small_angles] = (
|
||||
torch.sin(half_angles[~small_angles]) / angles[~small_angles])
|
||||
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
||||
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
||||
sin_half_angles_over_angles[small_angles] = (
|
||||
0.5 - (angles[small_angles] * angles[small_angles]) / 48)
|
||||
return quaternions[..., 1:] / sin_half_angles_over_angles
|
||||
|
||||
|
||||
def mat2aa(matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotations given as rotation matrices to axis/angle.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
|
||||
Returns:
|
||||
Rotations given as a vector in axis angle form, as a tensor
|
||||
of shape (..., 3), where the magnitude is the angle
|
||||
turned anticlockwise in radians around the vector's
|
||||
direction.
|
||||
"""
|
||||
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
||||
|
||||
|
||||
def batch_rodrigues(rot_vecs: Tensor, epsilon: float = 1e-8) -> Tensor:
|
||||
''' Calculates the rotation matrices for a batch of rotation vectors
|
||||
Parameters
|
||||
----------
|
||||
rot_vecs: torch.tensor Nx3
|
||||
array of N axis-angle vectors
|
||||
Returns
|
||||
-------
|
||||
R: torch.tensor Nx3x3
|
||||
The rotation matrices for the given axis-angle parameters
|
||||
'''
|
||||
assert len(rot_vecs.shape) == 2, (
|
||||
f'Expects an array of size Bx3, but received {rot_vecs.shape}')
|
||||
|
||||
batch_size = rot_vecs.shape[0]
|
||||
device = rot_vecs.device
|
||||
dtype = rot_vecs.dtype
|
||||
|
||||
angle = torch.norm(rot_vecs + epsilon, dim=1, keepdim=True, p=2)
|
||||
rot_dir = rot_vecs / angle
|
||||
|
||||
cos = torch.unsqueeze(torch.cos(angle), dim=1)
|
||||
sin = torch.unsqueeze(torch.sin(angle), dim=1)
|
||||
|
||||
# Bx1 arrays
|
||||
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
|
||||
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
|
||||
|
||||
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
|
||||
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
|
||||
.view((batch_size, 3, 3))
|
||||
|
||||
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
|
||||
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
|
||||
return rot_mat
|
||||
375
modelscope/models/cv/human3d_animation/utils.py
Normal file
375
modelscope/models/cv/human3d_animation/utils.py
Normal file
@@ -0,0 +1,375 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import nvdiffrast.torch as dr
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def read_obj(obj_path, print_shape=False):
|
||||
with open(obj_path, 'r') as f:
|
||||
bfm_lines = f.readlines()
|
||||
|
||||
vertices = []
|
||||
faces = []
|
||||
uvs = []
|
||||
vns = []
|
||||
faces_uv = []
|
||||
faces_normal = []
|
||||
max_face_length = 0
|
||||
for line in bfm_lines:
|
||||
if line[:2] == 'v ':
|
||||
vertex = [
|
||||
float(a) for a in line.strip().split(' ')[1:] if len(a) > 0
|
||||
]
|
||||
vertices.append(vertex)
|
||||
|
||||
if line[:2] == 'f ':
|
||||
items = line.strip().split(' ')[1:]
|
||||
face = [int(a.split('/')[0]) for a in items if len(a) > 0]
|
||||
max_face_length = max(max_face_length, len(face))
|
||||
faces.append(face)
|
||||
|
||||
if '/' in items[0] and len(items[0].split('/')[1]) > 0:
|
||||
face_uv = [int(a.split('/')[1]) for a in items if len(a) > 0]
|
||||
faces_uv.append(face_uv)
|
||||
|
||||
if '/' in items[0] and len(items[0].split('/')) >= 3 and len(
|
||||
items[0].split('/')[2]) > 0:
|
||||
face_normal = [
|
||||
int(a.split('/')[2]) for a in items if len(a) > 0
|
||||
]
|
||||
faces_normal.append(face_normal)
|
||||
|
||||
if line[:3] == 'vt ':
|
||||
items = line.strip().split(' ')[1:]
|
||||
uv = [float(a) for a in items if len(a) > 0]
|
||||
uvs.append(uv)
|
||||
|
||||
if line[:3] == 'vn ':
|
||||
items = line.strip().split(' ')[1:]
|
||||
vn = [float(a) for a in items if len(a) > 0]
|
||||
vns.append(vn)
|
||||
|
||||
vertices = np.array(vertices).astype(np.float32)
|
||||
if max_face_length <= 3:
|
||||
faces = np.array(faces).astype(np.int32)
|
||||
else:
|
||||
print('not a triangle face mesh!')
|
||||
|
||||
if vertices.shape[1] == 3:
|
||||
mesh = {
|
||||
'vertices': vertices,
|
||||
'faces': faces,
|
||||
}
|
||||
else:
|
||||
mesh = {
|
||||
'vertices': vertices[:, :3],
|
||||
'colors': vertices[:, 3:],
|
||||
'faces': faces,
|
||||
}
|
||||
|
||||
if len(uvs) > 0:
|
||||
uvs = np.array(uvs).astype(np.float32)
|
||||
mesh['uvs'] = uvs
|
||||
|
||||
if len(vns) > 0:
|
||||
vns = np.array(vns).astype(np.float32)
|
||||
mesh['normals'] = vns
|
||||
|
||||
if len(faces_uv) > 0:
|
||||
if max_face_length <= 3:
|
||||
faces_uv = np.array(faces_uv).astype(np.int32)
|
||||
mesh['faces_uv'] = faces_uv
|
||||
|
||||
if len(faces_normal) > 0:
|
||||
if max_face_length <= 3:
|
||||
faces_normal = np.array(faces_normal).astype(np.int32)
|
||||
mesh['faces_normal'] = faces_normal
|
||||
|
||||
if print_shape:
|
||||
print('num of vertices', len(vertices))
|
||||
print('num of faces', len(faces))
|
||||
return mesh
|
||||
|
||||
|
||||
def write_obj(save_path, mesh):
|
||||
save_dir = os.path.dirname(save_path)
|
||||
save_name = os.path.splitext(os.path.basename(save_path))[0]
|
||||
|
||||
if 'texture_map' in mesh:
|
||||
cv2.imwrite(
|
||||
os.path.join(save_dir, save_name + '.png'), mesh['texture_map'])
|
||||
|
||||
with open(os.path.join(save_dir, save_name + '.mtl'), 'w') as wf:
|
||||
wf.write('newmtl material_0\n')
|
||||
wf.write('Ka 1.000000 0.000000 0.000000\n')
|
||||
wf.write('Kd 1.000000 1.000000 1.000000\n')
|
||||
wf.write('Ks 0.000000 0.000000 0.000000\n')
|
||||
wf.write('Tr 0.000000\n')
|
||||
wf.write('illum 0\n')
|
||||
wf.write('Ns 0.000000\n')
|
||||
wf.write('map_Kd {}\n'.format(save_name + '.png'))
|
||||
|
||||
with open(save_path, 'w') as wf:
|
||||
if 'texture_map' in mesh:
|
||||
wf.write('# Create by ModelScope\n')
|
||||
wf.write('mtllib ./{}.mtl\n'.format(save_name))
|
||||
|
||||
if 'colors' in mesh:
|
||||
for i, v in enumerate(mesh['vertices']):
|
||||
wf.write('v {} {} {} {} {} {}\n'.format(
|
||||
v[0], v[1], v[2], mesh['colors'][i][0],
|
||||
mesh['colors'][i][1], mesh['colors'][i][2]))
|
||||
else:
|
||||
for v in mesh['vertices']:
|
||||
wf.write('v {} {} {}\n'.format(v[0], v[1], v[2]))
|
||||
|
||||
if 'uvs' in mesh:
|
||||
for uv in mesh['uvs']:
|
||||
wf.write('vt {} {}\n'.format(uv[0], uv[1]))
|
||||
|
||||
if 'normals' in mesh:
|
||||
for vn in mesh['normals']:
|
||||
wf.write('vn {} {} {}\n'.format(vn[0], vn[1], vn[2]))
|
||||
|
||||
if 'faces' in mesh:
|
||||
for ind, face in enumerate(mesh['faces']):
|
||||
if 'faces_uv' in mesh or 'faces_normal' in mesh:
|
||||
if 'faces_uv' in mesh:
|
||||
face_uv = mesh['faces_uv'][ind]
|
||||
else:
|
||||
face_uv = face
|
||||
if 'faces_normal' in mesh:
|
||||
face_normal = mesh['faces_normal'][ind]
|
||||
else:
|
||||
face_normal = face
|
||||
row = 'f ' + ' '.join([
|
||||
'{}/{}/{}'.format(face[i], face_uv[i], face_normal[i])
|
||||
for i in range(len(face))
|
||||
]) + '\n'
|
||||
else:
|
||||
row = 'f ' + ' '.join(
|
||||
['{}'.format(face[i])
|
||||
for i in range(len(face))]) + '\n'
|
||||
wf.write(row)
|
||||
|
||||
|
||||
def projection(x=0.1, n=1.0, f=50.0):
|
||||
return np.array([[n / x, 0, 0, 0], [0, n / x, 0, 0],
|
||||
[0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],
|
||||
[0, 0, -1, 0]]).astype(np.float32)
|
||||
|
||||
|
||||
def translate(x, y, z):
|
||||
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z],
|
||||
[0, 0, 0, 1]]).astype(np.float32)
|
||||
|
||||
|
||||
def rotate_x(a):
|
||||
s, c = np.sin(a), np.cos(a)
|
||||
return np.array([[1, 0, 0, 0], [0, c, s, 0], [0, -s, c, 0],
|
||||
[0, 0, 0, 1]]).astype(np.float32)
|
||||
|
||||
|
||||
def rotate_y(a):
|
||||
s, c = np.sin(a), np.cos(a)
|
||||
return np.array([[c, 0, s, 0], [0, 1, 0, 0], [-s, 0, c, 0],
|
||||
[0, 0, 0, 1]]).astype(np.float32)
|
||||
|
||||
|
||||
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return torch.sum(x * y, -1, keepdim=True)
|
||||
|
||||
|
||||
def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
|
||||
return 2 * dot(x, n) * n - x
|
||||
|
||||
|
||||
def length(x: torch.Tensor, eps: float = 1e-20) -> torch.Tensor:
|
||||
return torch.sqrt(torch.clamp(
|
||||
dot(x, x),
|
||||
min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
|
||||
|
||||
|
||||
def safe_normalize(x: torch.Tensor, eps: float = 1e-20) -> torch.Tensor:
|
||||
return x / length(x, eps)
|
||||
|
||||
|
||||
def transform_pos(mtx, pos):
|
||||
t_mtx = torch.from_numpy(mtx).cuda() if isinstance(mtx,
|
||||
np.ndarray) else mtx
|
||||
posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).cuda()], axis=1)
|
||||
return torch.matmul(posw, t_mtx.t())[None, ...]
|
||||
|
||||
|
||||
def render(glctx, mtx, pos, pos_idx, uv, uv_idx, tex, resolution, enable_mip,
|
||||
max_mip_level):
|
||||
pos_clip = transform_pos(mtx, pos)
|
||||
rast_out, rast_out_db = dr.rasterize(
|
||||
glctx, pos_clip, pos_idx, resolution=[resolution, resolution])
|
||||
|
||||
if enable_mip:
|
||||
texc, texd = dr.interpolate(
|
||||
uv[None, ...],
|
||||
rast_out,
|
||||
uv_idx,
|
||||
rast_db=rast_out_db,
|
||||
diff_attrs='all')
|
||||
color = dr.texture(
|
||||
tex[None, ...],
|
||||
texc,
|
||||
texd,
|
||||
filter_mode='linear-mipmap-linear',
|
||||
max_mip_level=max_mip_level)
|
||||
else:
|
||||
texc, _ = dr.interpolate(uv[None, ...], rast_out, uv_idx)
|
||||
color = dr.texture(tex[None, ...], texc, filter_mode='linear')
|
||||
|
||||
pos_idx = pos_idx.type(torch.long)
|
||||
v0 = pos[pos_idx[:, 0], :]
|
||||
v1 = pos[pos_idx[:, 1], :]
|
||||
v2 = pos[pos_idx[:, 2], :]
|
||||
face_normals = safe_normalize(torch.cross(v1 - v0, v2 - v0))
|
||||
face_normal_indices = (torch.arange(
|
||||
0, face_normals.shape[0], dtype=torch.int64,
|
||||
device='cuda')[:, None]).repeat(1, 3)
|
||||
gb_geometric_normal, _ = dr.interpolate(face_normals[None, ...], rast_out,
|
||||
face_normal_indices.int())
|
||||
normal = (gb_geometric_normal + 1) * 0.5
|
||||
mask = torch.clamp(rast_out[..., -1:], 0, 1)
|
||||
color = color * mask + (1 - mask) * torch.ones_like(color)
|
||||
normal = normal * mask + (1 - mask) * torch.ones_like(normal)
|
||||
|
||||
return color, mask, normal
|
||||
|
||||
|
||||
# The following code is based on https://github.com/Mathux/ACTOR.git
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
# Check PYTORCH3D_LICENCE before use
|
||||
|
||||
|
||||
def _copysign(a, b):
|
||||
"""
|
||||
Return a tensor where each element has the absolute value taken from the,
|
||||
corresponding element of a, with sign taken from the corresponding
|
||||
element of b. This is like the standard copysign floating-point operation,
|
||||
but is not careful about negative 0 and NaN.
|
||||
|
||||
Args:
|
||||
a: source tensor.
|
||||
b: tensor whose signs will be used, of the same shape as a.
|
||||
|
||||
Returns:
|
||||
Tensor of the same shape as a with the signs of b.
|
||||
"""
|
||||
signs_differ = (a < 0) != (b < 0)
|
||||
return torch.where(signs_differ, -a, a)
|
||||
|
||||
|
||||
def _sqrt_positive_part(x):
|
||||
"""
|
||||
Returns torch.sqrt(torch.max(0, x))
|
||||
but with a zero subgradient where x is 0.
|
||||
"""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
return ret
|
||||
|
||||
|
||||
def matrix_to_quaternion(matrix):
|
||||
"""
|
||||
Convert rotations given as rotation matrices to quaternions.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
|
||||
Returns:
|
||||
quaternions with real part first, as tensor of shape (..., 4).
|
||||
"""
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.')
|
||||
m00 = matrix[..., 0, 0]
|
||||
m11 = matrix[..., 1, 1]
|
||||
m22 = matrix[..., 2, 2]
|
||||
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
|
||||
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
|
||||
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
|
||||
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
|
||||
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
|
||||
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
|
||||
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
|
||||
return torch.stack((o0, o1, o2, o3), -1)
|
||||
|
||||
|
||||
def quaternion_to_axis_angle(quaternions):
|
||||
"""
|
||||
Convert rotations given as quaternions to axis/angle.
|
||||
|
||||
Args:
|
||||
quaternions: quaternions with real part first,
|
||||
as tensor of shape (..., 4).
|
||||
|
||||
Returns:
|
||||
Rotations given as a vector in axis angle form, as a tensor
|
||||
of shape (..., 3), where the magnitude is the angle
|
||||
turned anticlockwise in radians around the vector's
|
||||
direction.
|
||||
"""
|
||||
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
||||
half_angles = torch.atan2(norms, quaternions[..., :1])
|
||||
angles = 2 * half_angles
|
||||
eps = 1e-6
|
||||
small_angles = angles.abs() < eps
|
||||
sin_half_angles_over_angles = torch.empty_like(angles)
|
||||
sin_half_angles_over_angles[~small_angles] = (
|
||||
torch.sin(half_angles[~small_angles]) / angles[~small_angles])
|
||||
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
||||
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
||||
sin_half_angles_over_angles[small_angles] = (
|
||||
0.5 - (angles[small_angles] * angles[small_angles]) / 48)
|
||||
return quaternions[..., 1:] / sin_half_angles_over_angles
|
||||
|
||||
|
||||
def matrix_to_axis_angle(matrix):
|
||||
"""
|
||||
Convert rotations given as rotation matrices to axis/angle.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
|
||||
Returns:
|
||||
Rotations given as a vector in axis angle form, as a tensor
|
||||
of shape (..., 3), where the magnitude is the angle
|
||||
turned anticlockwise in radians around the vector's
|
||||
direction.
|
||||
"""
|
||||
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
||||
|
||||
|
||||
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
||||
using Gram--Schmidt orthogonalisation per Section B of [1].
|
||||
Args:
|
||||
d6: 6D rotation representation, of size (*, 6)
|
||||
|
||||
Returns:
|
||||
batch of rotation matrices of size (*, 3, 3)
|
||||
|
||||
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
||||
On the Continuity of Rotation Representations in Neural Networks.
|
||||
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
||||
Retrieved from http://arxiv.org/abs/1812.07035
|
||||
"""
|
||||
|
||||
a1, a2 = d6[..., :3], d6[..., 3:]
|
||||
b1 = F.normalize(a1, dim=-1)
|
||||
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
||||
b2 = F.normalize(b2, dim=-1)
|
||||
b3 = torch.cross(b1, b2, dim=-1)
|
||||
return torch.stack((b1, b2, b3), dim=-2)
|
||||
22
modelscope/models/cv/image_control_3d_portrait/__init__.py
Normal file
22
modelscope/models/cv/image_control_3d_portrait/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .image_control_3d_portrait import ImageControl3dPortrait
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'image_control_3d_portrait': ['ImageControl3dPortrait']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,468 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import math
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import PIL.Image as Image
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from scipy.io import loadmat
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor, TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.face_detection.peppa_pig_face.facer import FaceAna
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.device import create_device
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .network.camera_utils import FOV_to_intrinsics, LookAtPoseSampler
|
||||
from .network.shape_utils import convert_sdf_samples_to_ply
|
||||
from .network.triplane import TriPlaneGenerator
|
||||
from .network.triplane_encoder import TriplaneEncoder
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['ImageControl3dPortrait']
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_control_3d_portrait,
|
||||
module_name=Models.image_control_3d_portrait)
|
||||
class ImageControl3dPortrait(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the image face fusion model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
|
||||
logger.info('model params:{}'.format(kwargs))
|
||||
self.neural_rendering_resolution = kwargs[
|
||||
'neural_rendering_resolution']
|
||||
self.cam_radius = kwargs['cam_radius']
|
||||
self.fov_deg = kwargs['fov_deg']
|
||||
self.truncation_psi = kwargs['truncation_psi']
|
||||
self.truncation_cutoff = kwargs['truncation_cutoff']
|
||||
self.z_dim = kwargs['z_dim']
|
||||
self.image_size = kwargs['image_size']
|
||||
self.shape_res = kwargs['shape_res']
|
||||
self.pitch_range = kwargs['pitch_range']
|
||||
self.yaw_range = kwargs['yaw_range']
|
||||
self.max_batch = kwargs['max_batch']
|
||||
self.num_frames = kwargs['num_frames']
|
||||
self.box_warp = kwargs['box_warp']
|
||||
self.save_shape = kwargs['save_shape']
|
||||
self.save_images = kwargs['save_images']
|
||||
|
||||
device = kwargs['device']
|
||||
self.device = create_device(device)
|
||||
|
||||
self.facer = FaceAna(model_dir)
|
||||
|
||||
similarity_mat_path = os.path.join(model_dir, 'BFM',
|
||||
'similarity_Lm3D_all.mat')
|
||||
self.lm3d_std = self.load_lm3d(similarity_mat_path)
|
||||
|
||||
init_model_json = os.path.join(model_dir, 'configs',
|
||||
'init_encoder.json')
|
||||
with open(init_model_json, 'r') as fr:
|
||||
init_kwargs_encoder = json.load(fr)
|
||||
encoder_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
self.model = TriplaneEncoder(**init_kwargs_encoder)
|
||||
ckpt_encoder = torch.load(encoder_path, map_location='cpu')
|
||||
model_state = self.convert_state_dict(ckpt_encoder['state_dict'])
|
||||
self.model.load_state_dict(model_state)
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
init_args_G = ()
|
||||
init_netG_json = os.path.join(model_dir, 'configs', 'init_G.json')
|
||||
with open(init_netG_json, 'r') as fr:
|
||||
init_kwargs_G = json.load(fr)
|
||||
self.netG = TriPlaneGenerator(*init_args_G, **init_kwargs_G)
|
||||
netG_path = os.path.join(model_dir, 'ffhqrebalanced512-128.pth')
|
||||
ckpt_G = torch.load(netG_path)
|
||||
self.netG.load_state_dict(ckpt_G['G_ema'], strict=False)
|
||||
self.netG.neural_rendering_resolution = self.neural_rendering_resolution
|
||||
self.netG = self.netG.to(self.device)
|
||||
self.netG.eval()
|
||||
|
||||
self.intrinsics = FOV_to_intrinsics(self.fov_deg, device=self.device)
|
||||
col, row = np.meshgrid(
|
||||
np.arange(self.image_size), np.arange(self.image_size))
|
||||
np_coord = np.stack((col, row), axis=2) / self.image_size # [0,1]
|
||||
self.coord = torch.from_numpy(np_coord.astype(
|
||||
np.float32)).unsqueeze(0).permute(0, 3, 1, 2).to(self.device)
|
||||
|
||||
self.image_transform = transforms.Compose([
|
||||
transforms.Resize((self.image_size, self.image_size)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
|
||||
])
|
||||
|
||||
logger.info('init done')
|
||||
|
||||
def convert_state_dict(self, state_dict):
|
||||
if not next(iter(state_dict)).startswith('module.'):
|
||||
return state_dict
|
||||
new_state_dict = OrderedDict()
|
||||
|
||||
split_index = 0
|
||||
for cur_key, cur_value in state_dict.items():
|
||||
if cur_key.startswith('module.model'):
|
||||
split_index = 13
|
||||
elif cur_key.startswith('module'):
|
||||
split_index = 7
|
||||
|
||||
break
|
||||
|
||||
for k, v in state_dict.items():
|
||||
name = k[split_index:]
|
||||
new_state_dict[name] = v
|
||||
return new_state_dict
|
||||
|
||||
def detect_face(self, img):
|
||||
src_h, src_w, _ = img.shape
|
||||
boxes, landmarks, _ = self.facer.run(img)
|
||||
if boxes.shape[0] == 0:
|
||||
return None
|
||||
elif boxes.shape[0] > 1:
|
||||
max_area = 0
|
||||
max_index = 0
|
||||
for i in range(boxes.shape[0]):
|
||||
bbox_width = boxes[i][2] - boxes[i][0]
|
||||
bbox_height = boxes[i][3] - boxes[i][1]
|
||||
area = int(bbox_width) * int(bbox_height)
|
||||
if area > max_area:
|
||||
max_index = i
|
||||
max_area = area
|
||||
|
||||
return landmarks[max_index]
|
||||
else:
|
||||
return landmarks[0]
|
||||
|
||||
def get_f5p(self, landmarks, np_img):
|
||||
eye_left = self.find_pupil(landmarks[36:41], np_img)
|
||||
eye_right = self.find_pupil(landmarks[42:47], np_img)
|
||||
if eye_left is None or eye_right is None:
|
||||
logger.warning(
|
||||
'cannot find 5 points with find_pupil, used mean instead.!')
|
||||
eye_left = landmarks[36:41].mean(axis=0)
|
||||
eye_right = landmarks[42:47].mean(axis=0)
|
||||
nose = landmarks[30]
|
||||
mouth_left = landmarks[48]
|
||||
mouth_right = landmarks[54]
|
||||
f5p = [[eye_left[0], eye_left[1]], [eye_right[0], eye_right[1]],
|
||||
[nose[0], nose[1]], [mouth_left[0], mouth_left[1]],
|
||||
[mouth_right[0], mouth_right[1]]]
|
||||
return np.array(f5p)
|
||||
|
||||
def find_pupil(self, landmarks, np_img):
|
||||
h, w, _ = np_img.shape
|
||||
xmax = int(landmarks[:, 0].max())
|
||||
xmin = int(landmarks[:, 0].min())
|
||||
ymax = int(landmarks[:, 1].max())
|
||||
ymin = int(landmarks[:, 1].min())
|
||||
|
||||
if ymin >= ymax or xmin >= xmax or ymin < 0 or xmin < 0 or ymax > h or xmax > w:
|
||||
return None
|
||||
eye_img_bgr = np_img[ymin:ymax, xmin:xmax, :]
|
||||
eye_img = cv2.cvtColor(eye_img_bgr, cv2.COLOR_BGR2GRAY)
|
||||
eye_img = cv2.equalizeHist(eye_img)
|
||||
n_marks = landmarks - np.array([xmin, ymin]).reshape([1, 2])
|
||||
eye_mask = cv2.fillConvexPoly(
|
||||
np.zeros_like(eye_img), n_marks.astype(np.int32), 1)
|
||||
ret, thresh = cv2.threshold(eye_img, 100, 255,
|
||||
cv2.THRESH_BINARY | cv2.THRESH_OTSU)
|
||||
thresh = (1 - thresh / 255.) * eye_mask
|
||||
cnt = 0
|
||||
xm = []
|
||||
ym = []
|
||||
for i in range(thresh.shape[0]):
|
||||
for j in range(thresh.shape[1]):
|
||||
if thresh[i, j] > 0.5:
|
||||
xm.append(j)
|
||||
ym.append(i)
|
||||
cnt += 1
|
||||
if cnt != 0:
|
||||
xm.sort()
|
||||
ym.sort()
|
||||
xm = xm[cnt // 2]
|
||||
ym = ym[cnt // 2]
|
||||
else:
|
||||
xm = thresh.shape[1] / 2
|
||||
ym = thresh.shape[0] / 2
|
||||
|
||||
return xm + xmin, ym + ymin
|
||||
|
||||
def load_lm3d(self, similarity_mat_path):
|
||||
|
||||
Lm3D = loadmat(similarity_mat_path)
|
||||
Lm3D = Lm3D['lm']
|
||||
|
||||
lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
|
||||
lm_data1 = Lm3D[lm_idx[0], :]
|
||||
lm_data2 = np.mean(Lm3D[lm_idx[[1, 2]], :], 0)
|
||||
lm_data3 = np.mean(Lm3D[lm_idx[[3, 4]], :], 0)
|
||||
lm_data4 = Lm3D[lm_idx[5], :]
|
||||
lm_data5 = Lm3D[lm_idx[6], :]
|
||||
|
||||
Lm3D = np.stack([lm_data1, lm_data2, lm_data3, lm_data4, lm_data5],
|
||||
axis=0)
|
||||
|
||||
Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
|
||||
|
||||
return Lm3D
|
||||
|
||||
def POS(self, xp, x):
|
||||
npts = xp.shape[1]
|
||||
|
||||
A = np.zeros([2 * npts, 8])
|
||||
|
||||
A[0:2 * npts - 1:2, 0:3] = x.transpose()
|
||||
A[0:2 * npts - 1:2, 3] = 1
|
||||
|
||||
A[1:2 * npts:2, 4:7] = x.transpose()
|
||||
A[1:2 * npts:2, 7] = 1
|
||||
|
||||
b = np.reshape(xp.transpose(), [2 * npts, 1])
|
||||
|
||||
k, _, _, _ = np.linalg.lstsq(A, b)
|
||||
|
||||
R1 = k[0:3]
|
||||
R2 = k[4:7]
|
||||
sTx = k[3]
|
||||
sTy = k[7]
|
||||
s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2
|
||||
t = np.stack([sTx, sTy], axis=0)
|
||||
|
||||
return t, s
|
||||
|
||||
def resize_n_crop_img(self, img, lm, t, s, target_size=224., mask=None):
|
||||
w0, h0 = img.size
|
||||
w = (w0 * s).astype(np.int32)
|
||||
h = (h0 * s).astype(np.int32)
|
||||
left = (w / 2 - target_size / 2 + float(
|
||||
(t[0] - w0 / 2) * s)).astype(np.int32)
|
||||
right = left + target_size
|
||||
up = (h / 2 - target_size / 2 + float(
|
||||
(h0 / 2 - t[1]) * s)).astype(np.int32)
|
||||
below = up + target_size
|
||||
|
||||
img = img.resize((w, h), resample=Image.BICUBIC)
|
||||
img = img.crop((left, up, right, below))
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.resize((w, h), resample=Image.BICUBIC)
|
||||
mask = mask.crop((left, up, right, below))
|
||||
|
||||
lm = np.stack([lm[:, 0] - t[0] + w0 / 2, lm[:, 1] - t[1] + h0 / 2],
|
||||
axis=1) * s
|
||||
lm = lm - np.reshape(
|
||||
np.array([(w / 2 - target_size / 2),
|
||||
(h / 2 - target_size / 2)]), [1, 2])
|
||||
|
||||
return img, lm, mask
|
||||
|
||||
def align_img(self,
|
||||
img,
|
||||
lm,
|
||||
lm3D,
|
||||
mask=None,
|
||||
target_size=224.,
|
||||
rescale_factor=102.):
|
||||
w0, h0 = img.size
|
||||
lm5p = lm
|
||||
t, s = self.POS(lm5p.transpose(), lm3D.transpose())
|
||||
s = rescale_factor / s
|
||||
|
||||
img_new, lm_new, mask_new = self.resize_n_crop_img(
|
||||
img, lm, t, s, target_size=target_size, mask=mask)
|
||||
trans_params = np.array([w0, h0, s, t[0], t[1]], dtype=object)
|
||||
|
||||
return trans_params, img_new, lm_new, mask_new
|
||||
|
||||
def crop_image(self, img, lm):
|
||||
_, H = img.size
|
||||
lm[:, -1] = H - 1 - lm[:, -1]
|
||||
|
||||
target_size = 1024.
|
||||
rescale_factor = 300
|
||||
center_crop_size = 700
|
||||
output_size = 512
|
||||
|
||||
_, im_high, _, _, = self.align_img(
|
||||
img,
|
||||
lm,
|
||||
self.lm3d_std,
|
||||
target_size=target_size,
|
||||
rescale_factor=rescale_factor)
|
||||
|
||||
left = int(im_high.size[0] / 2 - center_crop_size / 2)
|
||||
upper = int(im_high.size[1] / 2 - center_crop_size / 2)
|
||||
right = left + center_crop_size
|
||||
lower = upper + center_crop_size
|
||||
im_cropped = im_high.crop((left, upper, right, lower))
|
||||
im_cropped = im_cropped.resize((output_size, output_size),
|
||||
resample=Image.LANCZOS)
|
||||
logger.info('crop image done!')
|
||||
return im_cropped
|
||||
|
||||
def create_samples(self, N=256, voxel_origin=[0, 0, 0], cube_length=2.0):
|
||||
voxel_origin = np.array(voxel_origin) - cube_length / 2
|
||||
voxel_size = cube_length / (N - 1)
|
||||
|
||||
overall_index = torch.arange(0, N**3, 1, out=torch.LongTensor())
|
||||
samples = torch.zeros(N**3, 3)
|
||||
|
||||
samples[:, 2] = overall_index % N
|
||||
samples[:, 1] = (overall_index.float() / N) % N
|
||||
samples[:, 0] = ((overall_index.float() / N) / N) % N
|
||||
|
||||
samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
|
||||
samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
|
||||
samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]
|
||||
|
||||
return samples.unsqueeze(0), voxel_origin, voxel_size
|
||||
|
||||
def numpy_array_to_video(self, numpy_list, video_out_path):
|
||||
assert len(numpy_list) > 0
|
||||
video_height = numpy_list[0].shape[0]
|
||||
video_width = numpy_list[0].shape[1]
|
||||
|
||||
out_video_size = (video_width, video_height)
|
||||
output_video_fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
||||
video_write_capture = cv2.VideoWriter(video_out_path,
|
||||
output_video_fourcc, 30,
|
||||
out_video_size)
|
||||
|
||||
for frame in numpy_list:
|
||||
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
video_write_capture.write(frame_bgr)
|
||||
|
||||
video_write_capture.release()
|
||||
|
||||
def inference(self, image_path, save_dir):
|
||||
basename = os.path.basename(image_path).split('.')[0]
|
||||
img = Image.open(image_path).convert('RGB')
|
||||
img_array = np.array(img)
|
||||
img_bgr = img_array[:, :, ::-1]
|
||||
landmark = self.detect_face(img_array)
|
||||
if landmark is None:
|
||||
logger.warning('No face detected in the image!')
|
||||
f5p = self.get_f5p(landmark, img_bgr)
|
||||
|
||||
logger.info('f5p is:{}'.format(f5p))
|
||||
img_cropped = self.crop_image(img, f5p)
|
||||
img_cropped.save(os.path.join(save_dir, 'crop.jpg'))
|
||||
|
||||
in_image = self.image_transform(img_cropped).unsqueeze(0).to(
|
||||
self.device)
|
||||
input = torch.cat((in_image, self.coord), 1)
|
||||
|
||||
save_video_path = os.path.join(save_dir, f'{basename}.mp4')
|
||||
pred_imgs = []
|
||||
|
||||
for frame_idx in range(self.num_frames):
|
||||
cam_pivot = torch.tensor([0, 0, 0.2], device=self.device)
|
||||
|
||||
cam2world_pose = LookAtPoseSampler.sample(
|
||||
3.14 / 2 + self.yaw_range
|
||||
* np.sin(2 * 3.14 * frame_idx / self.num_frames),
|
||||
3.14 / 2 - 0.05 + self.pitch_range
|
||||
* np.cos(2 * 3.14 * frame_idx / self.num_frames),
|
||||
cam_pivot,
|
||||
radius=self.cam_radius,
|
||||
device=self.device)
|
||||
|
||||
camera_params = torch.cat([
|
||||
cam2world_pose.reshape(-1, 16),
|
||||
self.intrinsics.reshape(-1, 9)
|
||||
], 1)
|
||||
|
||||
conditioning_cam2world_pose = LookAtPoseSampler.sample(
|
||||
np.pi / 2,
|
||||
np.pi / 2,
|
||||
cam_pivot,
|
||||
radius=self.cam_radius,
|
||||
device=self.device)
|
||||
conditioning_params = torch.cat([
|
||||
conditioning_cam2world_pose.reshape(-1, 16),
|
||||
self.intrinsics.reshape(-1, 9)
|
||||
], 1)
|
||||
|
||||
z = torch.from_numpy(np.random.randn(1,
|
||||
self.z_dim)).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
ws = self.netG.mapping(
|
||||
z,
|
||||
conditioning_params,
|
||||
truncation_psi=self.truncation_psi,
|
||||
truncation_cutoff=self.truncation_cutoff)
|
||||
|
||||
planes, pred_depth, pred_feature, pred_rgb, pred_sr, _, _, _, _ = self.model(
|
||||
ws, input, camera_params, None)
|
||||
|
||||
pred_img = (pred_sr.permute(0, 2, 3, 1) * 127.5 + 128).clamp(
|
||||
0, 255).to(torch.uint8)
|
||||
pred_img = pred_img.squeeze().cpu().numpy()
|
||||
if self.save_images:
|
||||
cv2.imwrite(
|
||||
os.path.join(save_dir, '{}.jpg'.format(frame_idx)),
|
||||
pred_img[:, :, ::-1])
|
||||
pred_imgs.append(pred_img)
|
||||
|
||||
self.numpy_array_to_video(pred_imgs, save_video_path)
|
||||
|
||||
if self.save_shape:
|
||||
max_batch = 1000000
|
||||
|
||||
samples, voxel_origin, voxel_size = self.create_samples(
|
||||
N=self.shape_res,
|
||||
voxel_origin=[0, 0, 0],
|
||||
cube_length=self.box_warp)
|
||||
samples = samples.to(z.device)
|
||||
sigmas = torch.zeros((samples.shape[0], samples.shape[1], 1),
|
||||
device=z.device)
|
||||
transformed_ray_directions_expanded = torch.zeros(
|
||||
(samples.shape[0], max_batch, 3), device=z.device)
|
||||
transformed_ray_directions_expanded[..., -1] = -1
|
||||
|
||||
head = 0
|
||||
with torch.no_grad():
|
||||
while head < samples.shape[1]:
|
||||
torch.manual_seed(0)
|
||||
sigma = self.model.sample(
|
||||
samples[:, head:head + max_batch],
|
||||
transformed_ray_directions_expanded[:, :samples.
|
||||
shape[1] - head],
|
||||
planes)['sigma']
|
||||
sigmas[:, head:head + max_batch] = sigma
|
||||
head += max_batch
|
||||
|
||||
sigmas = sigmas.reshape((self.shape_res, self.shape_res,
|
||||
self.shape_res)).cpu().numpy()
|
||||
sigmas = np.flip(sigmas, 0)
|
||||
|
||||
pad = int(30 * self.shape_res / 256)
|
||||
pad_value = -1000
|
||||
sigmas[:pad] = pad_value
|
||||
sigmas[-pad:] = pad_value
|
||||
sigmas[:, :pad] = pad_value
|
||||
sigmas[:, -pad:] = pad_value
|
||||
sigmas[:, :, :pad] = pad_value
|
||||
sigmas[:, :, -pad:] = pad_value
|
||||
convert_sdf_samples_to_ply(
|
||||
np.transpose(sigmas, (2, 1, 0)), [0, 0, 0],
|
||||
1,
|
||||
os.path.join(save_dir, f'{basename}.ply'),
|
||||
level=10)
|
||||
|
||||
logger.info('model inference done')
|
||||
@@ -0,0 +1,195 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
"""
|
||||
Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .volumetric_rendering import math_utils
|
||||
|
||||
|
||||
class GaussianCameraPoseSampler:
|
||||
"""
|
||||
Samples pitch and yaw from a Gaussian distribution and returns a camera pose.
|
||||
Camera is specified as looking at the origin.
|
||||
If horizontal and vertical stddev (specified in radians) are zero, gives a
|
||||
deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean.
|
||||
The coordinate system is specified with y-up, z-forward, x-left.
|
||||
Horizontal mean is the azimuthal angle (rotation around y axis) in radians,
|
||||
vertical mean is the polar angle (angle from the y axis) in radians.
|
||||
A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2.
|
||||
|
||||
Example:
|
||||
For a camera pose looking at the origin with the camera at position [0, 0, 1]:
|
||||
cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def sample(horizontal_mean,
|
||||
vertical_mean,
|
||||
horizontal_stddev=0,
|
||||
vertical_stddev=0,
|
||||
radius=1,
|
||||
batch_size=1,
|
||||
device='cpu'):
|
||||
h = torch.randn((batch_size, 1),
|
||||
device=device) * horizontal_stddev + horizontal_mean
|
||||
v = torch.randn(
|
||||
(batch_size, 1), device=device) * vertical_stddev + vertical_mean
|
||||
v = torch.clamp(v, 1e-5, math.pi - 1e-5)
|
||||
|
||||
theta = h
|
||||
v = v / math.pi
|
||||
phi = torch.arccos(1 - 2 * v)
|
||||
|
||||
camera_origins = torch.zeros((batch_size, 3), device=device)
|
||||
|
||||
camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi
|
||||
- theta)
|
||||
camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi
|
||||
- theta)
|
||||
camera_origins[:, 1:2] = radius * torch.cos(phi)
|
||||
|
||||
forward_vectors = math_utils.normalize_vecs(-camera_origins)
|
||||
return create_cam2world_matrix(forward_vectors, camera_origins)
|
||||
|
||||
|
||||
class LookAtPoseSampler:
|
||||
"""
|
||||
Same as GaussianCameraPoseSampler, except the
|
||||
camera is specified as looking at 'lookat_position', a 3-vector.
|
||||
|
||||
Example:
|
||||
For a camera pose looking at the origin with the camera at position [0, 0, 1]:
|
||||
cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def sample(horizontal_mean,
|
||||
vertical_mean,
|
||||
lookat_position,
|
||||
horizontal_stddev=0,
|
||||
vertical_stddev=0,
|
||||
radius=1,
|
||||
batch_size=1,
|
||||
device='cpu'):
|
||||
h = torch.randn((batch_size, 1),
|
||||
device=device) * horizontal_stddev + horizontal_mean
|
||||
v = torch.randn(
|
||||
(batch_size, 1), device=device) * vertical_stddev + vertical_mean
|
||||
v = torch.clamp(v, 1e-5, math.pi - 1e-5)
|
||||
|
||||
theta = h
|
||||
v = v / math.pi
|
||||
phi = torch.arccos(1 - 2 * v)
|
||||
|
||||
camera_origins = torch.zeros((batch_size, 3), device=device)
|
||||
|
||||
camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi
|
||||
- theta)
|
||||
camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi
|
||||
- theta)
|
||||
camera_origins[:, 1:2] = radius * torch.cos(phi)
|
||||
|
||||
# forward_vectors = math_utils.normalize_vecs(-camera_origins)
|
||||
forward_vectors = math_utils.normalize_vecs(lookat_position
|
||||
- camera_origins)
|
||||
return create_cam2world_matrix(forward_vectors, camera_origins)
|
||||
|
||||
|
||||
class UniformCameraPoseSampler:
|
||||
"""
|
||||
Same as GaussianCameraPoseSampler, except the
|
||||
pose is sampled from a uniform distribution with range +-[horizontal/vertical]_stddev.
|
||||
|
||||
Example:
|
||||
For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians:
|
||||
|
||||
cam2worlds = UniformCameraPoseSampler.sample
|
||||
(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def sample(horizontal_mean,
|
||||
vertical_mean,
|
||||
horizontal_stddev=0,
|
||||
vertical_stddev=0,
|
||||
radius=1,
|
||||
batch_size=1,
|
||||
device='cpu'):
|
||||
h = (torch.rand((batch_size, 1), device=device) * 2
|
||||
- 1) * horizontal_stddev + horizontal_mean
|
||||
v = (torch.rand((batch_size, 1), device=device) * 2
|
||||
- 1) * vertical_stddev + vertical_mean
|
||||
v = torch.clamp(v, 1e-5, math.pi - 1e-5)
|
||||
|
||||
theta = h
|
||||
v = v / math.pi
|
||||
phi = torch.arccos(1 - 2 * v)
|
||||
|
||||
camera_origins = torch.zeros((batch_size, 3), device=device)
|
||||
|
||||
camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi
|
||||
- theta)
|
||||
camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi
|
||||
- theta)
|
||||
camera_origins[:, 1:2] = radius * torch.cos(phi)
|
||||
|
||||
forward_vectors = math_utils.normalize_vecs(-camera_origins)
|
||||
return create_cam2world_matrix(forward_vectors, camera_origins)
|
||||
|
||||
|
||||
def create_cam2world_matrix(forward_vector, origin):
|
||||
"""
|
||||
Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.
|
||||
Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll.
|
||||
"""
|
||||
|
||||
forward_vector = math_utils.normalize_vecs(forward_vector)
|
||||
up_vector = torch.tensor([0, 1, 0],
|
||||
dtype=torch.float,
|
||||
device=origin.device).expand_as(forward_vector)
|
||||
|
||||
right_vector = -math_utils.normalize_vecs(
|
||||
torch.cross(up_vector, forward_vector, dim=-1))
|
||||
up_vector = math_utils.normalize_vecs(
|
||||
torch.cross(forward_vector, right_vector, dim=-1))
|
||||
|
||||
rotation_matrix = torch.eye(
|
||||
4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0],
|
||||
1, 1)
|
||||
rotation_matrix[:, :3, :3] = torch.stack(
|
||||
(right_vector, up_vector, forward_vector), axis=-1)
|
||||
|
||||
translation_matrix = torch.eye(
|
||||
4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0],
|
||||
1, 1)
|
||||
translation_matrix[:, :3, 3] = origin
|
||||
cam2world = (translation_matrix @ rotation_matrix)[:, :, :]
|
||||
assert (cam2world.shape[1:] == (4, 4))
|
||||
return cam2world
|
||||
|
||||
|
||||
def FOV_to_intrinsics(fov_degrees, device='cpu'):
|
||||
"""
|
||||
Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
|
||||
Note the intrinsics are returned as normalized by image size, rather than in pixel units.
|
||||
Assumes principal point is at image center.
|
||||
"""
|
||||
|
||||
focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414))
|
||||
intrinsics = torch.tensor(
|
||||
[[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]],
|
||||
device=device)
|
||||
return intrinsics
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,65 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
"""
|
||||
Utils for extracting 3D shapes using marching cubes. Based on code from DeepSDF (Park et al.)
|
||||
|
||||
Takes as input an .mrc file and extracts a mesh.
|
||||
|
||||
Ex.
|
||||
python shape_utils.py my_shape.mrc
|
||||
Ex.
|
||||
python shape_utils.py myshapes_directory --level=12
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import plyfile
|
||||
import skimage.measure
|
||||
|
||||
|
||||
def convert_sdf_samples_to_ply(numpy_3d_sdf_tensor,
|
||||
voxel_grid_origin,
|
||||
voxel_size,
|
||||
ply_filename_out,
|
||||
offset=None,
|
||||
scale=None,
|
||||
level=0.0):
|
||||
|
||||
verts, faces, normals, values = skimage.measure.marching_cubes(
|
||||
numpy_3d_sdf_tensor, level=level, spacing=[voxel_size] * 3)
|
||||
mesh_points = np.zeros_like(verts)
|
||||
mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
|
||||
mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
|
||||
mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]
|
||||
|
||||
if scale is not None:
|
||||
mesh_points = mesh_points / scale
|
||||
if offset is not None:
|
||||
mesh_points = mesh_points - offset
|
||||
|
||||
num_verts = verts.shape[0]
|
||||
num_faces = faces.shape[0]
|
||||
|
||||
verts_tuple = np.zeros((num_verts, ),
|
||||
dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
|
||||
|
||||
for i in range(0, num_verts):
|
||||
verts_tuple[i] = tuple(mesh_points[i, :])
|
||||
|
||||
faces_building = []
|
||||
for i in range(0, num_faces):
|
||||
faces_building.append(((faces[i, :].tolist(), )))
|
||||
faces_tuple = np.array(
|
||||
faces_building, dtype=[('vertex_indices', 'i4', (3, ))])
|
||||
|
||||
el_verts = plyfile.PlyElement.describe(verts_tuple, 'vertex')
|
||||
el_faces = plyfile.PlyElement.describe(faces_tuple, 'face')
|
||||
|
||||
ply_data = plyfile.PlyData([el_verts, el_faces])
|
||||
ply_data.write(ply_filename_out)
|
||||
@@ -0,0 +1,493 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
"""Superresolution network architectures from the paper
|
||||
"Efficient Geometry-aware 3D Generative Adversarial Networks"."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.ops.image_control_3d_portrait.torch_utils import (misc,
|
||||
persistence)
|
||||
from modelscope.ops.image_control_3d_portrait.torch_utils.ops import upfirdn2d
|
||||
from .networks_stylegan2 import (Conv2dLayer, SynthesisBlock, SynthesisLayer,
|
||||
ToRGBLayer)
|
||||
|
||||
|
||||
# for 512x512 generation
|
||||
@persistence.persistent_class
|
||||
class SuperresolutionHybrid8X(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
img_resolution,
|
||||
sr_num_fp16_res,
|
||||
sr_antialias,
|
||||
num_fp16_res=4,
|
||||
conv_clamp=None,
|
||||
channel_base=None,
|
||||
channel_max=None, # IGNORE
|
||||
**block_kwargs):
|
||||
super().__init__()
|
||||
assert img_resolution == 512
|
||||
|
||||
use_fp16 = sr_num_fp16_res > 0
|
||||
self.input_resolution = 128
|
||||
self.sr_antialias = sr_antialias
|
||||
self.block0 = SynthesisBlock(
|
||||
channels,
|
||||
128,
|
||||
w_dim=512,
|
||||
resolution=256,
|
||||
img_channels=3,
|
||||
is_last=False,
|
||||
use_fp16=use_fp16,
|
||||
conv_clamp=(256 if use_fp16 else None),
|
||||
**block_kwargs)
|
||||
self.block1 = SynthesisBlock(
|
||||
128,
|
||||
64,
|
||||
w_dim=512,
|
||||
resolution=512,
|
||||
img_channels=3,
|
||||
is_last=True,
|
||||
use_fp16=use_fp16,
|
||||
conv_clamp=(256 if use_fp16 else None),
|
||||
**block_kwargs)
|
||||
self.register_buffer('resample_filter',
|
||||
upfirdn2d.setup_filter([1, 3, 3, 1]))
|
||||
|
||||
def forward(self, rgb, x, ws, **block_kwargs):
|
||||
ws = ws[:, -1:, :].repeat(1, 3, 1)
|
||||
|
||||
if x.shape[-1] != self.input_resolution:
|
||||
x = torch.nn.functional.interpolate(
|
||||
x,
|
||||
size=(self.input_resolution, self.input_resolution),
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
antialias=self.sr_antialias)
|
||||
rgb = torch.nn.functional.interpolate(
|
||||
rgb,
|
||||
size=(self.input_resolution, self.input_resolution),
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
antialias=self.sr_antialias)
|
||||
|
||||
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
|
||||
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
|
||||
return rgb
|
||||
|
||||
|
||||
# for 256x256 generation
|
||||
@persistence.persistent_class
|
||||
class SuperresolutionHybrid4X(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
img_resolution,
|
||||
sr_num_fp16_res,
|
||||
sr_antialias,
|
||||
num_fp16_res=4,
|
||||
conv_clamp=None,
|
||||
channel_base=None,
|
||||
channel_max=None, # IGNORE
|
||||
**block_kwargs):
|
||||
super().__init__()
|
||||
assert img_resolution == 256
|
||||
use_fp16 = sr_num_fp16_res > 0
|
||||
self.sr_antialias = sr_antialias
|
||||
self.input_resolution = 128
|
||||
self.block0 = SynthesisBlockNoUp(
|
||||
channels,
|
||||
128,
|
||||
w_dim=512,
|
||||
resolution=128,
|
||||
img_channels=3,
|
||||
is_last=False,
|
||||
use_fp16=use_fp16,
|
||||
conv_clamp=(256 if use_fp16 else None),
|
||||
**block_kwargs)
|
||||
self.block1 = SynthesisBlock(
|
||||
128,
|
||||
64,
|
||||
w_dim=512,
|
||||
resolution=256,
|
||||
img_channels=3,
|
||||
is_last=True,
|
||||
use_fp16=use_fp16,
|
||||
conv_clamp=(256 if use_fp16 else None),
|
||||
**block_kwargs)
|
||||
self.register_buffer('resample_filter',
|
||||
upfirdn2d.setup_filter([1, 3, 3, 1]))
|
||||
|
||||
def forward(self, rgb, x, ws, **block_kwargs):
|
||||
ws = ws[:, -1:, :].repeat(1, 3, 1)
|
||||
|
||||
if x.shape[-1] < self.input_resolution:
|
||||
x = torch.nn.functional.interpolate(
|
||||
x,
|
||||
size=(self.input_resolution, self.input_resolution),
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
antialias=self.sr_antialias)
|
||||
rgb = torch.nn.functional.interpolate(
|
||||
rgb,
|
||||
size=(self.input_resolution, self.input_resolution),
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
antialias=self.sr_antialias)
|
||||
|
||||
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
|
||||
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
|
||||
return rgb
|
||||
|
||||
|
||||
# for 128 x 128 generation
|
||||
@persistence.persistent_class
|
||||
class SuperresolutionHybrid2X(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
img_resolution,
|
||||
sr_num_fp16_res,
|
||||
sr_antialias,
|
||||
num_fp16_res=4,
|
||||
conv_clamp=None,
|
||||
channel_base=None,
|
||||
channel_max=None, # IGNORE
|
||||
**block_kwargs):
|
||||
super().__init__()
|
||||
assert img_resolution == 128
|
||||
|
||||
use_fp16 = sr_num_fp16_res > 0
|
||||
self.input_resolution = 64
|
||||
self.sr_antialias = sr_antialias
|
||||
self.block0 = SynthesisBlockNoUp(
|
||||
channels,
|
||||
128,
|
||||
w_dim=512,
|
||||
resolution=64,
|
||||
img_channels=3,
|
||||
is_last=False,
|
||||
use_fp16=use_fp16,
|
||||
conv_clamp=(256 if use_fp16 else None),
|
||||
**block_kwargs)
|
||||
self.block1 = SynthesisBlock(
|
||||
128,
|
||||
64,
|
||||
w_dim=512,
|
||||
resolution=128,
|
||||
img_channels=3,
|
||||
is_last=True,
|
||||
use_fp16=use_fp16,
|
||||
conv_clamp=(256 if use_fp16 else None),
|
||||
**block_kwargs)
|
||||
self.register_buffer('resample_filter',
|
||||
upfirdn2d.setup_filter([1, 3, 3, 1]))
|
||||
|
||||
def forward(self, rgb, x, ws, **block_kwargs):
|
||||
ws = ws[:, -1:, :].repeat(1, 3, 1)
|
||||
|
||||
if x.shape[-1] != self.input_resolution:
|
||||
x = torch.nn.functional.interpolate(
|
||||
x,
|
||||
size=(self.input_resolution, self.input_resolution),
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
antialias=self.sr_antialias)
|
||||
rgb = torch.nn.functional.interpolate(
|
||||
rgb,
|
||||
size=(self.input_resolution, self.input_resolution),
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
antialias=self.sr_antialias)
|
||||
|
||||
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
|
||||
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
|
||||
return rgb
|
||||
|
||||
|
||||
# TODO: Delete (here for backwards compatibility with old 256x256 models)
|
||||
@persistence.persistent_class
|
||||
class SuperresolutionHybridDeepfp32(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
img_resolution,
|
||||
sr_num_fp16_res,
|
||||
num_fp16_res=4,
|
||||
conv_clamp=None,
|
||||
channel_base=None,
|
||||
channel_max=None, # IGNORE
|
||||
**block_kwargs):
|
||||
super().__init__()
|
||||
assert img_resolution == 256
|
||||
use_fp16 = sr_num_fp16_res > 0
|
||||
|
||||
self.input_resolution = 128
|
||||
self.block0 = SynthesisBlockNoUp(
|
||||
channels,
|
||||
128,
|
||||
w_dim=512,
|
||||
resolution=128,
|
||||
img_channels=3,
|
||||
is_last=False,
|
||||
use_fp16=use_fp16,
|
||||
conv_clamp=(256 if use_fp16 else None),
|
||||
**block_kwargs)
|
||||
self.block1 = SynthesisBlock(
|
||||
128,
|
||||
64,
|
||||
w_dim=512,
|
||||
resolution=256,
|
||||
img_channels=3,
|
||||
is_last=True,
|
||||
use_fp16=use_fp16,
|
||||
conv_clamp=(256 if use_fp16 else None),
|
||||
**block_kwargs)
|
||||
self.register_buffer('resample_filter',
|
||||
upfirdn2d.setup_filter([1, 3, 3, 1]))
|
||||
|
||||
def forward(self, rgb, x, ws, **block_kwargs):
|
||||
ws = ws[:, -1:, :].repeat(1, 3, 1)
|
||||
|
||||
if x.shape[-1] < self.input_resolution:
|
||||
x = torch.nn.functional.interpolate(
|
||||
x,
|
||||
size=(self.input_resolution, self.input_resolution),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
rgb = torch.nn.functional.interpolate(
|
||||
rgb,
|
||||
size=(self.input_resolution, self.input_resolution),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
|
||||
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
|
||||
return rgb
|
||||
|
||||
|
||||
@persistence.persistent_class
|
||||
class SynthesisBlockNoUp(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels, # Number of input channels, 0 = first block.
|
||||
out_channels, # Number of output channels.
|
||||
w_dim, # Intermediate latent (W) dimensionality.
|
||||
resolution, # Resolution of this block.
|
||||
img_channels, # Number of output color channels.
|
||||
is_last, # Is this the last block?
|
||||
architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
|
||||
resample_filter=[
|
||||
1, 3, 3, 1
|
||||
], # Low-pass filter to apply when resampling activations.
|
||||
conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
||||
use_fp16=False, # Use FP16 for this block?
|
||||
fp16_channels_last=False, # Use channels-last memory format with FP16?
|
||||
fused_modconv_default=True, # Default value of fused_modconv.
|
||||
**layer_kwargs, # Arguments for SynthesisLayer.
|
||||
):
|
||||
assert architecture in ['orig', 'skip', 'resnet']
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.w_dim = w_dim
|
||||
self.resolution = resolution
|
||||
self.img_channels = img_channels
|
||||
self.is_last = is_last
|
||||
self.architecture = architecture
|
||||
self.use_fp16 = use_fp16
|
||||
self.channels_last = (use_fp16 and fp16_channels_last)
|
||||
self.fused_modconv_default = fused_modconv_default
|
||||
self.register_buffer('resample_filter',
|
||||
upfirdn2d.setup_filter(resample_filter))
|
||||
self.num_conv = 0
|
||||
self.num_torgb = 0
|
||||
|
||||
if in_channels == 0:
|
||||
self.const = torch.nn.Parameter(
|
||||
torch.randn([out_channels, resolution, resolution]))
|
||||
|
||||
if in_channels != 0:
|
||||
self.conv0 = SynthesisLayer(
|
||||
in_channels,
|
||||
out_channels,
|
||||
w_dim=w_dim,
|
||||
resolution=resolution,
|
||||
conv_clamp=conv_clamp,
|
||||
channels_last=self.channels_last,
|
||||
**layer_kwargs)
|
||||
self.num_conv += 1
|
||||
|
||||
self.conv1 = SynthesisLayer(
|
||||
out_channels,
|
||||
out_channels,
|
||||
w_dim=w_dim,
|
||||
resolution=resolution,
|
||||
conv_clamp=conv_clamp,
|
||||
channels_last=self.channels_last,
|
||||
**layer_kwargs)
|
||||
self.num_conv += 1
|
||||
|
||||
if is_last or architecture == 'skip':
|
||||
self.torgb = ToRGBLayer(
|
||||
out_channels,
|
||||
img_channels,
|
||||
w_dim=w_dim,
|
||||
conv_clamp=conv_clamp,
|
||||
channels_last=self.channels_last)
|
||||
self.num_torgb += 1
|
||||
|
||||
if in_channels != 0 and architecture == 'resnet':
|
||||
self.skip = Conv2dLayer(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
up=2,
|
||||
resample_filter=resample_filter,
|
||||
channels_last=self.channels_last)
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
img,
|
||||
ws,
|
||||
force_fp32=False,
|
||||
fused_modconv=None,
|
||||
update_emas=False,
|
||||
**layer_kwargs):
|
||||
_ = update_emas # unused
|
||||
misc.assert_shape(ws,
|
||||
[None, self.num_conv + self.num_torgb, self.w_dim])
|
||||
w_iter = iter(ws.unbind(dim=1))
|
||||
if ws.device.type != 'cuda':
|
||||
force_fp32 = True
|
||||
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
|
||||
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
|
||||
if fused_modconv is None:
|
||||
fused_modconv = self.fused_modconv_default
|
||||
if fused_modconv == 'inference_only':
|
||||
fused_modconv = (not self.training)
|
||||
|
||||
# Input.
|
||||
if self.in_channels == 0:
|
||||
x = self.const.to(dtype=dtype, memory_format=memory_format)
|
||||
x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
|
||||
else:
|
||||
misc.assert_shape(
|
||||
x, [None, self.in_channels, self.resolution, self.resolution])
|
||||
x = x.to(dtype=dtype, memory_format=memory_format)
|
||||
|
||||
# Main layers.
|
||||
if self.in_channels == 0:
|
||||
x = self.conv1(
|
||||
x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
||||
elif self.architecture == 'resnet':
|
||||
y = self.skip(x, gain=np.sqrt(0.5))
|
||||
x = self.conv0(
|
||||
x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
||||
x = self.conv1(
|
||||
x,
|
||||
next(w_iter),
|
||||
fused_modconv=fused_modconv,
|
||||
gain=np.sqrt(0.5),
|
||||
**layer_kwargs)
|
||||
x = y.add_(x)
|
||||
else:
|
||||
x = self.conv0(
|
||||
x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
||||
x = self.conv1(
|
||||
x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
|
||||
|
||||
# ToRGB.
|
||||
# if img is not None:
|
||||
# misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
|
||||
# img = upfirdn2d.upsample2d(img, self.resample_filter)
|
||||
if self.is_last or self.architecture == 'skip':
|
||||
y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
|
||||
y = y.to(
|
||||
dtype=torch.float32, memory_format=torch.contiguous_format)
|
||||
img = img.add_(y) if img is not None else y
|
||||
|
||||
assert x.dtype == dtype
|
||||
assert img is None or img.dtype == torch.float32
|
||||
return x, img
|
||||
|
||||
def extra_repr(self):
|
||||
return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
|
||||
|
||||
|
||||
# for 512x512 generation
|
||||
@persistence.persistent_class
|
||||
class SuperresolutionHybrid8XDC(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
img_resolution,
|
||||
sr_num_fp16_res,
|
||||
sr_antialias,
|
||||
num_fp16_res=4,
|
||||
conv_clamp=None,
|
||||
channel_base=None,
|
||||
channel_max=None, # IGNORE
|
||||
**block_kwargs):
|
||||
super().__init__()
|
||||
assert img_resolution == 512
|
||||
|
||||
use_fp16 = sr_num_fp16_res > 0
|
||||
self.input_resolution = 128
|
||||
self.sr_antialias = sr_antialias
|
||||
self.block0 = SynthesisBlock(
|
||||
channels,
|
||||
256,
|
||||
w_dim=512,
|
||||
resolution=256,
|
||||
img_channels=3,
|
||||
is_last=False,
|
||||
use_fp16=use_fp16,
|
||||
conv_clamp=(256 if use_fp16 else None),
|
||||
**block_kwargs)
|
||||
self.block1 = SynthesisBlock(
|
||||
256,
|
||||
128,
|
||||
w_dim=512,
|
||||
resolution=512,
|
||||
img_channels=3,
|
||||
is_last=True,
|
||||
use_fp16=use_fp16,
|
||||
conv_clamp=(256 if use_fp16 else None),
|
||||
**block_kwargs)
|
||||
|
||||
def forward(self, rgb, x, ws, **block_kwargs):
|
||||
ws = ws[:, -1:, :].repeat(1, 3, 1)
|
||||
|
||||
if x.shape[-1] != self.input_resolution:
|
||||
x = torch.nn.functional.interpolate(
|
||||
x,
|
||||
size=(self.input_resolution, self.input_resolution),
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
antialias=self.sr_antialias)
|
||||
rgb = torch.nn.functional.interpolate(
|
||||
rgb,
|
||||
size=(self.input_resolution, self.input_resolution),
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
antialias=self.sr_antialias)
|
||||
|
||||
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
|
||||
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
|
||||
return rgb
|
||||
@@ -0,0 +1,242 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.ops.image_control_3d_portrait.torch_utils import persistence
|
||||
from .networks_stylegan2 import FullyConnectedLayer
|
||||
from .networks_stylegan2 import Generator as StyleGAN2Backbone
|
||||
from .superresolution import SuperresolutionHybrid8XDC
|
||||
from .volumetric_rendering.ray_sampler import RaySampler
|
||||
from .volumetric_rendering.renderer import ImportanceRenderer
|
||||
|
||||
|
||||
@persistence.persistent_class
|
||||
class TriPlaneGenerator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
z_dim, # Input latent (Z) dimensionality.
|
||||
c_dim, # Conditioning label (C) dimensionality.
|
||||
w_dim, # Intermediate latent (W) dimensionality.
|
||||
img_resolution, # Output resolution.
|
||||
img_channels, # Number of output color channels.
|
||||
sr_num_fp16_res=0,
|
||||
mapping_kwargs={}, # Arguments for MappingNetwork.
|
||||
rendering_kwargs={},
|
||||
sr_kwargs={},
|
||||
**synthesis_kwargs, # Arguments for SynthesisNetwork.
|
||||
):
|
||||
super().__init__()
|
||||
self.z_dim = z_dim
|
||||
self.c_dim = c_dim
|
||||
self.w_dim = w_dim
|
||||
self.img_resolution = img_resolution
|
||||
self.img_channels = img_channels
|
||||
self.renderer = ImportanceRenderer()
|
||||
self.ray_sampler = RaySampler()
|
||||
self.backbone = StyleGAN2Backbone(
|
||||
z_dim,
|
||||
c_dim,
|
||||
w_dim,
|
||||
img_resolution=256,
|
||||
img_channels=32 * 3,
|
||||
mapping_kwargs=mapping_kwargs,
|
||||
**synthesis_kwargs)
|
||||
self.superresolution = SuperresolutionHybrid8XDC(
|
||||
channels=32,
|
||||
img_resolution=img_resolution,
|
||||
sr_num_fp16_res=sr_num_fp16_res,
|
||||
sr_antialias=rendering_kwargs['sr_antialias'],
|
||||
**sr_kwargs)
|
||||
self.decoder = OSGDecoder(
|
||||
32, {
|
||||
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
|
||||
'decoder_output_dim': 32
|
||||
})
|
||||
self.neural_rendering_resolution = 64
|
||||
self.rendering_kwargs = rendering_kwargs
|
||||
|
||||
self._last_planes = None
|
||||
|
||||
def mapping(self,
|
||||
z,
|
||||
c,
|
||||
truncation_psi=1,
|
||||
truncation_cutoff=None,
|
||||
update_emas=False):
|
||||
if self.rendering_kwargs['c_gen_conditioning_zero']:
|
||||
c = torch.zeros_like(c)
|
||||
return self.backbone.mapping(
|
||||
z,
|
||||
c * self.rendering_kwargs.get('c_scale', 0),
|
||||
truncation_psi=truncation_psi,
|
||||
truncation_cutoff=truncation_cutoff,
|
||||
update_emas=update_emas)
|
||||
|
||||
def synthesis(self,
|
||||
ws,
|
||||
c,
|
||||
neural_rendering_resolution=None,
|
||||
update_emas=False,
|
||||
cache_backbone=False,
|
||||
use_cached_backbone=False,
|
||||
**synthesis_kwargs):
|
||||
cam2world_matrix = c[:, :16].view(-1, 4, 4)
|
||||
intrinsics = c[:, 16:25].view(-1, 3, 3)
|
||||
|
||||
if neural_rendering_resolution is None:
|
||||
neural_rendering_resolution = self.neural_rendering_resolution
|
||||
else:
|
||||
self.neural_rendering_resolution = neural_rendering_resolution
|
||||
|
||||
# Create a batch of rays for volume rendering
|
||||
ray_origins, ray_directions = self.ray_sampler(
|
||||
cam2world_matrix, intrinsics, neural_rendering_resolution)
|
||||
|
||||
# Create triplanes by running StyleGAN backbone
|
||||
N, M, _ = ray_origins.shape
|
||||
if use_cached_backbone and self._last_planes is not None:
|
||||
planes = self._last_planes
|
||||
else:
|
||||
planes = self.backbone.synthesis(
|
||||
ws, update_emas=update_emas, **synthesis_kwargs)
|
||||
if cache_backbone:
|
||||
self._last_planes = planes
|
||||
|
||||
# Reshape output into three 32-channel planes
|
||||
planes = planes.view(
|
||||
len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
|
||||
|
||||
# Perform volume rendering
|
||||
feature_samples, depth_samples, weights_samples = self.renderer(
|
||||
planes, self.decoder, ray_origins, ray_directions,
|
||||
self.rendering_kwargs) # channels last
|
||||
|
||||
# Reshape into 'raw' neural-rendered image
|
||||
H = W = self.neural_rendering_resolution
|
||||
feature_image = feature_samples.permute(0, 2, 1).reshape(
|
||||
N, feature_samples.shape[-1], H, W).contiguous()
|
||||
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
|
||||
|
||||
# Run superresolution to get final image
|
||||
rgb_image = feature_image[:, :3]
|
||||
sr_image = self.superresolution(
|
||||
rgb_image,
|
||||
feature_image,
|
||||
ws,
|
||||
noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
|
||||
**{
|
||||
k: synthesis_kwargs[k]
|
||||
for k in synthesis_kwargs.keys() if k != 'noise_mode'
|
||||
})
|
||||
|
||||
return {
|
||||
'image': sr_image,
|
||||
'image_raw': rgb_image,
|
||||
'image_depth': depth_image
|
||||
}
|
||||
|
||||
def sample(self,
|
||||
coordinates,
|
||||
directions,
|
||||
z,
|
||||
c,
|
||||
truncation_psi=1,
|
||||
truncation_cutoff=None,
|
||||
update_emas=False,
|
||||
**synthesis_kwargs):
|
||||
# Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes.
|
||||
ws = self.mapping(
|
||||
z,
|
||||
c,
|
||||
truncation_psi=truncation_psi,
|
||||
truncation_cutoff=truncation_cutoff,
|
||||
update_emas=update_emas)
|
||||
planes = self.backbone.synthesis(
|
||||
ws, update_emas=update_emas, **synthesis_kwargs)
|
||||
planes = planes.view(
|
||||
len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
|
||||
return self.renderer.run_model(planes, self.decoder, coordinates,
|
||||
directions, self.rendering_kwargs)
|
||||
|
||||
def sample_mixed(self,
|
||||
coordinates,
|
||||
directions,
|
||||
ws,
|
||||
truncation_psi=1,
|
||||
truncation_cutoff=None,
|
||||
update_emas=False,
|
||||
**synthesis_kwargs):
|
||||
# Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
|
||||
planes = self.backbone.synthesis(
|
||||
ws, update_emas=update_emas, **synthesis_kwargs)
|
||||
planes = planes.view(
|
||||
len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
|
||||
return self.renderer.run_model(planes, self.decoder, coordinates,
|
||||
directions, self.rendering_kwargs)
|
||||
|
||||
def forward(self,
|
||||
z,
|
||||
c,
|
||||
truncation_psi=1,
|
||||
truncation_cutoff=None,
|
||||
neural_rendering_resolution=None,
|
||||
update_emas=False,
|
||||
cache_backbone=False,
|
||||
use_cached_backbone=False,
|
||||
**synthesis_kwargs):
|
||||
# Render a batch of generated images.
|
||||
ws = self.mapping(
|
||||
z,
|
||||
c,
|
||||
truncation_psi=truncation_psi,
|
||||
truncation_cutoff=truncation_cutoff,
|
||||
update_emas=update_emas)
|
||||
return self.synthesis(
|
||||
ws,
|
||||
c,
|
||||
update_emas=update_emas,
|
||||
neural_rendering_resolution=neural_rendering_resolution,
|
||||
cache_backbone=cache_backbone,
|
||||
use_cached_backbone=use_cached_backbone,
|
||||
**synthesis_kwargs)
|
||||
|
||||
|
||||
class OSGDecoder(torch.nn.Module):
|
||||
|
||||
def __init__(self, n_features, options):
|
||||
super().__init__()
|
||||
self.hidden_dim = 64
|
||||
|
||||
self.net = torch.nn.Sequential(
|
||||
FullyConnectedLayer(
|
||||
n_features,
|
||||
self.hidden_dim,
|
||||
lr_multiplier=options['decoder_lr_mul']), torch.nn.Softplus(),
|
||||
FullyConnectedLayer(
|
||||
self.hidden_dim,
|
||||
1 + options['decoder_output_dim'],
|
||||
lr_multiplier=options['decoder_lr_mul']))
|
||||
|
||||
def forward(self, sampled_features, ray_directions):
|
||||
# Aggregate features
|
||||
sampled_features = sampled_features.mean(1)
|
||||
x = sampled_features
|
||||
|
||||
N, M, C = x.shape
|
||||
x = x.view(N * M, C)
|
||||
|
||||
x = self.net(x)
|
||||
x = x.view(N, M, -1)
|
||||
rgb = torch.sigmoid(x[..., 1:]) * (
|
||||
1 + 2 * 0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
|
||||
sigma = x[..., 0:1]
|
||||
return {'rgb': rgb, 'sigma': sigma}
|
||||
@@ -0,0 +1,697 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import segmentation_models_pytorch as smp
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
from .networks_stylegan2 import FullyConnectedLayer
|
||||
from .superresolution import SuperresolutionHybrid8XDC
|
||||
from .volumetric_rendering.ray_sampler import RaySampler
|
||||
from .volumetric_rendering.renderer import ImportanceRenderer
|
||||
|
||||
|
||||
class DWConv(nn.Module):
|
||||
|
||||
def __init__(self, dim=768):
|
||||
super(DWConv, self).__init__()
|
||||
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
x = x.transpose(1, 2).view(B, C, H, W)
|
||||
x = self.dwconv(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.dwconv = DWConv(hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = self.fc1(x)
|
||||
x = self.dwconv(x, H, W)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
sr_ratio=1):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, f'dim {dim} should be divided by num_heads {num_heads}.'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1:
|
||||
self.sr = nn.Conv2d(
|
||||
dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
q = self.q(x).reshape(B, N, self.num_heads,
|
||||
C // self.num_heads).permute(0, 2, 1, 3)
|
||||
|
||||
if self.sr_ratio > 1:
|
||||
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
||||
x_ = self.norm(x_)
|
||||
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads,
|
||||
C // self.num_heads).permute(
|
||||
2, 0, 3, 1, 4)
|
||||
else:
|
||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads,
|
||||
C // self.num_heads).permute(
|
||||
2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
sr_ratio=1):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
sr_ratio=sr_ratio)
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class OverlapPatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=7,
|
||||
stride=4,
|
||||
in_chans=3,
|
||||
embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.H, self.W = img_size[0] // patch_size[0], img_size[
|
||||
1] // patch_size[1]
|
||||
self.num_patches = self.H * self.W
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
kernel_size=patch_size,
|
||||
stride=stride,
|
||||
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
class Encoder_low(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
img_size=64,
|
||||
depth=5,
|
||||
in_chans=256,
|
||||
embed_dims=1024,
|
||||
num_head=4,
|
||||
mlp_ratio=2,
|
||||
sr_ratio=1,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6)):
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
|
||||
self.deeplabnet = smp.DeepLabV3(
|
||||
encoder_name='resnet34',
|
||||
encoder_depth=5,
|
||||
encoder_weights=None,
|
||||
decoder_channels=256,
|
||||
in_channels=5,
|
||||
classes=1)
|
||||
|
||||
self.deeplabnet.encoder.conv1 = nn.Conv2d(
|
||||
5,
|
||||
64,
|
||||
kernel_size=(7, 7),
|
||||
stride=(2, 2),
|
||||
padding=(3, 3),
|
||||
bias=False)
|
||||
self.deeplabnet.segmentation_head = nn.Sequential()
|
||||
self.deeplabnet.encoder.bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer1[0].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer1[0].bn2 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer1[1].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer1[1].bn2 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer1[2].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer1[2].bn2 = nn.Sequential()
|
||||
|
||||
self.deeplabnet.encoder.layer2[0].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer2[0].bn2 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer2[0].downsample[1] = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer2[1].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer2[1].bn2 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer2[2].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer2[2].bn2 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer2[3].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer2[3].bn2 = nn.Sequential()
|
||||
|
||||
self.deeplabnet.encoder.layer3[0].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer3[0].bn2 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer3[0].downsample[1] = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer3[1].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer3[1].bn2 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer3[2].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer3[2].bn2 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer3[3].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer3[3].bn2 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer3[4].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer3[4].bn2 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer3[5].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer3[5].bn2 = nn.Sequential()
|
||||
|
||||
self.deeplabnet.encoder.layer4[0].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer4[0].bn2 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer4[0].downsample[1] = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer4[1].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer4[1].bn2 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer4[2].bn1 = nn.Sequential()
|
||||
self.deeplabnet.encoder.layer4[2].bn2 = nn.Sequential()
|
||||
|
||||
self.deeplabnet.decoder[0].convs[0][1] = nn.Sequential()
|
||||
self.deeplabnet.decoder[0].convs[1][1] = nn.Sequential()
|
||||
self.deeplabnet.decoder[0].convs[2][1] = nn.Sequential()
|
||||
self.deeplabnet.decoder[0].convs[3][1] = nn.Sequential()
|
||||
self.deeplabnet.decoder[0].convs[4][2] = nn.Sequential()
|
||||
self.deeplabnet.decoder[0].project[1] = nn.Sequential()
|
||||
self.deeplabnet.decoder[2] = nn.Sequential()
|
||||
|
||||
self.patch_embed = OverlapPatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=3,
|
||||
stride=2,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dims)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
||||
cur = 0
|
||||
self.vit_block = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dims,
|
||||
num_heads=num_head,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[cur + i],
|
||||
norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratio) for i in range(depth)
|
||||
])
|
||||
self.norm1 = norm_layer(embed_dims)
|
||||
self.ps = nn.PixelShuffle(2)
|
||||
|
||||
self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=2)
|
||||
self.conv1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.upsample2 = nn.UpsamplingBilinear2d(scale_factor=2)
|
||||
self.conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
||||
self.relu2 = nn.ReLU()
|
||||
self.conv3 = nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, input):
|
||||
B = input.shape[0]
|
||||
|
||||
f_low = self.deeplabnet(input)
|
||||
x, H, W = self.patch_embed(f_low)
|
||||
|
||||
for i, blk in enumerate(self.vit_block):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm1(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
x = self.ps(x)
|
||||
|
||||
x = self.relu1(self.conv1(self.upsample1(x)))
|
||||
x = self.relu2(self.conv2(self.upsample2(x)))
|
||||
x = self.conv3(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder_high(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(5, 64, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.LeakyReLU(0.01)
|
||||
self.conv2 = nn.Conv2d(64, 96, kernel_size=3, stride=1, padding=1)
|
||||
self.relu2 = nn.LeakyReLU(0.01)
|
||||
self.conv3 = nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
||||
self.relu3 = nn.LeakyReLU(0.01)
|
||||
self.conv4 = nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
||||
self.relu4 = nn.LeakyReLU(0.01)
|
||||
self.conv5 = nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
||||
self.relu5 = nn.LeakyReLU(0.01)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu1(self.conv1(x))
|
||||
x = self.relu2(self.conv2(x))
|
||||
x = self.relu3(self.conv3(x))
|
||||
x = self.relu4(self.conv4(x))
|
||||
x = self.relu5(self.conv5(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MixFeature(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
img_size=256,
|
||||
depth=1,
|
||||
in_chans=128,
|
||||
embed_dims=1024,
|
||||
num_head=2,
|
||||
mlp_ratio=2,
|
||||
sr_ratio=2,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6)):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(192, 256, kernel_size=3, stride=1, padding=1)
|
||||
self.relu1 = nn.LeakyReLU(0.01)
|
||||
self.conv2 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
|
||||
self.relu2 = nn.LeakyReLU(0.01)
|
||||
|
||||
self.patch_embed = OverlapPatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=3,
|
||||
stride=2,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dims)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
||||
cur = 0
|
||||
self.vit_block = nn.ModuleList([
|
||||
Block(
|
||||
dim=embed_dims,
|
||||
num_heads=num_head,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[cur + i],
|
||||
norm_layer=norm_layer,
|
||||
sr_ratio=sr_ratio) for i in range(depth)
|
||||
])
|
||||
self.norm1 = norm_layer(embed_dims)
|
||||
self.ps = nn.PixelShuffle(2)
|
||||
|
||||
self.conv3 = nn.Conv2d(352, 256, kernel_size=3, stride=1, padding=1)
|
||||
self.relu3 = nn.LeakyReLU(0.01)
|
||||
self.conv4 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
|
||||
self.relu4 = nn.LeakyReLU(0.01)
|
||||
self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
||||
self.relu5 = nn.LeakyReLU(0.01)
|
||||
self.conv6 = nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x_low, x_high):
|
||||
x = torch.cat((x_low, x_high), 1)
|
||||
B = x.shape[0]
|
||||
|
||||
x = self.relu1(self.conv1(x))
|
||||
x = self.relu2(self.conv2(x))
|
||||
|
||||
x, H, W = self.patch_embed(x)
|
||||
|
||||
for i, blk in enumerate(self.vit_block):
|
||||
x = blk(x, H, W)
|
||||
x = self.norm1(x)
|
||||
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
||||
x = self.ps(x)
|
||||
|
||||
x = torch.cat((x, x_low), 1)
|
||||
x = self.relu3(self.conv3(x))
|
||||
x = self.relu4(self.conv4(x))
|
||||
x = self.relu5(self.conv5(x))
|
||||
x = self.conv6(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class OSGDecoder(torch.nn.Module):
|
||||
|
||||
def __init__(self, n_features, options):
|
||||
super().__init__()
|
||||
self.hidden_dim = 64
|
||||
|
||||
self.net = torch.nn.Sequential(
|
||||
FullyConnectedLayer(
|
||||
n_features,
|
||||
self.hidden_dim,
|
||||
lr_multiplier=options['decoder_lr_mul']), torch.nn.Softplus(),
|
||||
FullyConnectedLayer(
|
||||
self.hidden_dim,
|
||||
1 + options['decoder_output_dim'],
|
||||
lr_multiplier=options['decoder_lr_mul']))
|
||||
|
||||
def forward(self, sampled_features, ray_directions):
|
||||
sampled_features = sampled_features.mean(1)
|
||||
x = sampled_features
|
||||
|
||||
N, M, C = x.shape
|
||||
x = x.view(N * M, C)
|
||||
|
||||
x = self.net(x)
|
||||
x = x.view(N, M, -1)
|
||||
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
|
||||
sigma = x[..., 0:1]
|
||||
return {'rgb': rgb, 'sigma': sigma}
|
||||
|
||||
|
||||
class TriplaneEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
img_resolution,
|
||||
sr_num_fp16_res=0,
|
||||
rendering_kwargs={},
|
||||
sr_kwargs={}):
|
||||
super().__init__()
|
||||
self.encoder_low = Encoder_low(
|
||||
img_size=64,
|
||||
depth=5,
|
||||
in_chans=256,
|
||||
embed_dims=1024,
|
||||
num_head=4,
|
||||
mlp_ratio=2,
|
||||
sr_ratio=1,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
||||
self.encoder_high = Encoder_high()
|
||||
self.mix = MixFeature(
|
||||
img_size=256,
|
||||
depth=1,
|
||||
in_chans=128,
|
||||
embed_dims=1024,
|
||||
num_head=2,
|
||||
mlp_ratio=2,
|
||||
sr_ratio=2,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
||||
|
||||
self.renderer = ImportanceRenderer()
|
||||
self.ray_sampler = RaySampler()
|
||||
self.superresolution = SuperresolutionHybrid8XDC(
|
||||
channels=32,
|
||||
img_resolution=img_resolution,
|
||||
sr_num_fp16_res=sr_num_fp16_res,
|
||||
sr_antialias=rendering_kwargs['sr_antialias'],
|
||||
**sr_kwargs)
|
||||
self.decoder = OSGDecoder(
|
||||
32, {
|
||||
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
|
||||
'decoder_output_dim': 32
|
||||
})
|
||||
self.neural_rendering_resolution = 128
|
||||
self.rendering_kwargs = rendering_kwargs
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def gen_interfeats(self, ws, planes, camera_params):
|
||||
planes = planes.view(
|
||||
len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
|
||||
|
||||
cam2world_matrix = camera_params[:, :16].view(-1, 4, 4)
|
||||
intrinsics = camera_params[:, 16:25].view(-1, 3, 3)
|
||||
H = W = self.neural_rendering_resolution
|
||||
ray_origins, ray_directions = self.ray_sampler(
|
||||
cam2world_matrix, intrinsics, self.neural_rendering_resolution)
|
||||
N, M, _ = ray_origins.shape
|
||||
feature_samples, depth_samples, weights_samples = self.renderer(
|
||||
planes, self.decoder, ray_origins, ray_directions,
|
||||
self.rendering_kwargs)
|
||||
feature_image = feature_samples.permute(0, 2, 1).reshape(
|
||||
N, feature_samples.shape[-1], H, W).contiguous()
|
||||
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
|
||||
|
||||
rgb_image = feature_image[:, :3]
|
||||
sr_image = self.superresolution(
|
||||
rgb_image, feature_image, ws, noise_mode='const')
|
||||
|
||||
return depth_image, feature_image, rgb_image, sr_image
|
||||
|
||||
def sample(self, coordinates, directions, planes):
|
||||
planes = planes.view(
|
||||
len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
|
||||
return self.renderer.run_model(planes, self.decoder, coordinates,
|
||||
directions, self.rendering_kwargs)
|
||||
|
||||
def forward(self, ws, x, camera_ref, camera_mv):
|
||||
f = self.encoder_low(x)
|
||||
f_high = self.encoder_high(x)
|
||||
planes = self.mix(f, f_high)
|
||||
|
||||
depth_ref, feature_ref, rgb_ref, sr_ref = self.gen_interfeats(
|
||||
ws, planes, camera_ref)
|
||||
if camera_mv is not None:
|
||||
depth_mv, feature_mv, rgb_mv, sr_mv = self.gen_interfeats(
|
||||
ws, planes, camera_mv)
|
||||
else:
|
||||
depth_mv = feature_mv = rgb_mv = sr_mv = None
|
||||
|
||||
return planes, depth_ref, feature_ref, rgb_ref, sr_ref, depth_mv, feature_mv, rgb_mv, sr_mv
|
||||
|
||||
|
||||
def get_parameter_number(net):
|
||||
total_num = sum(p.numel() for p in net.parameters())
|
||||
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||
return {'Total': total_num, 'Trainable': trainable_num}
|
||||
@@ -0,0 +1,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
# empty
|
||||
@@ -0,0 +1,137 @@
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2022 Petr Kellnhofer
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def transform_vectors(matrix: torch.Tensor,
|
||||
vectors4: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Left-multiplies MxM @ NxM. Returns NxM.
|
||||
"""
|
||||
res = torch.matmul(vectors4, matrix.T)
|
||||
return res
|
||||
|
||||
|
||||
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Normalize vector lengths.
|
||||
"""
|
||||
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
|
||||
|
||||
|
||||
def torch_dot(x: torch.Tensor, y: torch.Tensor):
|
||||
"""
|
||||
Dot product of two tensors.
|
||||
"""
|
||||
return (x * y).sum(-1)
|
||||
|
||||
|
||||
def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor,
|
||||
box_side_length):
|
||||
"""
|
||||
Author: Petr Kellnhofer
|
||||
Intersects rays with the [-1, 1] NDC volume.
|
||||
Returns min and max distance of entry.
|
||||
Returns -1 for no intersection.
|
||||
https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
|
||||
"""
|
||||
o_shape = rays_o.shape
|
||||
rays_o = rays_o.detach().reshape(-1, 3)
|
||||
rays_d = rays_d.detach().reshape(-1, 3)
|
||||
|
||||
temp_min_1 = -1 * (box_side_length / 2)
|
||||
temp_min_2 = -1 * (box_side_length / 2)
|
||||
temp_min_3 = -1 * (box_side_length / 2)
|
||||
bb_min = [temp_min_1, temp_min_2, temp_min_3]
|
||||
temp_max_1 = 1 * (box_side_length / 2)
|
||||
temp_max_2 = 1 * (box_side_length / 2)
|
||||
temp_max_3 = 1 * (box_side_length / 2)
|
||||
bb_max = [temp_max_1, temp_max_2, temp_max_3]
|
||||
bounds = torch.tensor([bb_min, bb_max],
|
||||
dtype=rays_o.dtype,
|
||||
device=rays_o.device)
|
||||
is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
|
||||
|
||||
# Precompute inverse for stability.
|
||||
invdir = 1 / rays_d
|
||||
sign = (invdir < 0).long()
|
||||
|
||||
# Intersect with YZ plane.
|
||||
tmin = (bounds.index_select(0, sign[..., 0])[..., 0]
|
||||
- rays_o[..., 0]) * invdir[..., 0]
|
||||
tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0]
|
||||
- rays_o[..., 0]) * invdir[..., 0]
|
||||
|
||||
# Intersect with XZ plane.
|
||||
tymin = (bounds.index_select(0, sign[..., 1])[..., 1]
|
||||
- rays_o[..., 1]) * invdir[..., 1]
|
||||
tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1]
|
||||
- rays_o[..., 1]) * invdir[..., 1]
|
||||
|
||||
# Resolve parallel rays.
|
||||
is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
|
||||
|
||||
# Use the shortest intersection.
|
||||
tmin = torch.max(tmin, tymin)
|
||||
tmax = torch.min(tmax, tymax)
|
||||
|
||||
# Intersect with XY plane.
|
||||
tzmin = (bounds.index_select(0, sign[..., 2])[..., 2]
|
||||
- rays_o[..., 2]) * invdir[..., 2]
|
||||
tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2]
|
||||
- rays_o[..., 2]) * invdir[..., 2]
|
||||
|
||||
# Resolve parallel rays.
|
||||
is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
|
||||
|
||||
# Use the shortest intersection.
|
||||
tmin = torch.max(tmin, tzmin)
|
||||
tmax = torch.min(tmax, tzmax)
|
||||
|
||||
# Mark invalid.
|
||||
tmin[torch.logical_not(is_valid)] = -1
|
||||
tmax[torch.logical_not(is_valid)] = -2
|
||||
|
||||
return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
|
||||
|
||||
|
||||
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
|
||||
"""
|
||||
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
|
||||
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
|
||||
"""
|
||||
# create a tensor of 'num' steps from 0 to 1
|
||||
steps = torch.arange(
|
||||
num, dtype=torch.float32, device=start.device) / (
|
||||
num - 1)
|
||||
|
||||
# reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
|
||||
# - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
|
||||
# "cannot statically infer the expected size of a list in this contex", hence the code below
|
||||
for i in range(start.ndim):
|
||||
steps = steps.unsqueeze(-1)
|
||||
|
||||
# the output starts at 'start' and increments until 'stop' in each dimension
|
||||
out = start[None] + steps * (stop - start)[None]
|
||||
|
||||
return out
|
||||
@@ -0,0 +1,67 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
"""
|
||||
The ray marcher takes the raw output of the implicit representation and
|
||||
uses the volume rendering equation to produce composited colors and depths.
|
||||
Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MipRayMarcher2(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def run_forward(self, colors, densities, depths, rendering_options):
|
||||
deltas = depths[:, :, 1:] - depths[:, :, :-1]
|
||||
colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
|
||||
densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
|
||||
depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
|
||||
|
||||
if rendering_options['clamp_mode'] == 'softplus':
|
||||
densities_mid = F.softplus(
|
||||
densities_mid
|
||||
- 1) # activation bias of -1 makes things initialize better
|
||||
else:
|
||||
assert False, 'MipRayMarcher only supports `clamp_mode`=`softplus`!'
|
||||
|
||||
density_delta = densities_mid * deltas
|
||||
|
||||
alpha = 1 - torch.exp(-density_delta)
|
||||
|
||||
alpha_shifted = torch.cat(
|
||||
[torch.ones_like(alpha[:, :, :1]), 1 - alpha + 1e-10], -2)
|
||||
weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
|
||||
|
||||
composite_rgb = torch.sum(weights * colors_mid, -2)
|
||||
weight_total = weights.sum(2)
|
||||
composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
|
||||
|
||||
# clip the composite to min/max range of depths
|
||||
composite_depth = torch.nan_to_num(composite_depth, float('inf'))
|
||||
composite_depth = torch.clamp(composite_depth, torch.min(depths),
|
||||
torch.max(depths))
|
||||
|
||||
if rendering_options.get('white_back', False):
|
||||
composite_rgb = composite_rgb + 1 - weight_total
|
||||
|
||||
composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
|
||||
|
||||
return composite_rgb, composite_depth, weights
|
||||
|
||||
def forward(self, colors, densities, depths, rendering_options):
|
||||
composite_rgb, composite_depth, weights = self.run_forward(
|
||||
colors, densities, depths, rendering_options)
|
||||
|
||||
return composite_rgb, composite_depth, weights
|
||||
@@ -0,0 +1,80 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
"""
|
||||
The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
|
||||
Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class RaySampler(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = \
|
||||
None, None, None, None, None
|
||||
|
||||
def forward(self, cam2world_matrix, intrinsics, resolution):
|
||||
"""
|
||||
Create batches of rays and return origins and directions.
|
||||
|
||||
cam2world_matrix: (N, 4, 4)
|
||||
intrinsics: (N, 3, 3)
|
||||
resolution: int
|
||||
|
||||
ray_origins: (N, M, 3)
|
||||
ray_dirs: (N, M, 2)
|
||||
"""
|
||||
N, M = cam2world_matrix.shape[0], resolution**2
|
||||
cam_locs_world = cam2world_matrix[:, :3, 3]
|
||||
fx = intrinsics[:, 0, 0]
|
||||
fy = intrinsics[:, 1, 1]
|
||||
cx = intrinsics[:, 0, 2]
|
||||
cy = intrinsics[:, 1, 2]
|
||||
sk = intrinsics[:, 0, 1]
|
||||
|
||||
uv = torch.stack(
|
||||
torch.meshgrid(
|
||||
torch.arange(
|
||||
resolution,
|
||||
dtype=torch.float32,
|
||||
device=cam2world_matrix.device),
|
||||
torch.arange(
|
||||
resolution,
|
||||
dtype=torch.float32,
|
||||
device=cam2world_matrix.device))) * (1. / resolution) + (
|
||||
0.5 / resolution)
|
||||
uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
|
||||
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
|
||||
|
||||
x_cam = uv[:, :, 0].view(N, -1)
|
||||
y_cam = uv[:, :, 1].view(N, -1)
|
||||
z_cam = torch.ones((N, M), device=cam2world_matrix.device)
|
||||
|
||||
x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)
|
||||
* sk.unsqueeze(-1) / fy.unsqueeze(-1) - sk.unsqueeze(-1)
|
||||
* y_cam / fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
|
||||
y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
|
||||
|
||||
cam_rel_points = torch.stack(
|
||||
(x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
|
||||
|
||||
world_rel_points = torch.bmm(cam2world_matrix,
|
||||
cam_rel_points.permute(0, 2, 1)).permute(
|
||||
0, 2, 1)[:, :, :3]
|
||||
|
||||
ray_dirs = world_rel_points - cam_locs_world[:, None, :]
|
||||
ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
|
||||
|
||||
ray_origins = cam_locs_world.unsqueeze(1).repeat(
|
||||
1, ray_dirs.shape[1], 1)
|
||||
|
||||
return ray_origins, ray_dirs
|
||||
@@ -0,0 +1,341 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
"""
|
||||
The renderer is a module that takes in rays, decides where to sample along each
|
||||
ray, and computes pixel colors using the volume rendering equation.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from . import math_utils
|
||||
from .ray_marcher import MipRayMarcher2
|
||||
|
||||
|
||||
def generate_planes():
|
||||
"""
|
||||
Defines planes by the three vectors that form the "axes" of the
|
||||
plane. Should work with arbitrary number of planes and planes of
|
||||
arbitrary orientation.
|
||||
"""
|
||||
return torch.tensor(
|
||||
[[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 0, 1], [0, 1, 0]],
|
||||
[[0, 0, 1], [1, 0, 0], [0, 1, 0]]],
|
||||
dtype=torch.float32)
|
||||
|
||||
|
||||
def project_onto_planes(planes, coordinates):
|
||||
"""
|
||||
Does a projection of a 3D point onto a batch of 2D planes,
|
||||
returning 2D plane coordinates.
|
||||
|
||||
Takes plane axes of shape n_planes, 3, 3
|
||||
# Takes coordinates of shape N, M, 3
|
||||
# returns projections of shape N*n_planes, M, 2
|
||||
"""
|
||||
N, M, C = coordinates.shape
|
||||
n_planes, _, _ = planes.shape
|
||||
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1,
|
||||
-1).reshape(
|
||||
N * n_planes, M, 3)
|
||||
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(
|
||||
N, -1, -1, -1).reshape(N * n_planes, 3, 3).to(coordinates.device)
|
||||
projections = torch.bmm(coordinates, inv_planes)
|
||||
return projections[..., :2]
|
||||
|
||||
|
||||
def sample_from_planes(plane_axes,
|
||||
plane_features,
|
||||
coordinates,
|
||||
mode='bilinear',
|
||||
padding_mode='zeros',
|
||||
box_warp=None):
|
||||
assert padding_mode == 'zeros'
|
||||
N, n_planes, C, H, W = plane_features.shape
|
||||
_, M, _ = coordinates.shape
|
||||
plane_features = plane_features.view(N * n_planes, C, H, W)
|
||||
|
||||
coordinates = (2 / box_warp) * coordinates # TODO: add specific box bounds
|
||||
|
||||
projected_coordinates = project_onto_planes(plane_axes,
|
||||
coordinates).unsqueeze(1)
|
||||
output_features = torch.nn.functional.grid_sample(
|
||||
plane_features,
|
||||
projected_coordinates.float(),
|
||||
mode=mode,
|
||||
padding_mode=padding_mode,
|
||||
align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
|
||||
return output_features
|
||||
|
||||
|
||||
def sample_from_3dgrid(grid, coordinates):
|
||||
"""
|
||||
Expects coordinates in shape (batch_size, num_points_per_batch, 3)
|
||||
Expects grid in shape (1, channels, H, W, D)
|
||||
(Also works if grid has batch size)
|
||||
Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
|
||||
"""
|
||||
batch_size, n_coords, n_dims = coordinates.shape
|
||||
sampled_features = torch.nn.functional.grid_sample(
|
||||
grid.expand(batch_size, -1, -1, -1, -1),
|
||||
coordinates.reshape(batch_size, 1, 1, -1, n_dims),
|
||||
mode='bilinear',
|
||||
padding_mode='zeros',
|
||||
align_corners=False)
|
||||
N, C, H, W, D = sampled_features.shape
|
||||
sampled_features = sampled_features.permute(0, 4, 3, 2,
|
||||
1).reshape(N, H * W * D, C)
|
||||
return sampled_features
|
||||
|
||||
|
||||
class ImportanceRenderer(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ray_marcher = MipRayMarcher2()
|
||||
self.plane_axes = generate_planes()
|
||||
|
||||
def forward(self, planes, decoder, ray_origins, ray_directions,
|
||||
rendering_options):
|
||||
self.plane_axes = self.plane_axes.to(ray_origins.device)
|
||||
|
||||
if rendering_options['ray_start'] == rendering_options[
|
||||
'ray_end'] == 'auto':
|
||||
ray_start, ray_end = math_utils.get_ray_limits_box(
|
||||
ray_origins,
|
||||
ray_directions,
|
||||
box_side_length=rendering_options['box_warp'])
|
||||
is_ray_valid = ray_end > ray_start
|
||||
if torch.any(is_ray_valid).item():
|
||||
ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
|
||||
ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
|
||||
depths_coarse = self.sample_stratified(
|
||||
ray_origins, ray_start, ray_end,
|
||||
rendering_options['depth_resolution'],
|
||||
rendering_options['disparity_space_sampling'])
|
||||
else:
|
||||
# Create stratified depth samples
|
||||
depths_coarse = self.sample_stratified(
|
||||
ray_origins, rendering_options['ray_start'],
|
||||
rendering_options['ray_end'],
|
||||
rendering_options['depth_resolution'],
|
||||
rendering_options['disparity_space_sampling'])
|
||||
|
||||
batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape
|
||||
|
||||
# Coarse Pass
|
||||
sample_coordinates = (
|
||||
ray_origins.unsqueeze(-2)
|
||||
+ depths_coarse * ray_directions.unsqueeze(-2)).reshape(
|
||||
batch_size, -1, 3)
|
||||
sample_directions = ray_directions.unsqueeze(-2).expand(
|
||||
-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
|
||||
|
||||
out = self.run_model(planes, decoder, sample_coordinates,
|
||||
sample_directions, rendering_options)
|
||||
colors_coarse = out['rgb']
|
||||
densities_coarse = out['sigma']
|
||||
colors_coarse = colors_coarse.reshape(batch_size, num_rays,
|
||||
samples_per_ray,
|
||||
colors_coarse.shape[-1])
|
||||
densities_coarse = densities_coarse.reshape(batch_size, num_rays,
|
||||
samples_per_ray, 1)
|
||||
|
||||
# Fine Pass
|
||||
N_importance = rendering_options['depth_resolution_importance']
|
||||
if N_importance > 0:
|
||||
_, _, weights = self.ray_marcher(colors_coarse, densities_coarse,
|
||||
depths_coarse, rendering_options)
|
||||
|
||||
depths_fine = self.sample_importance(depths_coarse, weights,
|
||||
N_importance)
|
||||
|
||||
sample_directions = ray_directions.unsqueeze(-2).expand(
|
||||
-1, -1, N_importance, -1).reshape(batch_size, -1, 3)
|
||||
sample_coordinates = (
|
||||
ray_origins.unsqueeze(-2)
|
||||
+ depths_fine * ray_directions.unsqueeze(-2)).reshape(
|
||||
batch_size, -1, 3)
|
||||
|
||||
out = self.run_model(planes, decoder, sample_coordinates,
|
||||
sample_directions, rendering_options)
|
||||
colors_fine = out['rgb']
|
||||
densities_fine = out['sigma']
|
||||
colors_fine = colors_fine.reshape(batch_size, num_rays,
|
||||
N_importance,
|
||||
colors_fine.shape[-1])
|
||||
densities_fine = densities_fine.reshape(batch_size, num_rays,
|
||||
N_importance, 1)
|
||||
|
||||
all_depths, all_colors, all_densities = self.unify_samples(
|
||||
depths_coarse, colors_coarse, densities_coarse, depths_fine,
|
||||
colors_fine, densities_fine)
|
||||
|
||||
# Aggregate
|
||||
rgb_final, depth_final, weights = self.ray_marcher(
|
||||
all_colors, all_densities, all_depths, rendering_options)
|
||||
else:
|
||||
rgb_final, depth_final, weights = self.ray_marcher(
|
||||
colors_coarse, densities_coarse, depths_coarse,
|
||||
rendering_options)
|
||||
|
||||
return rgb_final, depth_final, weights.sum(2)
|
||||
|
||||
def run_model(self, planes, decoder, sample_coordinates, sample_directions,
|
||||
options):
|
||||
sampled_features = sample_from_planes(
|
||||
self.plane_axes,
|
||||
planes,
|
||||
sample_coordinates,
|
||||
padding_mode='zeros',
|
||||
box_warp=options['box_warp'])
|
||||
|
||||
out = decoder(sampled_features, sample_directions)
|
||||
if options.get('density_noise', 0) > 0:
|
||||
out['sigma'] += torch.randn_like(
|
||||
out['sigma']) * options['density_noise']
|
||||
return out
|
||||
|
||||
def sort_samples(self, all_depths, all_colors, all_densities):
|
||||
_, indices = torch.sort(all_depths, dim=-2)
|
||||
all_depths = torch.gather(all_depths, -2, indices)
|
||||
all_colors = torch.gather(
|
||||
all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
|
||||
all_densities = torch.gather(all_densities, -2,
|
||||
indices.expand(-1, -1, -1, 1))
|
||||
return all_depths, all_colors, all_densities
|
||||
|
||||
def unify_samples(self, depths1, colors1, densities1, depths2, colors2,
|
||||
densities2):
|
||||
all_depths = torch.cat([depths1, depths2], dim=-2)
|
||||
all_colors = torch.cat([colors1, colors2], dim=-2)
|
||||
all_densities = torch.cat([densities1, densities2], dim=-2)
|
||||
|
||||
_, indices = torch.sort(all_depths, dim=-2)
|
||||
all_depths = torch.gather(all_depths, -2, indices)
|
||||
all_colors = torch.gather(
|
||||
all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
|
||||
all_densities = torch.gather(all_densities, -2,
|
||||
indices.expand(-1, -1, -1, 1))
|
||||
|
||||
return all_depths, all_colors, all_densities
|
||||
|
||||
def sample_stratified(self,
|
||||
ray_origins,
|
||||
ray_start,
|
||||
ray_end,
|
||||
depth_resolution,
|
||||
disparity_space_sampling=False):
|
||||
"""
|
||||
Return depths of approximately uniformly spaced samples along rays.
|
||||
"""
|
||||
N, M, _ = ray_origins.shape
|
||||
if disparity_space_sampling:
|
||||
depths_coarse = torch.linspace(
|
||||
0, 1, depth_resolution,
|
||||
device=ray_origins.device).reshape(1, 1, depth_resolution,
|
||||
1).repeat(N, M, 1, 1)
|
||||
depth_delta = 1 / (depth_resolution - 1)
|
||||
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
|
||||
depths_coarse = 1. / (1. / ray_start * (1. - depths_coarse)
|
||||
+ 1. / ray_end * depths_coarse)
|
||||
else:
|
||||
if type(ray_start) == torch.Tensor:
|
||||
depths_coarse = math_utils.linspace(ray_start, ray_end,
|
||||
depth_resolution).permute(
|
||||
1, 2, 0, 3)
|
||||
depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
|
||||
depths_coarse += torch.rand_like(depths_coarse) * depth_delta[
|
||||
..., None]
|
||||
else:
|
||||
depths_coarse = torch.linspace(
|
||||
ray_start,
|
||||
ray_end,
|
||||
depth_resolution,
|
||||
device=ray_origins.device).reshape(1, 1, depth_resolution,
|
||||
1).repeat(N, M, 1, 1)
|
||||
depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
|
||||
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
|
||||
|
||||
return depths_coarse
|
||||
|
||||
def sample_importance(self, z_vals, weights, N_importance):
|
||||
"""
|
||||
Return depths of importance sampled points along rays. See NeRF importance sampling for more.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
batch_size, num_rays, samples_per_ray, _ = z_vals.shape
|
||||
|
||||
z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
|
||||
weights = weights.reshape(
|
||||
batch_size * num_rays,
|
||||
-1) # -1 to account for loss of 1 sample in MipRayMarcher
|
||||
|
||||
# smooth weights
|
||||
weights = torch.nn.functional.max_pool1d(
|
||||
weights.unsqueeze(1).float(), 2, 1, padding=1)
|
||||
weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
|
||||
weights = weights + 0.01
|
||||
|
||||
z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:])
|
||||
importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
|
||||
N_importance).detach().reshape(
|
||||
batch_size, num_rays,
|
||||
N_importance, 1)
|
||||
return importance_z_vals
|
||||
|
||||
def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
|
||||
"""
|
||||
Sample @N_importance samples from @bins with distribution defined by @weights.
|
||||
Inputs:
|
||||
bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
|
||||
weights: (N_rays, N_samples_)
|
||||
N_importance: the number of samples to draw from the distribution
|
||||
det: deterministic or not
|
||||
eps: a small number to prevent division by zero
|
||||
Outputs:
|
||||
samples: the sampled samples
|
||||
"""
|
||||
N_rays, N_samples_ = weights.shape
|
||||
weights = weights + eps # prevent division by zero (don't do inplace op!)
|
||||
pdf = weights / torch.sum(
|
||||
weights, -1, keepdim=True) # (N_rays, N_samples_)
|
||||
cdf = torch.cumsum(
|
||||
pdf, -1) # (N_rays, N_samples), cumulative distribution function
|
||||
cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf],
|
||||
-1) # (N_rays, N_samples_+1)
|
||||
# padded to 0~1 inclusive
|
||||
|
||||
if det:
|
||||
u = torch.linspace(0, 1, N_importance, device=bins.device)
|
||||
u = u.expand(N_rays, N_importance)
|
||||
else:
|
||||
u = torch.rand(N_rays, N_importance, device=bins.device)
|
||||
u = u.contiguous()
|
||||
|
||||
inds = torch.searchsorted(cdf, u, right=True)
|
||||
below = torch.clamp_min(inds - 1, 0)
|
||||
above = torch.clamp_max(inds, N_samples_)
|
||||
|
||||
inds_sampled = torch.stack([below, above],
|
||||
-1).view(N_rays, 2 * N_importance)
|
||||
cdf_g = torch.gather(cdf, 1,
|
||||
inds_sampled).view(N_rays, N_importance, 2)
|
||||
bins_g = torch.gather(bins, 1,
|
||||
inds_sampled).view(N_rays, N_importance, 2)
|
||||
|
||||
denom = cdf_g[..., 1] - cdf_g[..., 0]
|
||||
denom[denom < eps] = 1
|
||||
|
||||
samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / denom * (
|
||||
bins_g[..., 1] - bins_g[..., 0])
|
||||
return samples
|
||||
20
modelscope/models/cv/image_view_transform/__init__.py
Normal file
20
modelscope/models/cv/image_view_transform/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .image_view_transform_infer import ImageViewTransform
|
||||
|
||||
else:
|
||||
_import_structure = {'image_view_transform_infer': ['ImageViewTransform']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,219 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
|
||||
import cv2
|
||||
import diffusers # 0.12.1
|
||||
import fire
|
||||
import numpy as np
|
||||
import rich
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from rich import print
|
||||
from torch import autocast
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.fileio import load
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .ldm.ddim import DDIMSampler
|
||||
from .util import instantiate_from_config, load_and_preprocess
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def load_model_from_config(model, config, ckpt, device, verbose=False):
|
||||
print(f'Loading model from {ckpt}')
|
||||
pl_sd = torch.load(ckpt, map_location='cpu')
|
||||
if 'global_step' in pl_sd:
|
||||
print(f'Global Step: {pl_sd["global_step"]}')
|
||||
sd = pl_sd['state_dict']
|
||||
model = instantiate_from_config(config.model)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_view_transform, module_name=Models.image_view_transform)
|
||||
class ImageViewTransform(TorchModel):
|
||||
"""initialize the image view translation model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, device='cpu', *args, **kwargs):
|
||||
|
||||
super().__init__(model_dir=model_dir, device=device, *args, **kwargs)
|
||||
|
||||
self.device = torch.device(
|
||||
device if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
config = os.path.join(model_dir,
|
||||
'sd-objaverse-finetune-c_concat-256.yaml')
|
||||
ckpt = os.path.join(model_dir, 'zero123-xl.ckpt')
|
||||
config = OmegaConf.load(config)
|
||||
self.model = None
|
||||
self.model = load_model_from_config(
|
||||
self.model, config, ckpt, device=self.device)
|
||||
|
||||
def forward(self, model_path, x, y):
|
||||
pred_results = _infer(self.model, model_path, x, y, self.device)
|
||||
return pred_results
|
||||
|
||||
|
||||
def infer(genmodel, model_path, image_path, target_view_path, device):
|
||||
output_ims = genmodel(model_path, image_path, target_view_path)
|
||||
return output_ims
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_model(input_im, model, sampler, precision, h, w, ddim_steps,
|
||||
n_samples, scale, ddim_eta, x, y, z):
|
||||
precision_scope = autocast if precision == 'autocast' else nullcontext
|
||||
with precision_scope('cuda'):
|
||||
with model.ema_scope():
|
||||
c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1)
|
||||
T = torch.tensor([
|
||||
math.radians(x),
|
||||
math.sin(math.radians(y)),
|
||||
math.cos(math.radians(y)), z
|
||||
])
|
||||
T = T[None, None, :].repeat(n_samples, 1, 1).to(c.device)
|
||||
c = torch.cat([c, T], dim=-1)
|
||||
c = model.cc_projection(c)
|
||||
cond = {}
|
||||
cond['c_crossattn'] = [c]
|
||||
cond['c_concat'] = [
|
||||
model.encode_first_stage(
|
||||
(input_im.to(c.device))).mode().detach().repeat(
|
||||
n_samples, 1, 1, 1)
|
||||
]
|
||||
if scale != 1.0:
|
||||
uc = {}
|
||||
uc['c_concat'] = [
|
||||
torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)
|
||||
]
|
||||
uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)]
|
||||
else:
|
||||
uc = None
|
||||
|
||||
shape = [4, h // 8, w // 8]
|
||||
samples_ddim, _ = sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=n_samples,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
x_T=None)
|
||||
# samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
return torch.clamp(
|
||||
(x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
||||
|
||||
|
||||
def preprocess_image(models, input_im, preprocess, carvekit_path):
|
||||
'''
|
||||
:param input_im (PIL Image).
|
||||
:return input_im (H, W, 3) array in [0, 1].
|
||||
'''
|
||||
|
||||
print('old input_im:', input_im.size)
|
||||
|
||||
if preprocess:
|
||||
|
||||
# model_carvekit = create_carvekit_interface()
|
||||
model_carvekit = torch.load(carvekit_path)
|
||||
input_im = load_and_preprocess(model_carvekit, input_im)
|
||||
input_im = (input_im / 255.0).astype(np.float32)
|
||||
# (H, W, 3) array in [0, 1].
|
||||
else:
|
||||
input_im = input_im.resize([256, 256], Image.Resampling.LANCZOS)
|
||||
input_im = np.asarray(input_im, dtype=np.float32) / 255.0
|
||||
alpha = input_im[:, :, 3:4]
|
||||
white_im = np.ones_like(input_im)
|
||||
input_im = alpha * input_im + (1.0 - alpha) * white_im
|
||||
|
||||
input_im = input_im[:, :, 0:3]
|
||||
# (H, W, 3) array in [0, 1].
|
||||
|
||||
return input_im
|
||||
|
||||
|
||||
def main_run(models,
|
||||
device,
|
||||
return_what,
|
||||
x=0.0,
|
||||
y=0.0,
|
||||
z=0.0,
|
||||
raw_im=None,
|
||||
carvekit_path=None,
|
||||
preprocess=True,
|
||||
scale=3.0,
|
||||
n_samples=4,
|
||||
ddim_steps=50,
|
||||
ddim_eta=1.0,
|
||||
precision='fp32',
|
||||
h=256,
|
||||
w=256):
|
||||
'''
|
||||
:param raw_im (PIL Image).
|
||||
'''
|
||||
|
||||
raw_im.thumbnail([1536, 1536], Image.Resampling.LANCZOS)
|
||||
input_im = preprocess_image(models, raw_im, preprocess, carvekit_path)
|
||||
|
||||
if 'gen' in return_what:
|
||||
input_im = transforms.ToTensor()(input_im).unsqueeze(0).to(device)
|
||||
input_im = input_im * 2 - 1
|
||||
input_im = transforms.functional.resize(input_im, [h, w])
|
||||
|
||||
sampler = DDIMSampler(models)
|
||||
# used_x = -x # NOTE: Polar makes more sense in Basile's opinion this way!
|
||||
used_x = x # NOTE: Set this way for consistency.
|
||||
x_samples_ddim = sample_model(input_im, models, sampler, precision, h,
|
||||
w, ddim_steps, n_samples, scale,
|
||||
ddim_eta, used_x, y, z)
|
||||
|
||||
output_ims = []
|
||||
for x_sample in x_samples_ddim:
|
||||
image = x_sample.detach().cpu().squeeze().numpy()
|
||||
image = np.transpose(image, (1, 2, 0)) * 255
|
||||
image = np.uint8(image)
|
||||
bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
output_ims.append(bgr)
|
||||
|
||||
return output_ims
|
||||
|
||||
|
||||
def _infer(genmodel, model_path, image_path, target_view_path, device):
|
||||
if isinstance(image_path, str):
|
||||
raw_image = load(image_path)
|
||||
print(type(raw_image))
|
||||
else:
|
||||
raw_image = image_path
|
||||
if isinstance(target_view_path, str):
|
||||
views = load(target_view_path)
|
||||
else:
|
||||
views = target_view_path
|
||||
# views = views.astype(np.float32)
|
||||
carvekit_path = os.path.join(model_path, 'carvekit.pth')
|
||||
output_ims = main_run(genmodel, device, 'angles_gen', views[0], views[1],
|
||||
views[2], raw_image, carvekit_path, views[3],
|
||||
views[4], views[5], views[6], views[7])
|
||||
return output_ims
|
||||
294
modelscope/models/cv/image_view_transform/ldm/attention.py
Normal file
294
modelscope/models/cv/image_view_transform/ldm/attention.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import math
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
from .util_diffusion import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(
|
||||
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(
|
||||
qkv,
|
||||
'b (qkv heads c) h w -> qkv b heads c (h w)',
|
||||
heads=self.heads,
|
||||
qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(
|
||||
out,
|
||||
'b heads c (h w) -> b (heads c) h w',
|
||||
heads=self.heads,
|
||||
h=h,
|
||||
w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False):
|
||||
super().__init__()
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else
|
||||
None) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(),
|
||||
self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(
|
||||
self.norm1(x),
|
||||
context=context if self.disable_self_attn else None) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
disable_self_attn=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn) for d in range(depth)
|
||||
])
|
||||
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(
|
||||
inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
555
modelscope/models/cv/image_view_transform/ldm/autoencoder.py
Executable file
555
modelscope/models/cv/image_view_transform/ldm/autoencoder.py
Executable file
@@ -0,0 +1,555 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
|
||||
from ..util import instantiate_from_config
|
||||
from .distributions import DiagonalGaussianDistribution
|
||||
from .model import Decoder, Encoder
|
||||
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(
|
||||
n_embed,
|
||||
embed_dim,
|
||||
beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig['z_channels'], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(
|
||||
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
|
||||
)
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f'{context}: Switched to EMA weights')
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f'{context}: Restored training weights')
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(
|
||||
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
||||
)
|
||||
if len(missing) > 0:
|
||||
print(f'Missing Keys: {missing}')
|
||||
print(f'Unexpected Keys: {unexpected}')
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_, _, ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1,
|
||||
2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
if self.global_step <= 4:
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(
|
||||
np.arange(lower_size, upper_size + 16, 16))
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode='bicubic')
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# https://github.com/pytorch/pytorch/issues/37142
|
||||
# try not to fool the heuristics
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train',
|
||||
predicted_indices=ind)
|
||||
|
||||
self.log_dict(
|
||||
log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train')
|
||||
self.log_dict(
|
||||
log_dict_disc,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
self._validation_step(batch, batch_idx, suffix='_ema')
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=''):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val' + suffix,
|
||||
predicted_indices=ind)
|
||||
|
||||
discloss, log_dict_disc = self.loss(
|
||||
qloss,
|
||||
x,
|
||||
xrec,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val' + suffix,
|
||||
predicted_indices=ind)
|
||||
rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
|
||||
self.log(
|
||||
f'val{suffix}/rec_loss',
|
||||
rec_loss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
sync_dist=True)
|
||||
self.log(
|
||||
f'val{suffix}/aeloss',
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
sync_dist=True)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f'val{suffix}/rec_loss']
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor * self.learning_rate
|
||||
print('lr_d', lr_d)
|
||||
print('lr_g', lr_g)
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters()) + list(self.decoder.parameters())
|
||||
+ list(self.quantize.parameters())
|
||||
+ list(self.quant_conv.parameters())
|
||||
+ list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g,
|
||||
betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9))
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print('Setting up LambdaLR scheduler...')
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler':
|
||||
LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
{
|
||||
'scheduler':
|
||||
LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log['inputs'] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['inputs'] = x
|
||||
log['reconstructions'] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log['reconstructions_ema'] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig['double_z']
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
|
||||
2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig['z_channels'], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f'Restored from {path}')
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1,
|
||||
2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train')
|
||||
self.log(
|
||||
'aeloss',
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
self.log_dict(
|
||||
log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train')
|
||||
|
||||
self.log(
|
||||
'discloss',
|
||||
discloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
self.log_dict(
|
||||
log_dict_disc,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val')
|
||||
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val')
|
||||
|
||||
self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(
|
||||
list(self.encoder.parameters()) + list(self.decoder.parameters())
|
||||
+ list(self.quant_conv.parameters())
|
||||
+ list(self.post_quant_conv.parameters()),
|
||||
lr=lr,
|
||||
betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log['reconstructions'] = xrec
|
||||
log['inputs'] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
433
modelscope/models/cv/image_view_transform/ldm/ddim.py
Executable file
433
modelscope/models/cv/image_view_transform/ldm/ddim.py
Executable file
@@ -0,0 +1,433 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from tqdm import tqdm
|
||||
|
||||
from .sampling_util import (norm_thresholding, renorm_thresholding,
|
||||
spatial_norm_thresholding)
|
||||
from .util_diffusion import (extract_into_tensor,
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps, noise_like)
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
|
||||
def __init__(self, model, schedule='linear', **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def to(self, device):
|
||||
"""Same as to in torch module
|
||||
Don't really underestand why this isn't a module in the first place"""
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
new_v = getattr(self, k).to(device)
|
||||
setattr(self, k, new_v)
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device('cuda'):
|
||||
attr = attr.to(torch.device('cuda'))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.,
|
||||
verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
def to_torch(x):
|
||||
return x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev',
|
||||
to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
alpha_1 = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
|
||||
alpha_2 = (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
alpha_1 * alpha_2)
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
**kwargs):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
t_start=-1):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = reversed(range(
|
||||
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
||||
0]
|
||||
print(f'Running DDIM Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0 = outs
|
||||
if callback:
|
||||
img = callback(i, img, pred_x0)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat(
|
||||
[unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat(
|
||||
[unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
||||
**corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
if use_original_steps:
|
||||
alphas_prev = self.model.alphas_cumprod_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
alphas_prev = self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1),
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self,
|
||||
x0,
|
||||
c,
|
||||
t_enc,
|
||||
use_original_steps=False,
|
||||
return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None):
|
||||
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[
|
||||
0]
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
||||
t = torch.full((x0.shape[0], ),
|
||||
i,
|
||||
device=self.model.device,
|
||||
dtype=torch.long)
|
||||
if unconditional_guidance_scale == 1.:
|
||||
noise_pred = self.model.apply_model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model.apply_model(
|
||||
torch.cat((x_next, x_next)), torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c))), 2)
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (
|
||||
noise_pred - e_t_uncond)
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
alp_1 = (1 / alphas_next[i] - 1).sqrt()
|
||||
alp_2 = (1 / alphas[i] - 1).sqrt()
|
||||
weighted_noise_pred = alphas_next[i].sqrt() * (
|
||||
alp_1 - alp_2) * noise_pred
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if return_intermediates and i % (num_steps // return_intermediates
|
||||
) == 0 and i < num_steps - 1:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
|
||||
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({'intermediates': intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
||||
* noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps
|
||||
) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f'Running DDIM Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0], ),
|
||||
step,
|
||||
device=x_latent.device,
|
||||
dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
return x_dec
|
||||
2553
modelscope/models/cv/image_view_transform/ldm/ddpm.py
Executable file
2553
modelscope/models/cv/image_view_transform/ldm/ddpm.py
Executable file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,92 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(
|
||||
self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar
|
||||
+ torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, 'at least one argument must be a Tensor'
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
comp_1 = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
|
||||
comp_2 = ((mean1 - mean2)**2) * torch.exp(-logvar2)
|
||||
return 0.5 * (comp_1 + comp_2)
|
||||
84
modelscope/models/cv/image_view_transform/ldm/ema.py
Normal file
84
modelscope/models/cv/image_view_transform/ldm/ema.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
'num_updates',
|
||||
torch.tensor(0, dtype=torch.int)
|
||||
if use_num_upates else torch.tensor(-1, dtype=torch.int))
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
|
||||
s_name = name.replace('.', '')
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
num_1 = (1 + self.num_updates)
|
||||
num_2 = (10 + self.num_updates)
|
||||
decay = min(self.decay, num_1 / num_2)
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(
|
||||
m_param[key])
|
||||
param_1 = (shadow_params[sname] - m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * param_1)
|
||||
else:
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(
|
||||
shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
131
modelscope/models/cv/image_view_transform/ldm/helpers.py
Normal file
131
modelscope/models/cv/image_view_transform/ldm/helpers.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
from torch.nn import (AdaptiveAvgPool2d, BatchNorm2d, Conv2d, MaxPool2d,
|
||||
Module, PReLU, ReLU, Sequential, Sigmoid)
|
||||
|
||||
|
||||
class Flatten(Module):
|
||||
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
def l2_norm(input, axis=1):
|
||||
norm = torch.norm(input, 2, axis, True)
|
||||
output = torch.div(input, norm)
|
||||
return output
|
||||
|
||||
|
||||
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
||||
pass
|
||||
|
||||
|
||||
def get_block(in_channel, depth, num_units, stride=2):
|
||||
return [Bottleneck(in_channel, depth, stride)
|
||||
] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
||||
|
||||
|
||||
def get_blocks(num_layers):
|
||||
if num_layers == 50:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=4),
|
||||
get_block(in_channel=128, depth=256, num_units=14),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 100:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=13),
|
||||
get_block(in_channel=128, depth=256, num_units=30),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 152:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=8),
|
||||
get_block(in_channel=128, depth=256, num_units=36),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
'Invalid number of layers: {}. Must be one of [50, 100, 152]'.
|
||||
format(num_layers))
|
||||
return blocks
|
||||
|
||||
|
||||
class SEModule(Module):
|
||||
|
||||
def __init__(self, channels, reduction):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2d(1)
|
||||
self.fc1 = Conv2d(
|
||||
channels,
|
||||
channels // reduction,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
bias=False)
|
||||
self.relu = ReLU(inplace=True)
|
||||
self.fc2 = Conv2d(
|
||||
channels // reduction,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
bias=False)
|
||||
self.sigmoid = Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return module_input * x
|
||||
|
||||
|
||||
class bottleneck_IR(Module):
|
||||
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth))
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
BatchNorm2d(depth))
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
|
||||
|
||||
class bottleneck_IR_SE(Module):
|
||||
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR_SE, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth))
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
BatchNorm2d(depth), SEModule(depth, 16))
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
27
modelscope/models/cv/image_view_transform/ldm/id_loss.py
Normal file
27
modelscope/models/cv/image_view_transform/ldm/id_loss.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .model_irse import Backbone
|
||||
|
||||
|
||||
class IDFeatures(nn.Module):
|
||||
|
||||
def __init__(self, model_path):
|
||||
super(IDFeatures, self).__init__()
|
||||
print('Loading ResNet ArcFace')
|
||||
self.facenet = Backbone(
|
||||
input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
||||
self.facenet.load_state_dict(
|
||||
torch.load(model_path, map_location='cpu'))
|
||||
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
||||
self.facenet.eval()
|
||||
|
||||
def forward(self, x, crop=False):
|
||||
# Not sure of the image range here
|
||||
if crop:
|
||||
x = torch.nn.functional.interpolate(x, (256, 256), mode='area')
|
||||
x = x[:, :, 35:223, 32:220]
|
||||
x = self.face_pool(x)
|
||||
x_feats = self.facenet(x)
|
||||
return x_feats
|
||||
961
modelscope/models/cv/image_view_transform/ldm/model.py
Normal file
961
modelscope/models/cv/image_view_transform/ldm/model.py
Normal file
@@ -0,0 +1,961 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
from ..util import instantiate_from_config
|
||||
from .attention import LinearAttention
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x, scale_factor=2.0, mode='nearest')
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class LinAttnBlock(LinearAttention):
|
||||
"""to match AttnBlock usage"""
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(
|
||||
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type='vanilla'):
|
||||
assert attn_type in ['vanilla', 'linear',
|
||||
'none'], f'attn_type {attn_type} unknown'
|
||||
print(
|
||||
f"making attention of type '{attn_type}' with {in_channels} in_channels"
|
||||
)
|
||||
if attn_type == 'vanilla':
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == 'none':
|
||||
return nn.Identity(in_channels)
|
||||
else:
|
||||
return LinAttnBlock(in_channels)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
use_timestep=True,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla'):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = 'linear'
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch * 4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
torch.nn.Linear(self.ch, self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
skip_in = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch * in_ch_mult[i_level]
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in + skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
if context is not None:
|
||||
# assume aligned context, cat along channel axis
|
||||
x = torch.cat((x, context), dim=1)
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()],
|
||||
dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = 'linear'
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = 'linear'
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print('Working with z of shape {} = {} dimensions.'.format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class SimpleDecoder(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.model = nn.ModuleList([
|
||||
nn.Conv2d(in_channels, in_channels, 1),
|
||||
ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0),
|
||||
ResnetBlock(
|
||||
in_channels=2 * in_channels,
|
||||
out_channels=4 * in_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0),
|
||||
ResnetBlock(
|
||||
in_channels=4 * in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0),
|
||||
nn.Conv2d(2 * in_channels, in_channels, 1),
|
||||
Upsample(in_channels, with_conv=True)
|
||||
])
|
||||
# end
|
||||
self.norm_out = Normalize(in_channels)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.model):
|
||||
if i in [1, 2, 3]:
|
||||
x = layer(x, None)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
|
||||
class UpsampleDecoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
ch,
|
||||
num_res_blocks,
|
||||
resolution,
|
||||
ch_mult=(2, 2),
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
# upsampling
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = in_channels
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.res_blocks = nn.ModuleList()
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
res_block = []
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
res_block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
self.res_blocks.append(nn.ModuleList(res_block))
|
||||
if i_level != self.num_resolutions - 1:
|
||||
self.upsample_blocks.append(Upsample(block_in, True))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# upsampling
|
||||
h = x
|
||||
for k, i_level in enumerate(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.res_blocks[i_level][i_block](h, None)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class LatentRescaler(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
factor,
|
||||
in_channels,
|
||||
mid_channels,
|
||||
out_channels,
|
||||
depth=2):
|
||||
super().__init__()
|
||||
# residual block, interpolate, residual block
|
||||
self.factor = factor
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.res_block1 = nn.ModuleList([
|
||||
ResnetBlock(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)
|
||||
])
|
||||
self.attn = AttnBlock(mid_channels)
|
||||
self.res_block2 = nn.ModuleList([
|
||||
ResnetBlock(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)
|
||||
])
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
for block in self.res_block1:
|
||||
x = block(x, None)
|
||||
x = torch.nn.functional.interpolate(
|
||||
x,
|
||||
size=(int(round(x.shape[2] * self.factor)),
|
||||
int(round(x.shape[3] * self.factor))))
|
||||
x = self.attn(x)
|
||||
for block in self.res_block2:
|
||||
x = block(x, None)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
ch,
|
||||
resolution,
|
||||
out_ch,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
rescale_factor=1.0,
|
||||
rescale_module_depth=1):
|
||||
super().__init__()
|
||||
intermediate_chn = ch * ch_mult[-1]
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
z_channels=intermediate_chn,
|
||||
double_z=False,
|
||||
resolution=resolution,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
out_ch=None)
|
||||
self.rescaler = LatentRescaler(
|
||||
factor=rescale_factor,
|
||||
in_channels=intermediate_chn,
|
||||
mid_channels=intermediate_chn,
|
||||
out_channels=out_ch,
|
||||
depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.rescaler(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleDecoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
z_channels,
|
||||
out_ch,
|
||||
resolution,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
rescale_factor=1.0,
|
||||
rescale_module_depth=1):
|
||||
super().__init__()
|
||||
tmp_chn = z_channels * ch_mult[-1]
|
||||
self.decoder = Decoder(
|
||||
out_ch=out_ch,
|
||||
z_channels=tmp_chn,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
in_channels=None,
|
||||
num_res_blocks=num_res_blocks,
|
||||
ch_mult=ch_mult,
|
||||
resolution=resolution,
|
||||
ch=ch)
|
||||
self.rescaler = LatentRescaler(
|
||||
factor=rescale_factor,
|
||||
in_channels=z_channels,
|
||||
mid_channels=tmp_chn,
|
||||
out_channels=tmp_chn,
|
||||
depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_size,
|
||||
out_size,
|
||||
in_channels,
|
||||
out_channels,
|
||||
ch_mult=2):
|
||||
super().__init__()
|
||||
assert out_size >= in_size
|
||||
num_blocks = int(np.log2(out_size // in_size)) + 1
|
||||
factor_up = 1. + (out_size % in_size)
|
||||
print(
|
||||
f'Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}'
|
||||
)
|
||||
self.rescaler = LatentRescaler(
|
||||
factor=factor_up,
|
||||
in_channels=in_channels,
|
||||
mid_channels=2 * in_channels,
|
||||
out_channels=in_channels)
|
||||
self.decoder = Decoder(
|
||||
out_ch=out_channels,
|
||||
resolution=out_size,
|
||||
z_channels=in_channels,
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=[],
|
||||
in_channels=None,
|
||||
ch=in_channels,
|
||||
ch_mult=[ch_mult for _ in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Resize(nn.Module):
|
||||
|
||||
def __init__(self, in_channels=None, learned=False, mode='bilinear'):
|
||||
super().__init__()
|
||||
self.with_conv = learned
|
||||
self.mode = mode
|
||||
if self.with_conv:
|
||||
print(
|
||||
f'Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode'
|
||||
)
|
||||
raise NotImplementedError()
|
||||
assert in_channels is not None
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=4, stride=2, padding=1)
|
||||
|
||||
def forward(self, x, scale_factor=1.0):
|
||||
if scale_factor == 1.0:
|
||||
return x
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(
|
||||
x,
|
||||
mode=self.mode,
|
||||
align_corners=False,
|
||||
scale_factor=scale_factor)
|
||||
return x
|
||||
|
||||
|
||||
class FirstStagePostProcessor(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
ch_mult: list,
|
||||
in_channels,
|
||||
pretrained_model: nn.Module = None,
|
||||
reshape=False,
|
||||
n_channels=None,
|
||||
dropout=0.,
|
||||
pretrained_config=None):
|
||||
super().__init__()
|
||||
if pretrained_config is None:
|
||||
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
||||
self.pretrained_model = pretrained_model
|
||||
else:
|
||||
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
||||
self.instantiate_pretrained(pretrained_config)
|
||||
|
||||
self.do_reshape = reshape
|
||||
|
||||
if n_channels is None:
|
||||
n_channels = self.pretrained_model.encoder.ch
|
||||
|
||||
self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, n_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
blocks = []
|
||||
downs = []
|
||||
ch_in = n_channels
|
||||
for m in ch_mult:
|
||||
blocks.append(
|
||||
ResnetBlock(
|
||||
in_channels=ch_in,
|
||||
out_channels=m * n_channels,
|
||||
dropout=dropout))
|
||||
ch_in = m * n_channels
|
||||
downs.append(Downsample(ch_in, with_conv=False))
|
||||
|
||||
self.model = nn.ModuleList(blocks)
|
||||
self.downsampler = nn.ModuleList(downs)
|
||||
|
||||
def instantiate_pretrained(self, config):
|
||||
model = instantiate_from_config(config)
|
||||
self.pretrained_model = model.eval()
|
||||
# self.pretrained_model.train = False
|
||||
for param in self.pretrained_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_with_pretrained(self, x):
|
||||
c = self.pretrained_model.encode(x)
|
||||
if isinstance(c, DiagonalGaussianDistribution):
|
||||
c = c.mode()
|
||||
return c
|
||||
|
||||
def forward(self, x):
|
||||
z_fs = self.encode_with_pretrained(x)
|
||||
z = self.proj_norm(z_fs)
|
||||
z = self.proj(z)
|
||||
z = nonlinearity(z)
|
||||
|
||||
for submodel, downmodel in zip(self.model, self.downsampler):
|
||||
z = submodel(z, temb=None)
|
||||
z = downmodel(z)
|
||||
|
||||
if self.do_reshape:
|
||||
z = rearrange(z, 'b c h w -> b (h w) c')
|
||||
return z
|
||||
92
modelscope/models/cv/image_view_transform/ldm/model_irse.py
Normal file
92
modelscope/models/cv/image_view_transform/ldm/model_irse.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
|
||||
Module, PReLU, Sequential)
|
||||
|
||||
from .helpers import (Flatten, bottleneck_IR, bottleneck_IR_SE, get_blocks,
|
||||
l2_norm)
|
||||
|
||||
|
||||
class Backbone(Module):
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
num_layers,
|
||||
mode='ir',
|
||||
drop_ratio=0.4,
|
||||
affine=True):
|
||||
super(Backbone, self).__init__()
|
||||
assert input_size in [112, 224], 'input_size should be 112 or 224'
|
||||
assert num_layers in [50, 100,
|
||||
152], 'num_layers should be 50, 100 or 152'
|
||||
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
|
||||
blocks = get_blocks(num_layers)
|
||||
if mode == 'ir':
|
||||
unit_module = bottleneck_IR
|
||||
elif mode == 'ir_se':
|
||||
unit_module = bottleneck_IR_SE
|
||||
self.input_layer = Sequential(
|
||||
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
|
||||
PReLU(64))
|
||||
if input_size == 112:
|
||||
self.output_layer = Sequential(
|
||||
BatchNorm2d(512), Dropout(drop_ratio), Flatten(),
|
||||
Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine))
|
||||
else:
|
||||
self.output_layer = Sequential(
|
||||
BatchNorm2d(512), Dropout(drop_ratio), Flatten(),
|
||||
Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine))
|
||||
|
||||
modules = []
|
||||
for block in blocks:
|
||||
for bottleneck in block:
|
||||
modules.append(
|
||||
unit_module(bottleneck.in_channel, bottleneck.depth,
|
||||
bottleneck.stride))
|
||||
self.body = Sequential(*modules)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.body(x)
|
||||
x = self.output_layer(x)
|
||||
return l2_norm(x)
|
||||
|
||||
|
||||
def IR_50(input_size):
|
||||
"""Constructs a ir-50 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_101(input_size):
|
||||
"""Constructs a ir-101 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_152(input_size):
|
||||
"""Constructs a ir-152 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_50(input_size):
|
||||
"""Constructs a ir_se-50 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_101(input_size):
|
||||
"""Constructs a ir_se-101 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_152(input_size):
|
||||
"""Constructs a ir_se-152 model."""
|
||||
model = Backbone(
|
||||
input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
668
modelscope/models/cv/image_view_transform/ldm/modules.py
Normal file
668
modelscope/models/cv/image_view_transform/ldm/modules.py
Normal file
@@ -0,0 +1,668 @@
|
||||
import random
|
||||
from functools import partial
|
||||
|
||||
import clip
|
||||
import kornia
|
||||
import kornia.augmentation as K
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision import transforms
|
||||
from transformers import (CLIPTextModel, CLIPTokenizer, CLIPVisionModel,
|
||||
T5EncoderModel, T5Tokenizer)
|
||||
|
||||
from ..util import default, instantiate_from_config
|
||||
from .id_loss import IDFeatures
|
||||
from .util_diffusion import extract_into_tensor, make_beta_schedule, noise_like
|
||||
from .x_transformer import Encoder, TransformerWrapper
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class FaceClipEncoder(AbstractEncoder):
|
||||
|
||||
def __init__(self, augment=True, retreival_key=None):
|
||||
super().__init__()
|
||||
self.encoder = FrozenCLIPImageEmbedder()
|
||||
self.augment = augment
|
||||
self.retreival_key = retreival_key
|
||||
|
||||
def forward(self, img):
|
||||
encodings = []
|
||||
with torch.no_grad():
|
||||
x_offset = 125
|
||||
if self.retreival_key:
|
||||
# Assumes retrieved image are packed into the second half of channels
|
||||
face = img[:, 3:, 190:440, x_offset:(512 - x_offset)]
|
||||
other = img[:, :3, ...].clone()
|
||||
else:
|
||||
face = img[:, :, 190:440, x_offset:(512 - x_offset)]
|
||||
other = img.clone()
|
||||
|
||||
if self.augment:
|
||||
face = K.RandomHorizontalFlip()(face)
|
||||
|
||||
other[:, :, 190:440, x_offset:(512 - x_offset)] *= 0
|
||||
encodings = [
|
||||
self.encoder.encode(face),
|
||||
self.encoder.encode(other),
|
||||
]
|
||||
|
||||
return torch.cat(encodings, dim=1)
|
||||
|
||||
def encode(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros(
|
||||
(1, 2, 768),
|
||||
device=self.encoder.model.visual.conv1.weight.device)
|
||||
|
||||
return self(img)
|
||||
|
||||
|
||||
class FaceIdClipEncoder(AbstractEncoder):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = FrozenCLIPImageEmbedder()
|
||||
for p in self.encoder.parameters():
|
||||
p.requires_grad = False
|
||||
self.id = FrozenFaceEncoder(
|
||||
'/home/jpinkney/code/stable-diffusion/model_ir_se50.pth',
|
||||
augment=True)
|
||||
|
||||
def forward(self, img):
|
||||
encodings = []
|
||||
with torch.no_grad():
|
||||
face = kornia.geometry.resize(
|
||||
img, (256, 256), interpolation='bilinear', align_corners=True)
|
||||
|
||||
other = img.clone()
|
||||
other[:, :, 184:452, 122:396] *= 0
|
||||
encodings = [
|
||||
self.id.encode(face),
|
||||
self.encoder.encode(other),
|
||||
]
|
||||
|
||||
return torch.cat(encodings, dim=1)
|
||||
|
||||
def encode(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros(
|
||||
(1, 2, 768),
|
||||
device=self.encoder.model.visual.conv1.weight.device)
|
||||
|
||||
return self(img)
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
|
||||
def forward(self, batch, key=None):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
|
||||
class TransformerEmbedder(AbstractEncoder):
|
||||
"""Some transformer encoder layers"""
|
||||
|
||||
def __init__(self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size,
|
||||
max_seq_len=77,
|
||||
device='cuda'):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(
|
||||
num_tokens=vocab_size,
|
||||
max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer))
|
||||
|
||||
def forward(self, tokens):
|
||||
tokens = tokens.to(self.device) # meh
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
return z
|
||||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
class BERTTokenizer(AbstractEncoder):
|
||||
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||
|
||||
def __init__(self, device='cuda', vq_interface=True, max_length=77):
|
||||
super().__init__()
|
||||
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
self.device = device
|
||||
self.vq_interface = vq_interface
|
||||
self.max_length = max_length
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt')
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
return tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, text):
|
||||
tokens = self(text)
|
||||
if not self.vq_interface:
|
||||
return tokens
|
||||
return None, None, [None, None, tokens]
|
||||
|
||||
def decode(self, text):
|
||||
return text
|
||||
|
||||
|
||||
class BERTEmbedder(AbstractEncoder):
|
||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||
|
||||
def __init__(self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size=30522,
|
||||
max_seq_len=77,
|
||||
device='cuda',
|
||||
use_tokenizer=True,
|
||||
embedding_dropout=0.0):
|
||||
super().__init__()
|
||||
self.use_tknz_fn = use_tokenizer
|
||||
if self.use_tknz_fn:
|
||||
self.tknz_fn = BERTTokenizer(
|
||||
vq_interface=False, max_length=max_seq_len)
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(
|
||||
num_tokens=vocab_size,
|
||||
max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
emb_dropout=embedding_dropout)
|
||||
|
||||
def forward(self, text):
|
||||
if self.use_tknz_fn:
|
||||
tokens = self.tknz_fn(text)
|
||||
else:
|
||||
tokens = text
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
# output of length 77
|
||||
return self(text)
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
|
||||
def __init__(self,
|
||||
version='google/t5-v1_1-large',
|
||||
device='cuda',
|
||||
max_length=77
|
||||
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt')
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenFaceEncoder(AbstractEncoder):
|
||||
|
||||
def __init__(self, model_path, augment=False):
|
||||
super().__init__()
|
||||
self.loss_fn = IDFeatures(model_path)
|
||||
# face encoder is frozen
|
||||
for p in self.loss_fn.parameters():
|
||||
p.requires_grad = False
|
||||
# Mapper is trainable
|
||||
self.mapper = torch.nn.Linear(512, 768)
|
||||
p = 0.25
|
||||
if augment:
|
||||
self.augment = K.AugmentationSequential(
|
||||
K.RandomHorizontalFlip(p=0.5),
|
||||
K.RandomEqualize(p=p),
|
||||
# K.RandomPlanckianJitter(p=p),
|
||||
# K.RandomPlasmaBrightness(p=p),
|
||||
# K.RandomPlasmaContrast(p=p),
|
||||
# K.ColorJiggle(0.02, 0.2, 0.2, p=p),
|
||||
)
|
||||
else:
|
||||
self.augment = False
|
||||
|
||||
def forward(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros((1, 1, 768), device=self.mapper.weight.device)
|
||||
|
||||
if self.augment is not None:
|
||||
# Transforms require 0-1
|
||||
img = self.augment((img + 1) / 2)
|
||||
img = 2 * img - 1
|
||||
|
||||
feat = self.loss_fn(img, crop=True)
|
||||
feat = self.mapper(feat.unsqueeze(1))
|
||||
return feat
|
||||
|
||||
def encode(self, img):
|
||||
return self(img)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
|
||||
def __init__(self,
|
||||
version='openai/clip-vit-large-patch14',
|
||||
device='cuda',
|
||||
max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt')
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class ClipImageProjector(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
version='openai/clip-vit-large-patch14',
|
||||
max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.model = CLIPVisionModel.from_pretrained(version)
|
||||
self.model.train()
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.antialias = True
|
||||
self.mapper = torch.nn.Linear(1024, 768)
|
||||
self.register_buffer(
|
||||
'mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer(
|
||||
'std',
|
||||
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
|
||||
persistent=False)
|
||||
null_cond = self.get_null_cond(version, max_length)
|
||||
self.register_buffer('null_cond', null_cond)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_null_cond(self, version, max_length):
|
||||
device = self.mean.device
|
||||
embedder = FrozenCLIPEmbedder(
|
||||
version=version, device=device, max_length=max_length)
|
||||
null_cond = embedder([''])
|
||||
return null_cond
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(
|
||||
x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
if isinstance(x, list):
|
||||
return self.null_cond
|
||||
# x is assumed to be in range [-1,1]
|
||||
x = self.preprocess(x)
|
||||
outputs = self.model(pixel_values=x)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
last_hidden_state = self.mapper(last_hidden_state)
|
||||
return F.pad(
|
||||
last_hidden_state,
|
||||
[0, 0, 0, self.max_length - last_hidden_state.shape[1], 0, 0])
|
||||
|
||||
def encode(self, im):
|
||||
return self(im)
|
||||
|
||||
|
||||
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
|
||||
|
||||
def __init__(self,
|
||||
version='openai/clip-vit-large-patch14',
|
||||
device='cuda',
|
||||
max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.embedder = FrozenCLIPEmbedder(
|
||||
version=version, device=device, max_length=max_length)
|
||||
self.projection = torch.nn.Linear(768, 768)
|
||||
|
||||
def forward(self, text):
|
||||
z = self.embedder(text)
|
||||
return self.projection(z)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model='ViT-L/14',
|
||||
jit=False,
|
||||
device='cpu',
|
||||
antialias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||
# We don't use the text part so delete it
|
||||
del self.model.transformer
|
||||
self.antialias = antialias
|
||||
self.register_buffer(
|
||||
'mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer(
|
||||
'std',
|
||||
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
|
||||
persistent=False)
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(
|
||||
x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
# x is assumed to be in range [-1,1]
|
||||
if isinstance(x, list):
|
||||
# [""] denotes condition dropout for ucg
|
||||
device = self.model.visual.conv1.weight.device
|
||||
return torch.zeros(1, 768, device=device)
|
||||
return self.model.encode_image(self.preprocess(x)).float()
|
||||
|
||||
def encode(self, im):
|
||||
return self(im).unsqueeze(1)
|
||||
|
||||
|
||||
class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model='ViT-L/14',
|
||||
jit=False,
|
||||
device='cpu',
|
||||
antialias=True,
|
||||
max_crops=5,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||
# We don't use the text part so delete it
|
||||
del self.model.transformer
|
||||
self.antialias = antialias
|
||||
self.register_buffer(
|
||||
'mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer(
|
||||
'std',
|
||||
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
|
||||
persistent=False)
|
||||
self.max_crops = max_crops
|
||||
|
||||
def preprocess(self, x):
|
||||
|
||||
# Expects inputs in the range -1, 1
|
||||
randcrop = transforms.RandomResizedCrop(
|
||||
224, scale=(0.085, 1.0), ratio=(1, 1))
|
||||
max_crops = self.max_crops
|
||||
patches = []
|
||||
crops = [randcrop(x) for _ in range(max_crops)]
|
||||
patches.extend(crops)
|
||||
x = torch.cat(patches, dim=0)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
# x is assumed to be in range [-1,1]
|
||||
if isinstance(x, list):
|
||||
# [""] denotes condition dropout for ucg
|
||||
device = self.model.visual.conv1.weight.device
|
||||
return torch.zeros(1, self.max_crops, 768, device=device)
|
||||
batch_tokens = []
|
||||
for im in x:
|
||||
patches = self.preprocess(im.unsqueeze(0))
|
||||
tokens = self.model.encode_image(patches).float()
|
||||
for t in tokens:
|
||||
if random.random() < 0.1:
|
||||
t *= 0
|
||||
batch_tokens.append(tokens.unsqueeze(0))
|
||||
|
||||
return torch.cat(batch_tokens, dim=0)
|
||||
|
||||
def encode(self, im):
|
||||
return self(im)
|
||||
|
||||
|
||||
class SpatialRescaler(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
n_stages=1,
|
||||
method='bilinear',
|
||||
multiplier=0.5,
|
||||
in_channels=3,
|
||||
out_channels=None,
|
||||
bias=False):
|
||||
super().__init__()
|
||||
self.n_stages = n_stages
|
||||
assert self.n_stages >= 0
|
||||
assert method in [
|
||||
'nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area'
|
||||
]
|
||||
self.multiplier = multiplier
|
||||
self.interpolator = partial(
|
||||
torch.nn.functional.interpolate, mode=method)
|
||||
self.remap_output = out_channels is not None
|
||||
if self.remap_output:
|
||||
print(
|
||||
f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'
|
||||
)
|
||||
self.channel_mapper = nn.Conv2d(
|
||||
in_channels, out_channels, 1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
for stage in range(self.n_stages):
|
||||
x = self.interpolator(x, scale_factor=self.multiplier)
|
||||
|
||||
if self.remap_output:
|
||||
x = self.channel_mapper(x)
|
||||
return x
|
||||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
class LowScaleEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
model_config,
|
||||
linear_start,
|
||||
linear_end,
|
||||
timesteps=1000,
|
||||
max_noise_level=250,
|
||||
output_size=64,
|
||||
scale_factor=1.0):
|
||||
super().__init__()
|
||||
self.max_noise_level = max_noise_level
|
||||
self.model = instantiate_from_config(model_config)
|
||||
self.augmentation_schedule = self.register_schedule(
|
||||
timesteps=timesteps,
|
||||
linear_start=linear_start,
|
||||
linear_end=linear_end)
|
||||
self.out_size = output_size
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def register_schedule(self,
|
||||
beta_schedule='linear',
|
||||
timesteps=1000,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
cosine_s=8e-3):
|
||||
betas = make_beta_schedule(
|
||||
beta_schedule,
|
||||
timesteps,
|
||||
linear_start=linear_start,
|
||||
linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev',
|
||||
to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
|
||||
* x_start
|
||||
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t,
|
||||
x_start.shape) * noise)
|
||||
|
||||
def forward(self, x):
|
||||
z = self.model.encode(x).sample()
|
||||
z = z * self.scale_factor
|
||||
noise_level = torch.randint(
|
||||
0, self.max_noise_level, (x.shape[0], ), device=x.device).long()
|
||||
z = self.q_sample(z, noise_level)
|
||||
if self.out_size is not None:
|
||||
z = torch.nn.functional.interpolate(
|
||||
z, size=self.out_size,
|
||||
mode='nearest') # TODO: experiment with mode
|
||||
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||
return z, noise_level
|
||||
|
||||
def decode(self, z):
|
||||
z = z / self.scale_factor
|
||||
return self.model.decode(z)
|
||||
1010
modelscope/models/cv/image_view_transform/ldm/openaimodel.py
Normal file
1010
modelscope/models/cv/image_view_transform/ldm/openaimodel.py
Normal file
File diff suppressed because it is too large
Load Diff
349
modelscope/models/cv/image_view_transform/ldm/plms.py
Executable file
349
modelscope/models/cv/image_view_transform/ldm/plms.py
Executable file
@@ -0,0 +1,349 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .sampling_util import norm_thresholding
|
||||
from .util_diffusion import (make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps, noise_like)
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
|
||||
def __init__(self, model, schedule='linear', **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device('cuda'):
|
||||
attr = attr.to(torch.device('cuda'))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.,
|
||||
verbose=True):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
def to_torch(x):
|
||||
return x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev',
|
||||
to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
alp_1 = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
|
||||
alp_2 = (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
alp_1 * alp_2)
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = list(reversed(range(
|
||||
0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
||||
0]
|
||||
print(f'Running PLMS Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full((b, ),
|
||||
time_range[min(i + 1,
|
||||
len(time_range) - 1)],
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps,
|
||||
t_next=ts_next,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=None,
|
||||
t_next=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat([
|
||||
unconditional_conditioning[k][i], c[k][i]
|
||||
]) for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat(
|
||||
[unconditional_conditioning[k], c[k]])
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in,
|
||||
c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
||||
**corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
if use_original_steps:
|
||||
alphas_prev = self.model.alphas_cumprod_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
alphas_prev = self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1),
|
||||
alphas_prev[index],
|
||||
device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1),
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2]
|
||||
- 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
51
modelscope/models/cv/image_view_transform/ldm/sampling_util.py
Executable file
51
modelscope/models/cv/image_view_transform/ldm/sampling_util.py
Executable file
@@ -0,0 +1,51 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f'input has {x.ndim} dims but target_dims is {target_dims}, which is less'
|
||||
)
|
||||
return x[(..., ) + (None, ) * dims_to_append]
|
||||
|
||||
|
||||
def renorm_thresholding(x0, value):
|
||||
# renorm
|
||||
pred_max = x0.max()
|
||||
pred_min = x0.min()
|
||||
pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
|
||||
pred_x0 = 2 * pred_x0 - 1. # -1 ... 1
|
||||
|
||||
s = torch.quantile(
|
||||
rearrange(pred_x0, 'b ... -> b (...)').abs(), value, dim=-1)
|
||||
s.clamp_(min=1.0)
|
||||
s = s.view(-1, *((1, ) * (pred_x0.ndim - 1)))
|
||||
|
||||
# clip by threshold
|
||||
# pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
|
||||
|
||||
# temporary hack: numpy on cpu
|
||||
pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(),
|
||||
s.cpu().numpy()) / s.cpu().numpy()
|
||||
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
|
||||
|
||||
# re.renorm
|
||||
pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
|
||||
pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
|
||||
return pred_x0
|
||||
|
||||
|
||||
def norm_thresholding(x0, value):
|
||||
s = append_dims(
|
||||
x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||
return x0 * (value / s)
|
||||
|
||||
|
||||
def spatial_norm_thresholding(x0, value):
|
||||
# b c h w
|
||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||
return x0 * (value / s)
|
||||
308
modelscope/models/cv/image_view_transform/ldm/util_diffusion.py
Normal file
308
modelscope/models/cv/image_view_transform/ldm/util_diffusion.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# adopted from
|
||||
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# and
|
||||
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
# and
|
||||
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
#
|
||||
# thanks!
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
|
||||
from ..util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(schedule,
|
||||
n_timestep,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
cosine_s=8e-3):
|
||||
if schedule == 'linear':
|
||||
betas = (
|
||||
torch.linspace(
|
||||
linear_start**0.5,
|
||||
linear_end**0.5,
|
||||
n_timestep,
|
||||
dtype=torch.float64)**2)
|
||||
|
||||
elif schedule == 'cosine':
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
|
||||
+ cosine_s)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == 'sqrt_linear':
|
||||
betas = torch.linspace(
|
||||
linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == 'sqrt':
|
||||
betas = torch.linspace(
|
||||
linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method,
|
||||
num_ddim_timesteps,
|
||||
num_ddpm_timesteps,
|
||||
verbose=True):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
|
||||
num_ddim_timesteps))**2).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
||||
)
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums,
|
||||
ddim_timesteps,
|
||||
eta,
|
||||
verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]]
|
||||
+ alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
alpha_1 = (1 - alphas_prev) / (1 - alphas)
|
||||
alpha_2 = (1 - alphas / alphas_prev)
|
||||
sigmas = eta * np.sqrt(alpha_1 * alpha_2)
|
||||
if verbose:
|
||||
print(
|
||||
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
|
||||
)
|
||||
print(
|
||||
f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
|
||||
)
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [
|
||||
x.detach().requires_grad_(True) for x in ctx.input_tensors
|
||||
]
|
||||
with torch.enable_grad():
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f'unsupported dimensions: {dims}')
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f'unsupported dimensions: {dims}')
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(
|
||||
c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
|
||||
def repeat_noise():
|
||||
return torch.randn((1, *shape[1:]),
|
||||
device=device).repeat(shape[0],
|
||||
*((1, ) * (len(shape) - 1)))
|
||||
|
||||
def noise():
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
return repeat_noise() if repeat else noise()
|
||||
680
modelscope/models/cv/image_view_transform/ldm/x_transformer.py
Normal file
680
modelscope/models/cv/image_view_transform/ldm/x_transformer.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, reduce, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
Intermediates = namedtuple('Intermediates',
|
||||
['pre_softmax_attn', 'post_softmax_attn'])
|
||||
|
||||
LayerIntermediates = namedtuple('Intermediates',
|
||||
['hiddens', 'attn_intermediates'])
|
||||
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(max_seq_len, dim)
|
||||
self.init_()
|
||||
|
||||
def init_(self):
|
||||
nn.init.normal_(self.emb.weight, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
n = torch.arange(x.shape[1], device=x.device)
|
||||
return self.emb(n)[None, :, :]
|
||||
|
||||
|
||||
class FixedPositionalEmbedding(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
inv_freq = 1. / (10000**(torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(self, x, seq_dim=1, offset=0):
|
||||
t = torch.arange(
|
||||
x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
||||
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
|
||||
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
||||
return emb[None, :, :]
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def always(val):
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
return val
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def not_equals(val):
|
||||
|
||||
def inner(x):
|
||||
return x != val
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def equals(val):
|
||||
|
||||
def inner(x):
|
||||
return x == val
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def max_neg_value(tensor):
|
||||
return -torch.finfo(tensor.dtype).max
|
||||
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(), dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val, )
|
||||
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(
|
||||
partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(
|
||||
map(lambda x: (x[0][len(prefix):], x[1]),
|
||||
tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
|
||||
# classes
|
||||
class Scale(nn.Module):
|
||||
|
||||
def __init__(self, value, fn):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x, *rest = self.fn(x, **kwargs)
|
||||
return (x * self.value, *rest)
|
||||
|
||||
|
||||
class Rezero(nn.Module):
|
||||
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.g = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x, *rest = self.fn(x, **kwargs)
|
||||
return (x * self.g, *rest)
|
||||
|
||||
|
||||
class ScaleNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.scale = dim**-0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, dim, eps=1e-8):
|
||||
super().__init__()
|
||||
self.scale = dim**-0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
|
||||
def forward(self, x, residual):
|
||||
return x + residual
|
||||
|
||||
|
||||
class GRUGating(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gru = nn.GRUCell(dim, dim)
|
||||
|
||||
def forward(self, x, residual):
|
||||
gated_output = self.gru(
|
||||
rearrange(x, 'b n d -> (b n) d'),
|
||||
rearrange(residual, 'b n d -> (b n) d'))
|
||||
|
||||
return gated_output.reshape_as(x)
|
||||
|
||||
|
||||
# feedforward
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(
|
||||
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# attention.
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
dim_head=DEFAULT_DIM_HEAD,
|
||||
heads=8,
|
||||
causal=False,
|
||||
mask=None,
|
||||
talking_heads=False,
|
||||
sparse_topk=None,
|
||||
use_entmax15=False,
|
||||
num_mem_kv=0,
|
||||
dropout=0.,
|
||||
on_attn=False):
|
||||
super().__init__()
|
||||
if use_entmax15:
|
||||
raise NotImplementedError(
|
||||
'Check out entmax activation instead of softmax activation!')
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.causal = causal
|
||||
self.mask = mask
|
||||
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# talking heads
|
||||
self.talking_heads = talking_heads
|
||||
if talking_heads:
|
||||
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
||||
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
# explicit topk sparse attention
|
||||
self.sparse_topk = sparse_topk
|
||||
|
||||
# entmax
|
||||
|
||||
self.attn_fn = F.softmax
|
||||
|
||||
# add memory key / values
|
||||
self.num_mem_kv = num_mem_kv
|
||||
if num_mem_kv > 0:
|
||||
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
||||
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
||||
|
||||
# attention on attention
|
||||
self.attn_on_attn = on_attn
|
||||
self.to_out = nn.Sequential(nn.Linear(
|
||||
inner_dim, dim
|
||||
* 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
rel_pos=None,
|
||||
sinusoidal_emb=None,
|
||||
prev_attn=None,
|
||||
mem=None):
|
||||
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
|
||||
kv_input = default(context, x)
|
||||
|
||||
q_input = x
|
||||
k_input = kv_input
|
||||
v_input = kv_input
|
||||
|
||||
if exists(mem):
|
||||
k_input = torch.cat((mem, k_input), dim=-2)
|
||||
v_input = torch.cat((mem, v_input), dim=-2)
|
||||
|
||||
if exists(sinusoidal_emb):
|
||||
# in shortformer, the query would start at a position offset depending on the past cached memory
|
||||
offset = k_input.shape[-2] - q_input.shape[-2]
|
||||
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
|
||||
k_input = k_input + sinusoidal_emb(k_input)
|
||||
|
||||
q = self.to_q(q_input)
|
||||
k = self.to_k(k_input)
|
||||
v = self.to_v(v_input)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
input_mask = None
|
||||
if any(map(exists, (mask, context_mask))):
|
||||
q_mask = default(mask, lambda: torch.ones(
|
||||
(b, n), device=device).bool())
|
||||
k_mask = q_mask if not exists(context) else context_mask
|
||||
k_mask = default(
|
||||
k_mask, lambda: torch.ones(
|
||||
(b, k.shape[-2]), device=device).bool())
|
||||
q_mask = rearrange(q_mask, 'b i -> b () i ()')
|
||||
k_mask = rearrange(k_mask, 'b j -> b () () j')
|
||||
input_mask = q_mask * k_mask
|
||||
|
||||
if self.num_mem_kv > 0:
|
||||
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b),
|
||||
(self.mem_k, self.mem_v))
|
||||
k = torch.cat((mem_k, k), dim=-2)
|
||||
v = torch.cat((mem_v, v), dim=-2)
|
||||
if exists(input_mask):
|
||||
input_mask = F.pad(
|
||||
input_mask, (self.num_mem_kv, 0), value=True)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
mask_value = max_neg_value(dots)
|
||||
|
||||
if exists(prev_attn):
|
||||
dots = dots + prev_attn
|
||||
|
||||
pre_softmax_attn = dots
|
||||
|
||||
if talking_heads:
|
||||
dots = einsum('b h i j, h k -> b k i j', dots,
|
||||
self.pre_softmax_proj).contiguous()
|
||||
|
||||
if exists(rel_pos):
|
||||
dots = rel_pos(dots)
|
||||
|
||||
if exists(input_mask):
|
||||
dots.masked_fill_(~input_mask, mask_value)
|
||||
del input_mask
|
||||
|
||||
if self.causal:
|
||||
i, j = dots.shape[-2:]
|
||||
r = torch.arange(i, device=device)
|
||||
mask = rearrange(r, 'i -> () () i ()') < rearrange(
|
||||
r, 'j -> () () () j')
|
||||
mask = F.pad(mask, (j - i, 0), value=False)
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
|
||||
top, _ = dots.topk(self.sparse_topk, dim=-1)
|
||||
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
|
||||
mask = dots < vk
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = self.attn_fn(dots, dim=-1)
|
||||
post_softmax_attn = attn
|
||||
|
||||
attn = self.dropout(attn)
|
||||
|
||||
if talking_heads:
|
||||
attn = einsum('b h i j, h k -> b k i j', attn,
|
||||
self.post_softmax_proj).contiguous()
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
intermediates = Intermediates(
|
||||
pre_softmax_attn=pre_softmax_attn,
|
||||
post_softmax_attn=post_softmax_attn)
|
||||
|
||||
return self.to_out(out), intermediates
|
||||
|
||||
|
||||
class AttentionLayers(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
depth,
|
||||
heads=8,
|
||||
causal=False,
|
||||
cross_attend=False,
|
||||
only_cross=False,
|
||||
use_scalenorm=False,
|
||||
use_rmsnorm=False,
|
||||
use_rezero=False,
|
||||
rel_pos_num_buckets=32,
|
||||
rel_pos_max_distance=128,
|
||||
position_infused_attn=False,
|
||||
custom_layers=None,
|
||||
sandwich_coef=None,
|
||||
par_ratio=None,
|
||||
residual_attn=False,
|
||||
cross_residual_attn=False,
|
||||
macaron=False,
|
||||
pre_norm=True,
|
||||
gate_residual=False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
||||
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
|
||||
|
||||
attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.has_pos_emb = position_infused_attn
|
||||
self.pia_pos_emb = FixedPositionalEmbedding(
|
||||
dim) if position_infused_attn else None
|
||||
self.rotary_pos_emb = always(None)
|
||||
|
||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'error'
|
||||
self.rel_pos = None
|
||||
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
self.residual_attn = residual_attn
|
||||
self.cross_residual_attn = cross_residual_attn
|
||||
|
||||
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
||||
norm_class = RMSNorm if use_rmsnorm else norm_class
|
||||
norm_fn = partial(norm_class, dim)
|
||||
|
||||
norm_fn = nn.Identity if use_rezero else norm_fn
|
||||
branch_fn = Rezero if use_rezero else None
|
||||
|
||||
if cross_attend and not only_cross:
|
||||
default_block = ('a', 'c', 'f')
|
||||
elif cross_attend and only_cross:
|
||||
default_block = ('c', 'f')
|
||||
else:
|
||||
default_block = ('a', 'f')
|
||||
|
||||
if macaron:
|
||||
default_block = ('f', ) + default_block
|
||||
|
||||
if exists(custom_layers):
|
||||
layer_types = custom_layers
|
||||
elif exists(par_ratio):
|
||||
par_depth = depth * len(default_block)
|
||||
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
||||
default_block = tuple(filter(not_equals('f'), default_block))
|
||||
par_attn = par_depth // par_ratio
|
||||
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
||||
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
||||
assert len(
|
||||
default_block
|
||||
) <= par_width, 'default block is too large for par_ratio'
|
||||
par_block = default_block + ('f', ) * (
|
||||
par_width - len(default_block))
|
||||
par_head = par_block * par_attn
|
||||
layer_types = par_head + ('f', ) * (par_depth - len(par_head))
|
||||
elif exists(sandwich_coef):
|
||||
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
||||
layer_types = ('a', ) * sandwich_coef + default_block * (
|
||||
depth - sandwich_coef) + ('f', ) * sandwich_coef
|
||||
else:
|
||||
layer_types = default_block * depth
|
||||
|
||||
self.layer_types = layer_types
|
||||
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
||||
|
||||
for layer_type in self.layer_types:
|
||||
if layer_type == 'a':
|
||||
layer = Attention(
|
||||
dim, heads=heads, causal=causal, **attn_kwargs)
|
||||
elif layer_type == 'c':
|
||||
layer = Attention(dim, heads=heads, **attn_kwargs)
|
||||
elif layer_type == 'f':
|
||||
layer = FeedForward(dim, **ff_kwargs)
|
||||
layer = layer if not macaron else Scale(0.5, layer)
|
||||
else:
|
||||
raise Exception(f'invalid layer type {layer_type}')
|
||||
|
||||
if isinstance(layer, Attention) and exists(branch_fn):
|
||||
layer = branch_fn(layer)
|
||||
|
||||
if gate_residual:
|
||||
residual_fn = GRUGating(dim)
|
||||
else:
|
||||
residual_fn = Residual()
|
||||
|
||||
self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
mems=None,
|
||||
return_hiddens=False):
|
||||
hiddens = []
|
||||
intermediates = []
|
||||
prev_attn = None
|
||||
prev_cross_attn = None
|
||||
|
||||
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
||||
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
|
||||
zip(self.layer_types, self.layers)):
|
||||
is_last = ind == (len(self.layers) - 1)
|
||||
|
||||
if layer_type == 'a':
|
||||
hiddens.append(x)
|
||||
layer_mem = mems.pop(0)
|
||||
|
||||
residual = x
|
||||
|
||||
if self.pre_norm:
|
||||
x = norm(x)
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter = block(
|
||||
x,
|
||||
mask=mask,
|
||||
sinusoidal_emb=self.pia_pos_emb,
|
||||
rel_pos=self.rel_pos,
|
||||
prev_attn=prev_attn,
|
||||
mem=layer_mem)
|
||||
elif layer_type == 'c':
|
||||
out, inter = block(
|
||||
x,
|
||||
context=context,
|
||||
mask=mask,
|
||||
context_mask=context_mask,
|
||||
prev_attn=prev_cross_attn)
|
||||
elif layer_type == 'f':
|
||||
out = block(x)
|
||||
|
||||
x = residual_fn(out, residual)
|
||||
|
||||
if layer_type in ('a', 'c'):
|
||||
intermediates.append(inter)
|
||||
|
||||
if layer_type == 'a' and self.residual_attn:
|
||||
prev_attn = inter.pre_softmax_attn
|
||||
elif layer_type == 'c' and self.cross_residual_attn:
|
||||
prev_cross_attn = inter.pre_softmax_attn
|
||||
|
||||
if not self.pre_norm and not is_last:
|
||||
x = norm(x)
|
||||
|
||||
if return_hiddens:
|
||||
intermediates = LayerIntermediates(
|
||||
hiddens=hiddens, attn_intermediates=intermediates)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(AttentionLayers):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
assert 'causal' not in kwargs, 'cannot set causality on encoder'
|
||||
super().__init__(causal=False, **kwargs)
|
||||
|
||||
|
||||
class TransformerWrapper(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
num_tokens,
|
||||
max_seq_len,
|
||||
attn_layers,
|
||||
emb_dim=None,
|
||||
max_mem_len=0.,
|
||||
emb_dropout=0.,
|
||||
num_memory_tokens=None,
|
||||
tie_embedding=False,
|
||||
use_pos_emb=True):
|
||||
super().__init__()
|
||||
assert isinstance(
|
||||
attn_layers, AttentionLayers
|
||||
), 'attention layers must be one of Encoder or Decoder'
|
||||
|
||||
dim = attn_layers.dim
|
||||
emb_dim = default(emb_dim, dim)
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_mem_len = max_mem_len
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.project_emb = nn.Linear(emb_dim,
|
||||
dim) if emb_dim != dim else nn.Identity()
|
||||
self.attn_layers = attn_layers
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.init_()
|
||||
|
||||
self.to_logits = nn.Linear(
|
||||
dim, num_tokens
|
||||
) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
||||
|
||||
# memory tokens (like [cls]) from Memory Transformers paper
|
||||
num_memory_tokens = default(num_memory_tokens, 0)
|
||||
self.num_memory_tokens = num_memory_tokens
|
||||
if num_memory_tokens > 0:
|
||||
self.memory_tokens = nn.Parameter(
|
||||
torch.randn(num_memory_tokens, dim))
|
||||
|
||||
# let funnel encoder know number of memory tokens, if specified
|
||||
if hasattr(attn_layers, 'num_memory_tokens'):
|
||||
attn_layers.num_memory_tokens = num_memory_tokens
|
||||
|
||||
def init_(self):
|
||||
nn.init.normal_(self.token_emb.weight, std=0.02)
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
return_embeddings=False,
|
||||
mask=None,
|
||||
return_mems=False,
|
||||
return_attn=False,
|
||||
mems=None,
|
||||
**kwargs):
|
||||
b, num_mem = *x.shape[0], self.num_memory_tokens
|
||||
x = self.token_emb(x)
|
||||
x += self.pos_emb(x)
|
||||
x = self.emb_dropout(x)
|
||||
|
||||
x = self.project_emb(x)
|
||||
|
||||
if num_mem > 0:
|
||||
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
|
||||
x = torch.cat((mem, x), dim=1)
|
||||
|
||||
# auto-handle masking after appending memory tokens
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (num_mem, 0), value=True)
|
||||
|
||||
x, intermediates = self.attn_layers(
|
||||
x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
||||
x = self.norm(x)
|
||||
|
||||
mem, x = x[:, :num_mem], x[:, num_mem:]
|
||||
|
||||
out = self.to_logits(x) if not return_embeddings else x
|
||||
|
||||
if return_mems:
|
||||
hiddens = intermediates.hiddens
|
||||
new_mems = list(
|
||||
map(lambda pair: torch.cat(pair, dim=-2), zip(
|
||||
mems, hiddens))) if exists(mems) else hiddens
|
||||
new_mems = list(
|
||||
map(lambda t: t[..., -self.max_mem_len:, :].detach(),
|
||||
new_mems))
|
||||
return out, new_mems
|
||||
|
||||
if return_attn:
|
||||
attn_maps = list(
|
||||
map(lambda t: t.post_softmax_attn,
|
||||
intermediates.attn_intermediates))
|
||||
return out, attn_maps
|
||||
|
||||
return out
|
||||
297
modelscope/models/cv/image_view_transform/util.py
Executable file
297
modelscope/models/cv/image_view_transform/util.py
Executable file
@@ -0,0 +1,297 @@
|
||||
import importlib
|
||||
import os
|
||||
import time
|
||||
from inspect import isfunction
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torchvision
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from torch import optim
|
||||
|
||||
|
||||
def pil_rectangle_crop(im):
|
||||
width, height = im.size # Get dimensions
|
||||
|
||||
if width <= height:
|
||||
left = 0
|
||||
right = width
|
||||
top = (height - width) / 2
|
||||
bottom = (height + width) / 2
|
||||
else:
|
||||
|
||||
top = 0
|
||||
bottom = height
|
||||
left = (width - height) / 2
|
||||
bottom = (width + height) / 2
|
||||
|
||||
# Crop the center of the image
|
||||
im = im.crop((left, top, right, bottom))
|
||||
return im
|
||||
|
||||
|
||||
def add_margin(pil_img, color, size=256):
|
||||
width, height = pil_img.size
|
||||
result = Image.new(pil_img.mode, (size, size), color)
|
||||
result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
|
||||
return result
|
||||
|
||||
|
||||
# def create_carvekit_interface():
|
||||
# # Check doc strings for more information
|
||||
# interface = HiInterface(
|
||||
# object_type='object', # Can be "object" or "hairs-like".
|
||||
# batch_size_seg=5,
|
||||
# batch_size_matting=1,
|
||||
# device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
# seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
|
||||
# matting_mask_size=2048,
|
||||
# trimap_prob_threshold=231,
|
||||
# trimap_dilation=30,
|
||||
# trimap_erosion_iters=5,
|
||||
# fp16=False)
|
||||
|
||||
# return interface
|
||||
|
||||
|
||||
def load_and_preprocess(interface, input_im):
|
||||
'''
|
||||
:param input_im (PIL Image).
|
||||
:return image (H, W, 3) array in [0, 1].
|
||||
'''
|
||||
# See https://github.com/Ir1d/image-background-remove-tool
|
||||
image = input_im.convert('RGB')
|
||||
|
||||
image_without_background = interface([image])[0]
|
||||
image_without_background = np.array(image_without_background)
|
||||
est_seg = image_without_background > 127
|
||||
image = np.array(image)
|
||||
foreground = est_seg[:, :, -1].astype(np.bool_)
|
||||
image[~foreground] = [255., 255., 255.]
|
||||
x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8))
|
||||
image = image[y:y + h, x:x + w, :]
|
||||
image = PIL.Image.fromarray(np.array(image))
|
||||
|
||||
# resize image such that long edge is 512
|
||||
image.thumbnail([200, 200], Image.Resampling.LANCZOS)
|
||||
image = add_margin(image, (255, 255, 255), size=256)
|
||||
image = np.array(image)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new('RGB', wh, color='white')
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = '\n'.join(xc[bi][start:start + nc]
|
||||
for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill='black', font=font)
|
||||
except UnicodeEncodeError:
|
||||
print('Cant encode string for logging. Skipping.')
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(
|
||||
f'{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.'
|
||||
)
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if 'target' not in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == '__is_unconditional__':
|
||||
return None
|
||||
raise KeyError('Expected key `target` to instantiate.')
|
||||
return get_obj_from_str(config['target'])(**config.get('params', dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit('.', 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
class AdamWwithEMAandWings(optim.Optimizer):
|
||||
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
||||
def __init__(self,
|
||||
params,
|
||||
lr=1.e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1.e-8,
|
||||
weight_decay=1.e-2,
|
||||
amsgrad=False,
|
||||
ema_decay=0.9999,
|
||||
ema_power=1.,
|
||||
param_names=()):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError('Invalid learning rate: {}'.format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError('Invalid epsilon value: {}'.format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError('Invalid beta parameter at index 0: {}'.format(
|
||||
betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError('Invalid beta parameter at index 1: {}'.format(
|
||||
betas[1]))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(
|
||||
'Invalid weight_decay value: {}'.format(weight_decay))
|
||||
if not 0.0 <= ema_decay <= 1.0:
|
||||
raise ValueError('Invalid ema_decay value: {}'.format(ema_decay))
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=amsgrad,
|
||||
ema_decay=ema_decay,
|
||||
ema_power=ema_power,
|
||||
param_names=param_names)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
ema_params_with_grad = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps = []
|
||||
amsgrad = group['amsgrad']
|
||||
beta1, beta2 = group['betas']
|
||||
ema_decay = group['ema_decay']
|
||||
ema_power = group['ema_power']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
'AdamW does not support sparse gradients')
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of parameter values
|
||||
state['param_exp_avg'] = p.detach().float().clone()
|
||||
|
||||
exp_avgs.append(state['exp_avg'])
|
||||
exp_avg_sqs.append(state['exp_avg_sq'])
|
||||
ema_params_with_grad.append(state['param_exp_avg'])
|
||||
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
||||
|
||||
# update the steps for each param group update
|
||||
state['step'] += 1
|
||||
# record the step after step update
|
||||
state_steps.append(state['step'])
|
||||
|
||||
optim._functional.adamw(
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
eps=group['eps'],
|
||||
maximize=False)
|
||||
|
||||
cur_ema_decay = min(ema_decay, 1 - state['step']**-ema_power)
|
||||
for param, ema_param in zip(params_with_grad,
|
||||
ema_params_with_grad):
|
||||
ema_param.mul_(cur_ema_decay).add_(
|
||||
param.float(), alpha=1 - cur_ema_decay)
|
||||
|
||||
return loss
|
||||
@@ -36,6 +36,7 @@ class OCRDetection(TorchModel):
|
||||
self.return_polygon = cfgs.model.inference_kwargs.return_polygon
|
||||
self.backbone = cfgs.model.backbone
|
||||
self.detector = None
|
||||
self.onnx_export = False
|
||||
if self.backbone == 'resnet50':
|
||||
self.detector = VLPTModel()
|
||||
elif self.backbone == 'resnet18':
|
||||
@@ -62,11 +63,20 @@ class OCRDetection(TorchModel):
|
||||
org_shape (`List`): image original shape,
|
||||
value is [height, width].
|
||||
"""
|
||||
pred = self.detector(input['img'])
|
||||
if type(input) is dict:
|
||||
pred = self.detector(input['img'])
|
||||
else:
|
||||
# for onnx convert
|
||||
input = {'img': input, 'org_shape': [800, 800]}
|
||||
pred = self.detector(input['img'])
|
||||
return {'results': pred, 'org_shape': input['org_shape']}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
pred = inputs['results'][0]
|
||||
|
||||
if self.onnx_export:
|
||||
return pred
|
||||
|
||||
height, width = inputs['org_shape']
|
||||
segmentation = pred > self.thresh
|
||||
if self.return_polygon:
|
||||
|
||||
@@ -164,15 +164,17 @@ def polygons_from_bitmap(pred, _bitmap, dest_width, dest_height):
|
||||
return boxes, scores
|
||||
|
||||
|
||||
def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height):
|
||||
def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height, is_numpy=False):
|
||||
"""
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
"""
|
||||
|
||||
assert _bitmap.size(0) == 1
|
||||
bitmap = _bitmap.cpu().numpy()[0]
|
||||
pred = pred.cpu().detach().numpy()[0]
|
||||
if is_numpy:
|
||||
bitmap = _bitmap[0]
|
||||
pred = pred[0]
|
||||
else:
|
||||
bitmap = _bitmap.cpu().numpy()[0]
|
||||
pred = pred.cpu().detach().numpy()[0]
|
||||
height, width = bitmap.shape
|
||||
boxes = []
|
||||
scores = []
|
||||
|
||||
@@ -109,8 +109,8 @@ class OCRRecognition(TorchModel):
|
||||
with open(dict_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
cnt = 1
|
||||
# ConvNextViT model start from index=2
|
||||
if self.do_chunking:
|
||||
# ConvNextViT and LightweightEdge model start from index=2
|
||||
if cfgs.model.recognizer == 'ConvNextViT' or cfgs.model.recognizer == 'LightweightEdge':
|
||||
cnt += 1
|
||||
for line in lines:
|
||||
line = line.strip('\n')
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .nas_block import plnas_linear_mix_se
|
||||
@@ -16,27 +17,20 @@ class LightweightEdge(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(LightweightEdge, self).__init__()
|
||||
self.FeatureExtraction = plnas_linear_mix_se(3, 123)
|
||||
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
|
||||
(None, 1)) # Transform final (imgH/16-1) -> 1
|
||||
self.dropout = nn.Dropout(0.3)
|
||||
self.Prediction = nn.Sequential(
|
||||
OrderedDict([
|
||||
('fc1', nn.Linear(123, 120)),
|
||||
('bn', nn.BatchNorm1d(120)),
|
||||
('fc2', nn.Linear(120, 7642)),
|
||||
]))
|
||||
self.our_nas_model = plnas_linear_mix_se(1, 128)
|
||||
self.embed_dim = 128
|
||||
self.head = nn.Linear(self.embed_dim, 7644)
|
||||
|
||||
def forward(self, input):
|
||||
visual_feature = self.FeatureExtraction(input)
|
||||
visual_feature = self.AdaptiveAvgPool(
|
||||
visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
|
||||
visual_feature = visual_feature.squeeze(3)
|
||||
visual_feature = self.dropout(visual_feature)
|
||||
prediction = self.Prediction.fc1(visual_feature.contiguous())
|
||||
b, t, c = prediction.shape
|
||||
prediction = self.Prediction.bn(prediction.view(b * t,
|
||||
c)).view(b, t, c)
|
||||
prediction = self.Prediction.fc2(prediction)
|
||||
|
||||
# RGB2GRAY
|
||||
input = input[:, 0:
|
||||
1, :, :] * 0.2989 + input[:, 1:
|
||||
2, :, :] * 0.5870 + input[:, 2:
|
||||
3, :, :] * 0.1140
|
||||
x = self.our_nas_model(input)
|
||||
x = torch.squeeze(x, 2)
|
||||
x = torch.transpose(x, 1, 2)
|
||||
b, s, e = x.size()
|
||||
x = x.reshape(b * s, e)
|
||||
prediction = self.head(x).view(b, s, -1)
|
||||
return prediction
|
||||
|
||||
@@ -126,7 +126,7 @@ def plnas_linear_mix_se(input_channel, output_channel):
|
||||
|
||||
stride_stages = [(2, 2), (2, 1), (2, 1), (2, 1)]
|
||||
n_cell_stages = [5, 5, 5, 5]
|
||||
width_stages = [32, 64, 96, 123]
|
||||
width_stages = [32, 64, 96, 128]
|
||||
conv_op_ids = [
|
||||
2, 23, 24, 26, 2, 2, 11, 27, 27, 27, 27, 2, 0, 2, 16, 10, 27, 2, 2, 2,
|
||||
22, 10, 27, 3
|
||||
|
||||
@@ -26,7 +26,9 @@ class YOLOXONNX(object):
|
||||
options.intra_op_num_threads = 1
|
||||
options.inter_op_num_threads = 1
|
||||
self.ort_session = ort.InferenceSession(
|
||||
self.onnx_path, sess_options=options)
|
||||
self.onnx_path,
|
||||
sess_options=options,
|
||||
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
||||
self.with_p6 = False
|
||||
self.multi_detect = multi_detect
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from modelscope.models.cv.s2net_panorama_depth_estimation.networks.util_helper i
|
||||
compute_hp_info, render_depth_map)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.device import create_device
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
@@ -35,8 +36,7 @@ class PanoramaDepthEstimation(TorchModel):
|
||||
"""
|
||||
super().__init__(model_dir, **kwargs)
|
||||
if 'device' in kwargs:
|
||||
self.device = torch.device('cuda' if 'gpu' in
|
||||
kwargs['device'] else 'cpu')
|
||||
self.device = create_device(kwargs['device'])
|
||||
else:
|
||||
self.device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
@@ -9,7 +9,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from timm.models.layers import drop, drop_path, trunc_normal_
|
||||
from timm.layers.drop import drop_path
|
||||
from timm.layers.weight_init import trunc_normal_
|
||||
|
||||
from .common import Upsample, resize
|
||||
|
||||
|
||||
@@ -11,7 +11,8 @@ from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import drop, drop_path, trunc_normal_
|
||||
from timm.layers.drop import drop_path
|
||||
from timm.layers.weight_init import trunc_normal_
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from mmcv.cnn import ConvModule
|
||||
from timm.models.layers import drop, drop_path, trunc_normal_
|
||||
from timm.layers.drop import drop_path
|
||||
from timm.layers.weight_init import trunc_normal_
|
||||
|
||||
from .common import resize
|
||||
|
||||
|
||||
660
modelscope/models/cv/text_texture_generation/Tex2Texture.py
Normal file
660
modelscope/models/cv/text_texture_generation/Tex2Texture.py
Normal file
@@ -0,0 +1,660 @@
|
||||
# Copyright © Alibaba, Inc. and its affiliates.
|
||||
# The implementation here is modifed based on StableDiffusionControlNetInpaintPipeline,
|
||||
# originally Apache 2.0 License and public available at
|
||||
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import PIL.Image as Image
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from diffusers import (AutoencoderKL, ControlNetModel, DiffusionPipeline,
|
||||
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
StableDiffusionControlNetInpaintPipeline,
|
||||
StableDiffusionInpaintPipeline, StableDiffusionPipeline,
|
||||
UNet2DConditionModel)
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.utils import (deprecate, is_accelerate_available,
|
||||
is_accelerate_version, is_compiled_module,
|
||||
logging, randn_tensor, replace_example_docstring)
|
||||
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor, TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.text_texture_generation.lib2.camera import *
|
||||
from modelscope.models.cv.text_texture_generation.lib2.init_view import *
|
||||
from modelscope.models.cv.text_texture_generation.utils import *
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
|
||||
>>> from diffusers.utils import load_image
|
||||
>>> import numpy as np
|
||||
>>> import torch
|
||||
|
||||
>>> init_image = load_image(image_path)
|
||||
>>> init_image = init_image.resize((512, 512))
|
||||
>>> generator = torch.Generator(device="cpu").manual_seed(1)
|
||||
>>> mask_image = load_image(mask_path)
|
||||
>>> mask_image = mask_image.resize((512, 512))
|
||||
>>> def make_inpaint_condition(image, image_mask):
|
||||
... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
|
||||
... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
|
||||
... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
|
||||
... image[image_mask > 0.5] = -1.0 # set as masked pixel
|
||||
... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
|
||||
... image = torch.from_numpy(image)
|
||||
... return image
|
||||
>>> control_image = make_inpaint_condition(init_image, mask_image)
|
||||
>>> controlnet = ControlNetModel.from_pretrained(
|
||||
... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> image = pipe(
|
||||
... "a handsome man with ray-ban sunglasses",
|
||||
... num_inference_steps=20,
|
||||
... generator=generator,
|
||||
... eta=1.0,
|
||||
... image=init_image,
|
||||
... mask_image=mask_image,
|
||||
... control_image=control_image,
|
||||
... ).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.text_texture_generation, module_name=Models.text_texture_generation)
|
||||
class Tex2Texture(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
"""The Tex2Texture is modified based on TEXTure and Text2Tex, publicly available at
|
||||
https://github.com/TEXTurePaper/TEXTurePaper &
|
||||
https://github.com/daveredrum/Text2Tex
|
||||
Args:
|
||||
model_dir: the root directory of the model files
|
||||
"""
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device('cuda')
|
||||
logger.info('Use GPU: {}'.format(self.device))
|
||||
else:
|
||||
print('no gpu avaiable')
|
||||
exit()
|
||||
|
||||
model_path = model_dir + '/base_model/'
|
||||
controlmodel_path = model_dir + '/control_model/'
|
||||
inpaintmodel_path = model_dir + '/inpaint_model/'
|
||||
torch_dtype = kwargs.get('torch_dtype', torch.float16)
|
||||
self.controlnet = ControlNetModel.from_pretrained(
|
||||
controlmodel_path, torch_dtype=torch_dtype).to(self.device)
|
||||
self.inpaintmodel = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
inpaintmodel_path,
|
||||
torch_dtype=torch_dtype,
|
||||
).to(self.device)
|
||||
self.pipe = StableDiffusionControlinpaintPipeline.from_pretrained(
|
||||
model_path, controlnet=self.controlnet,
|
||||
torch_dtype=torch_dtype).to(self.device)
|
||||
logger.info('model load over')
|
||||
|
||||
def init_mesh(self, mesh_path):
|
||||
verts, faces, aux = load_obj(mesh_path, device=self.device)
|
||||
mesh = load_objs_as_meshes([mesh_path], device=self.device)
|
||||
return mesh, verts, faces, aux
|
||||
|
||||
def normalize_mesh(self, mesh):
|
||||
bbox = mesh.get_bounding_boxes()
|
||||
num_verts = mesh.verts_packed().shape[0]
|
||||
mesh_center = bbox.mean(dim=2).repeat(num_verts, 1)
|
||||
mesh = mesh.offset_verts(-mesh_center)
|
||||
lens = bbox[0, :, 1] - bbox[0, :, 0]
|
||||
max_len = lens.max()
|
||||
scale = 0.9 / max_len
|
||||
scale = scale.unsqueeze(0).repeat(num_verts)
|
||||
# mesh.scale_verts_(scale)
|
||||
new_mesh = mesh.scale_verts(scale)
|
||||
return new_mesh.verts_packed(), new_mesh, mesh_center, scale
|
||||
|
||||
def save_normalized_obj(self, verts, faces, aux, path='normalized.obj'):
|
||||
print('=> saving normalized mesh file...')
|
||||
obj_path = path
|
||||
save_obj(
|
||||
obj_path,
|
||||
verts=verts,
|
||||
faces=faces.verts_idx,
|
||||
decimal_places=5,
|
||||
verts_uvs=aux.verts_uvs,
|
||||
faces_uvs=faces.textures_idx,
|
||||
texture_map=aux.texture_images[list(aux.texture_images.keys())[0]])
|
||||
|
||||
def mesh_normalized(self, mesh_path, save_path='normalized.obj'):
|
||||
mesh, verts, faces, aux = self.init_mesh(mesh_path)
|
||||
verts, mesh, mesh_center, scale = self.normalize_mesh(mesh)
|
||||
self.save_normalized_obj(verts, faces, aux, save_path)
|
||||
return mesh, verts, faces, aux, mesh_center, scale
|
||||
|
||||
|
||||
def prepare_mask_and_masked_image(image,
|
||||
mask,
|
||||
height,
|
||||
width,
|
||||
return_image=False):
|
||||
if image is None:
|
||||
raise ValueError('`image` input cannot be undefined.')
|
||||
|
||||
if mask is None:
|
||||
raise ValueError('`mask_image` input cannot be undefined.')
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
if not isinstance(mask, torch.Tensor):
|
||||
raise TypeError(
|
||||
f'`image` is a torch.Tensor but `mask` (type: {type(mask)} is not'
|
||||
)
|
||||
|
||||
# Batch single image
|
||||
if image.ndim == 3:
|
||||
assert image.shape[
|
||||
0] == 3, 'Image outside a batch should be of shape (3, H, W)'
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
# Batch and add channel dim for single mask
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# Batch single mask or add channel dim
|
||||
if mask.ndim == 3:
|
||||
# Single batched mask, no channel dim or single mask not batched but channel dim
|
||||
if mask.shape[0] == 1:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
# Batched masks no channel dim
|
||||
else:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
assert image.ndim == 4 and mask.ndim == 4, 'Image and Mask must have 4 dimensions'
|
||||
assert image.shape[-2:] == mask.shape[
|
||||
-2:], 'Image and Mask must have the same spatial dimensions'
|
||||
assert image.shape[0] == mask.shape[
|
||||
0], 'Image and Mask must have the same batch size'
|
||||
|
||||
# Check image is in [-1, 1]
|
||||
if image.min() < -1 or image.max() > 1:
|
||||
raise ValueError('Image should be in [-1, 1] range')
|
||||
|
||||
# Check mask is in [0, 1]
|
||||
if mask.min() < 0 or mask.max() > 1:
|
||||
raise ValueError('Mask should be in [0, 1] range')
|
||||
|
||||
# Binarize mask
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
|
||||
# Image as float32
|
||||
image = image.to(dtype=torch.float32)
|
||||
elif isinstance(mask, torch.Tensor):
|
||||
raise TypeError(
|
||||
f'`mask` is a torch.Tensor but `image` (type: {type(image)} is not'
|
||||
)
|
||||
else:
|
||||
# preprocess image
|
||||
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
||||
image = [image]
|
||||
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
||||
# resize all images w.r.t passed height an width
|
||||
image = [
|
||||
i.resize((width, height), resample=PIL.Image.LANCZOS)
|
||||
for i in image
|
||||
]
|
||||
image = [np.array(i.convert('RGB'))[None, :] for i in image]
|
||||
image = np.concatenate(image, axis=0)
|
||||
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||
|
||||
image = image.transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
|
||||
# preprocess mask
|
||||
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
||||
mask = [mask]
|
||||
|
||||
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
||||
mask = [
|
||||
i.resize((width, height), resample=PIL.Image.LANCZOS)
|
||||
for i in mask
|
||||
]
|
||||
mask = np.concatenate(
|
||||
[np.array(m.convert('L'))[None, None, :] for m in mask],
|
||||
axis=0)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = image * (mask < 0.5)
|
||||
|
||||
# n.b. ensure backwards compatibility as old function does not return image
|
||||
if return_image:
|
||||
return mask, masked_image, image
|
||||
|
||||
return mask, masked_image
|
||||
|
||||
|
||||
class StableDiffusionControlinpaintPipeline(
|
||||
StableDiffusionControlNetInpaintPipeline):
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: Union[torch.Tensor, PIL.Image.Image] = None,
|
||||
mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
|
||||
control_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray,
|
||||
List[torch.FloatTensor], List[PIL.Image.Image],
|
||||
List[np.ndarray], ] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
strength: float = 1.0,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator,
|
||||
List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = 'pil',
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor],
|
||||
None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
|
||||
guess_mode: bool = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
|
||||
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
|
||||
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
|
||||
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
|
||||
specified in init, images must be passed as a list such that each element of the list can be correctly
|
||||
batched for input to a single controlnet.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
strength (`float`, *optional*, defaults to 1.):
|
||||
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
|
||||
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
|
||||
`strength`. The number of denoising steps depends on the amount of noise initially added. When
|
||||
`strength` is 1, added noise will be maximum and the denoising process will run for the full number of
|
||||
iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
|
||||
portion of the reference `image`.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5):
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
|
||||
corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
|
||||
than for [`~StableDiffusionControlNetPipeline.__call__`].
|
||||
guess_mode (`bool`, *optional*, defaults to `False`):
|
||||
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
||||
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height, width = self._default_height_width(height, width, image)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
control_image,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(
|
||||
self.controlnet) else self.controlnet
|
||||
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(
|
||||
controlnet_conditioning_scale, float):
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale
|
||||
] * len(controlnet.nets)
|
||||
|
||||
global_pool_conditions = (
|
||||
controlnet.config.global_pool_conditions if isinstance(
|
||||
controlnet, ControlNetModel) else
|
||||
controlnet.nets[0].config.global_pool_conditions)
|
||||
guess_mode = guess_mode or global_pool_conditions
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get('scale', None)
|
||||
if cross_attention_kwargs is not None else None)
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
control_image = self.prepare_control_image(
|
||||
image=control_image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
control_images = []
|
||||
|
||||
for control_image_ in control_image:
|
||||
control_image_ = self.prepare_control_image(
|
||||
image=control_image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
control_images.append(control_image_)
|
||||
|
||||
control_image = control_images
|
||||
else:
|
||||
assert False
|
||||
|
||||
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
|
||||
mask, masked_image, init_image = prepare_mask_and_masked_image(
|
||||
image, mask_image, height, width, return_image=True)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps, num_inference_steps = self.get_timesteps(
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
device=device)
|
||||
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size
|
||||
* num_images_per_prompt)
|
||||
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
|
||||
is_strength_max = strength == 1.0
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
num_channels_unet = self.unet.config.in_channels
|
||||
return_image_latents = num_channels_unet == 4
|
||||
latents_outputs = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
image=init_image,
|
||||
timestep=latent_timestep,
|
||||
is_strength_max=is_strength_max,
|
||||
return_noise=True,
|
||||
return_image_latents=return_image_latents,
|
||||
)
|
||||
|
||||
if return_image_latents:
|
||||
latents, noise, image_latents = latents_outputs
|
||||
else:
|
||||
latents, noise = latents_outputs
|
||||
|
||||
# 7. Prepare mask latent variables
|
||||
mask, masked_image_latents = self.prepare_mask_latents(
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size * num_images_per_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
do_classifier_free_guidance,
|
||||
)
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat(
|
||||
[latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t)
|
||||
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
control_model_input = latents
|
||||
control_model_input = self.scheduler.scale_model_input(
|
||||
control_model_input, t)
|
||||
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
||||
else:
|
||||
control_model_input = latent_model_input
|
||||
controlnet_prompt_embeds = prompt_embeds
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
control_model_input,
|
||||
t,
|
||||
encoder_hidden_states=controlnet_prompt_embeds,
|
||||
controlnet_cond=control_image,
|
||||
conditioning_scale=controlnet_conditioning_scale,
|
||||
guess_mode=guess_mode,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
down_block_res_samples = [
|
||||
torch.cat([torch.zeros_like(d), d])
|
||||
for d in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.cat([
|
||||
torch.zeros_like(mid_block_res_sample),
|
||||
mid_block_res_sample
|
||||
])
|
||||
|
||||
# predict the noise residual
|
||||
if num_channels_unet == 9:
|
||||
latent_model_input = torch.cat(
|
||||
[latent_model_input, mask, masked_image_latents],
|
||||
dim=1)
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(
|
||||
noise_pred,
|
||||
t,
|
||||
latents,
|
||||
**extra_step_kwargs,
|
||||
return_dict=False)[0]
|
||||
|
||||
if num_channels_unet == 4:
|
||||
init_latents_proper = image_latents[:1]
|
||||
init_mask = mask[:1]
|
||||
|
||||
if i < len(timesteps) - 1:
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
init_latents_proper, noise, torch.tensor([t]))
|
||||
|
||||
latents = (1 - init_mask
|
||||
) * init_latents_proper + init_mask * latents
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) % self.scheduler.order
|
||||
== 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# If we do sequential model offloading, let's offload unet and controlnet
|
||||
# manually for max memory savings
|
||||
if hasattr(
|
||||
self,
|
||||
'final_offload_hook') and self.final_offload_hook is not None:
|
||||
self.unet.to('cpu')
|
||||
self.controlnet.to('cpu')
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not output_type == 'latent':
|
||||
image = self.vae.decode(
|
||||
latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(
|
||||
image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(
|
||||
image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
if hasattr(
|
||||
self,
|
||||
'final_offload_hook') and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
return StableDiffusionPipelineOutput(
|
||||
images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
165
modelscope/models/cv/text_texture_generation/lib2/camera.py
Normal file
165
modelscope/models/cv/text_texture_generation/lib2/camera.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# customized
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.renderer import PerspectiveCameras, look_at_view_transform
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
from modelscope.models.cv.text_texture_generation.lib2.init_view import \
|
||||
VIEWPOINTS
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
# ---------------- UTILS ----------------------
|
||||
|
||||
|
||||
def degree_to_radian(d):
|
||||
return d * np.pi / 180
|
||||
|
||||
|
||||
def radian_to_degree(r):
|
||||
return 180 * r / np.pi
|
||||
|
||||
|
||||
def xyz_to_polar(xyz):
|
||||
""" assume y-axis is the up axis """
|
||||
|
||||
x, y, z = xyz
|
||||
|
||||
theta = 180 * np.arccos(z) / np.pi
|
||||
phi = 180 * np.arccos(y) / np.pi
|
||||
|
||||
return theta, phi
|
||||
|
||||
|
||||
def polar_to_xyz(theta, phi, dist):
|
||||
""" assume y-axis is the up axis """
|
||||
|
||||
theta = degree_to_radian(theta)
|
||||
phi = degree_to_radian(phi)
|
||||
|
||||
x = np.sin(phi) * np.sin(theta) * dist
|
||||
y = np.cos(phi) * dist
|
||||
z = np.sin(phi) * np.cos(theta) * dist
|
||||
|
||||
return [x, y, z]
|
||||
|
||||
|
||||
# ---------------- VIEWPOINTS ----------------------
|
||||
|
||||
|
||||
def filter_viewpoints(pre_viewpoints: dict, viewpoints: dict):
|
||||
""" return the binary mask of viewpoints to be filtered """
|
||||
|
||||
filter_mask = [0 for _ in viewpoints.keys()]
|
||||
for i, v in viewpoints.items():
|
||||
x_v, y_v, z_v = polar_to_xyz(v['azim'], 90 - v['elev'], v['dist'])
|
||||
|
||||
for _, pv in pre_viewpoints.items():
|
||||
x_pv, y_pv, z_pv = polar_to_xyz(pv['azim'], 90 - pv['elev'],
|
||||
pv['dist'])
|
||||
sim = cosine_similarity(
|
||||
np.array([[x_v, y_v, z_v]]), np.array([[x_pv, y_pv, z_pv]]))[0,
|
||||
0]
|
||||
|
||||
if sim > 0.9:
|
||||
filter_mask[i] = 1
|
||||
|
||||
return filter_mask
|
||||
|
||||
|
||||
def init_viewpoints(init_dist,
|
||||
init_elev,
|
||||
init_azim,
|
||||
use_principle=True,
|
||||
use_shapenet=False,
|
||||
use_objaverse=False):
|
||||
sample_space = 12
|
||||
(dist_list, elev_list, azim_list,
|
||||
sector_list) = init_predefined_viewpoints(sample_space, init_dist,
|
||||
init_elev)
|
||||
|
||||
# punishments for views -> in case always selecting the same view
|
||||
view_punishments = [1 for _ in range(len(dist_list))]
|
||||
|
||||
if use_principle:
|
||||
(dist_list, elev_list, azim_list, sector_list,
|
||||
view_punishments) = init_principle_viewpoints(dist_list, elev_list,
|
||||
azim_list, sector_list,
|
||||
view_punishments,
|
||||
use_shapenet,
|
||||
use_objaverse)
|
||||
azim_list = [v - init_azim for v in azim_list]
|
||||
elev_list = [v - init_elev for v in elev_list]
|
||||
|
||||
return dist_list, elev_list, azim_list, sector_list, view_punishments
|
||||
|
||||
|
||||
def init_principle_viewpoints(dist_list,
|
||||
elev_list,
|
||||
azim_list,
|
||||
sector_list,
|
||||
view_punishments,
|
||||
use_shapenet=False,
|
||||
use_objaverse=False):
|
||||
if use_shapenet:
|
||||
key = 'shapenet'
|
||||
|
||||
pre_elev_list = [v for v in VIEWPOINTS[key]['elev']]
|
||||
pre_azim_list = [v for v in VIEWPOINTS[key]['azim']]
|
||||
pre_sector_list = [v for v in VIEWPOINTS[key]['sector']]
|
||||
|
||||
num_principle = 10
|
||||
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
|
||||
pre_view_punishments = [0 for _ in range(num_principle)]
|
||||
|
||||
elif use_objaverse:
|
||||
key = 'objaverse'
|
||||
|
||||
pre_elev_list = [v for v in VIEWPOINTS[key]['elev']]
|
||||
pre_azim_list = [v for v in VIEWPOINTS[key]['azim']]
|
||||
pre_sector_list = [v for v in VIEWPOINTS[key]['sector']]
|
||||
|
||||
num_principle = 10
|
||||
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
|
||||
pre_view_punishments = [0 for _ in range(num_principle)]
|
||||
else:
|
||||
num_principle = 12
|
||||
pre_elev_list = [v for v in VIEWPOINTS[num_principle]['elev']]
|
||||
pre_azim_list = [v for v in VIEWPOINTS[num_principle]['azim']]
|
||||
pre_sector_list = [v for v in VIEWPOINTS[num_principle]['sector']]
|
||||
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
|
||||
pre_view_punishments = [0 for _ in range(num_principle)]
|
||||
|
||||
dist_list = pre_dist_list + dist_list
|
||||
elev_list = pre_elev_list + elev_list
|
||||
azim_list = pre_azim_list + azim_list
|
||||
sector_list = pre_sector_list + sector_list
|
||||
view_punishments = pre_view_punishments + view_punishments
|
||||
|
||||
return dist_list, elev_list, azim_list, sector_list, view_punishments
|
||||
|
||||
|
||||
def init_predefined_viewpoints(sample_space, init_dist, init_elev):
|
||||
viewpoints = VIEWPOINTS[sample_space]
|
||||
|
||||
assert sample_space == len(viewpoints['sector'])
|
||||
|
||||
dist_list = [init_dist
|
||||
for _ in range(sample_space)] # always the same dist
|
||||
elev_list = [viewpoints['elev'][i] for i in range(sample_space)]
|
||||
azim_list = [viewpoints['azim'][i] for i in range(sample_space)]
|
||||
sector_list = [viewpoints['sector'][i] for i in range(sample_space)]
|
||||
|
||||
return dist_list, elev_list, azim_list, sector_list
|
||||
|
||||
|
||||
def init_camera(dist, elev, azim, image_size, device):
|
||||
R, T = look_at_view_transform(dist, elev, azim)
|
||||
image_size = torch.tensor([image_size, image_size]).unsqueeze(0)
|
||||
T[0][2] = dist
|
||||
cameras = PerspectiveCameras(
|
||||
R=R, T=T, device=device, image_size=image_size)
|
||||
|
||||
return cameras
|
||||
229
modelscope/models/cv/text_texture_generation/lib2/init_view.py
Normal file
229
modelscope/models/cv/text_texture_generation/lib2/init_view.py
Normal file
@@ -0,0 +1,229 @@
|
||||
PALETTE = {
|
||||
0: [255, 255, 255], # white - background
|
||||
1: [204, 50, 50], # red - old
|
||||
2: [231, 180, 22], # yellow - update
|
||||
3: [45, 201, 55] # green - new
|
||||
}
|
||||
|
||||
QUAD_WEIGHTS = {
|
||||
0: 0, # background
|
||||
1: 0.1, # old
|
||||
2: 0.5, # update
|
||||
3: 1 # new
|
||||
}
|
||||
|
||||
VIEWPOINTS = {
|
||||
2: {
|
||||
'azim': [0, 180],
|
||||
'elev': [0, 0],
|
||||
'sector': ['front', 'back']
|
||||
},
|
||||
4: {
|
||||
'azim': [
|
||||
45,
|
||||
315,
|
||||
135,
|
||||
225,
|
||||
],
|
||||
'elev': [
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
],
|
||||
'sector': [
|
||||
'front right',
|
||||
'front left',
|
||||
'back right',
|
||||
'back left',
|
||||
]
|
||||
},
|
||||
6: {
|
||||
'azim': [0, 90, 270, 0, 180, 0],
|
||||
'elev': [0, 0, 0, 90, 0, -90],
|
||||
'sector': [
|
||||
'front',
|
||||
'right',
|
||||
'left',
|
||||
'top',
|
||||
'back',
|
||||
'bottom',
|
||||
]
|
||||
},
|
||||
10: {
|
||||
'azim': [270, 315, 225, 0, 180, 45, 135, 90, 270, 270],
|
||||
'elev': [15, 15, 15, 15, 15, 15, 15, 15, 90, -90],
|
||||
'sector': [
|
||||
'front',
|
||||
'front right',
|
||||
'front left',
|
||||
'right',
|
||||
'left',
|
||||
'back right',
|
||||
'back left',
|
||||
'back',
|
||||
'top',
|
||||
'bottom',
|
||||
]
|
||||
},
|
||||
12: {
|
||||
'azim': [
|
||||
0,
|
||||
45,
|
||||
315,
|
||||
135,
|
||||
225,
|
||||
180,
|
||||
45,
|
||||
315,
|
||||
90,
|
||||
270,
|
||||
90,
|
||||
270,
|
||||
],
|
||||
'elev': [
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
30,
|
||||
30,
|
||||
15,
|
||||
15,
|
||||
90,
|
||||
-90,
|
||||
],
|
||||
'sector': [
|
||||
'front',
|
||||
'front right',
|
||||
'front left',
|
||||
'back right',
|
||||
'back left',
|
||||
'back',
|
||||
'front right',
|
||||
'front left',
|
||||
'right',
|
||||
'left',
|
||||
'top',
|
||||
'bottom',
|
||||
]
|
||||
},
|
||||
36: {
|
||||
'azim': [
|
||||
45,
|
||||
315,
|
||||
135,
|
||||
225,
|
||||
0,
|
||||
45,
|
||||
315,
|
||||
90,
|
||||
270,
|
||||
135,
|
||||
225,
|
||||
180,
|
||||
0,
|
||||
45,
|
||||
315,
|
||||
90,
|
||||
270,
|
||||
135,
|
||||
225,
|
||||
180,
|
||||
22.5,
|
||||
337.5,
|
||||
67.5,
|
||||
292.5,
|
||||
112.5,
|
||||
247.5,
|
||||
157.5,
|
||||
202.5,
|
||||
22.5,
|
||||
337.5,
|
||||
67.5,
|
||||
292.5,
|
||||
112.5,
|
||||
247.5,
|
||||
157.5,
|
||||
202.5,
|
||||
],
|
||||
'elev': [
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
30,
|
||||
30,
|
||||
30,
|
||||
30,
|
||||
30,
|
||||
30,
|
||||
30,
|
||||
30,
|
||||
60,
|
||||
60,
|
||||
60,
|
||||
60,
|
||||
60,
|
||||
60,
|
||||
60,
|
||||
60,
|
||||
15,
|
||||
15,
|
||||
15,
|
||||
15,
|
||||
15,
|
||||
15,
|
||||
15,
|
||||
15,
|
||||
45,
|
||||
45,
|
||||
45,
|
||||
45,
|
||||
45,
|
||||
45,
|
||||
45,
|
||||
45,
|
||||
],
|
||||
'sector': [
|
||||
'front right',
|
||||
'front left',
|
||||
'back right',
|
||||
'back left',
|
||||
'front',
|
||||
'front right',
|
||||
'front left',
|
||||
'right',
|
||||
'left',
|
||||
'back right',
|
||||
'back left',
|
||||
'back',
|
||||
'top front',
|
||||
'top right',
|
||||
'top left',
|
||||
'top right',
|
||||
'top left',
|
||||
'top right',
|
||||
'top left',
|
||||
'top back',
|
||||
'front right',
|
||||
'front left',
|
||||
'front right',
|
||||
'front left',
|
||||
'back right',
|
||||
'back left',
|
||||
'back right',
|
||||
'back left',
|
||||
'front right',
|
||||
'front left',
|
||||
'front right',
|
||||
'front left',
|
||||
'back right',
|
||||
'back left',
|
||||
'back right',
|
||||
'back left',
|
||||
]
|
||||
}
|
||||
}
|
||||
655
modelscope/models/cv/text_texture_generation/lib2/projection.py
Normal file
655
modelscope/models/cv/text_texture_generation/lib2/projection.py
Normal file
@@ -0,0 +1,655 @@
|
||||
import os
|
||||
import random
|
||||
# customized
|
||||
import sys
|
||||
from typing import NamedTuple, Sequence
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from pytorch3d.io import save_obj
|
||||
from pytorch3d.ops import interpolate_face_attributes
|
||||
from pytorch3d.renderer import (AmbientLights, MeshRasterizer,
|
||||
MeshRendererWithFragments,
|
||||
RasterizationSettings, SoftPhongShader,
|
||||
TexturesUV)
|
||||
from pytorch3d.renderer.mesh.shader import ShaderBase
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.models.cv.text_texture_generation.lib2.camera import \
|
||||
init_camera
|
||||
from modelscope.models.cv.text_texture_generation.lib2.init_view import *
|
||||
from modelscope.models.cv.text_texture_generation.lib2.viusel import (
|
||||
visualize_outputs, visualize_quad_mask)
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
|
||||
class BlendParams(NamedTuple):
|
||||
sigma: float = 1e-4
|
||||
gamma: float = 1e-4
|
||||
background_color: Sequence = (1, 1, 1)
|
||||
|
||||
|
||||
class FlatTexelShader(ShaderBase):
|
||||
|
||||
def __init__(self,
|
||||
device='cpu',
|
||||
cameras=None,
|
||||
lights=None,
|
||||
materials=None,
|
||||
blend_params=None):
|
||||
super().__init__(device, cameras, lights, materials, blend_params)
|
||||
|
||||
def forward(self, fragments, meshes, **_kwargs):
|
||||
texels = meshes.sample_textures(fragments)
|
||||
texels[(fragments.pix_to_face == -1), :] = 0
|
||||
return texels.squeeze(-2)
|
||||
|
||||
|
||||
def init_soft_phong_shader(camera, blend_params, device):
|
||||
lights = AmbientLights(device=device)
|
||||
shader = SoftPhongShader(
|
||||
cameras=camera,
|
||||
lights=lights,
|
||||
device=device,
|
||||
blend_params=blend_params)
|
||||
|
||||
return shader
|
||||
|
||||
|
||||
def init_flat_texel_shader(camera, device):
|
||||
shader = FlatTexelShader(cameras=camera, device=device)
|
||||
return shader
|
||||
|
||||
|
||||
def init_renderer(camera, shader, image_size, faces_per_pixel):
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=image_size, faces_per_pixel=faces_per_pixel)
|
||||
renderer = MeshRendererWithFragments(
|
||||
rasterizer=MeshRasterizer(
|
||||
cameras=camera, raster_settings=raster_settings),
|
||||
shader=shader)
|
||||
|
||||
return renderer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def render(mesh, renderer, pad_value=10):
|
||||
|
||||
def phong_normal_shading(meshes, fragments) -> torch.Tensor:
|
||||
faces = meshes.faces_packed() # (F, 3)
|
||||
vertex_normals = meshes.verts_normals_packed() # (V, 3)
|
||||
faces_normals = vertex_normals[faces]
|
||||
pixel_normals = interpolate_face_attributes(fragments.pix_to_face,
|
||||
fragments.bary_coords,
|
||||
faces_normals)
|
||||
|
||||
return pixel_normals
|
||||
|
||||
def similarity_shading(meshes, fragments):
|
||||
faces = meshes.faces_packed() # (F, 3)
|
||||
vertex_normals = meshes.verts_normals_packed() # (V, 3)
|
||||
faces_normals = vertex_normals[faces]
|
||||
vertices = meshes.verts_packed() # (V, 3)
|
||||
face_positions = vertices[faces]
|
||||
view_directions = torch.nn.functional.normalize(
|
||||
(renderer.shader.cameras.get_camera_center().reshape(1, 1, 3)
|
||||
- face_positions),
|
||||
p=2,
|
||||
dim=2)
|
||||
cosine_similarity = torch.nn.CosineSimilarity(dim=2)(faces_normals,
|
||||
view_directions)
|
||||
pixel_similarity = interpolate_face_attributes(
|
||||
fragments.pix_to_face, fragments.bary_coords,
|
||||
cosine_similarity.unsqueeze(-1))
|
||||
|
||||
return pixel_similarity
|
||||
|
||||
def get_relative_depth_map(fragments, pad_value=pad_value):
|
||||
absolute_depth = fragments.zbuf[..., 0] # B, H, W
|
||||
no_depth = -1
|
||||
|
||||
depth_min, depth_max = absolute_depth[absolute_depth != no_depth].min(
|
||||
), absolute_depth[absolute_depth != no_depth].max()
|
||||
target_min, target_max = 50, 255
|
||||
|
||||
depth_value = absolute_depth[absolute_depth != no_depth]
|
||||
depth_value = depth_max - depth_value # reverse values
|
||||
|
||||
depth_value /= (depth_max - depth_min)
|
||||
depth_value = depth_value * (target_max - target_min) + target_min
|
||||
|
||||
relative_depth = absolute_depth.clone()
|
||||
relative_depth[absolute_depth != no_depth] = depth_value
|
||||
relative_depth[absolute_depth == no_depth] = pad_value
|
||||
|
||||
return relative_depth
|
||||
|
||||
images, fragments = renderer(mesh)
|
||||
normal_maps = phong_normal_shading(mesh, fragments).squeeze(-2)
|
||||
similarity_maps = similarity_shading(mesh, fragments).squeeze(-2) # -1 - 1
|
||||
depth_maps = get_relative_depth_map(fragments)
|
||||
|
||||
# normalize similarity mask to 0 - 1
|
||||
similarity_maps = torch.abs(similarity_maps) # 0 - 1
|
||||
|
||||
# HACK erode, eliminate isolated dots
|
||||
non_zero_similarity = (similarity_maps > 0).float()
|
||||
non_zero_similarity = (non_zero_similarity * 255.).cpu().numpy().astype(
|
||||
np.uint8)[0]
|
||||
non_zero_similarity = cv2.erode(
|
||||
non_zero_similarity, kernel=np.ones((3, 3), np.uint8), iterations=2)
|
||||
non_zero_similarity = torch.from_numpy(non_zero_similarity).to(
|
||||
similarity_maps.device).unsqueeze(0) / 255.
|
||||
similarity_maps = non_zero_similarity.unsqueeze(-1) * similarity_maps
|
||||
return images, normal_maps, similarity_maps, depth_maps, fragments
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def check_visible_faces(mesh, fragments):
|
||||
pix_to_face = fragments.pix_to_face
|
||||
visible_map = pix_to_face.unique() # (num_visible_faces)
|
||||
return visible_map
|
||||
|
||||
|
||||
def get_all_4_locations(values_y, values_x):
|
||||
y_0 = torch.floor(values_y)
|
||||
y_1 = torch.ceil(values_y)
|
||||
x_0 = torch.floor(values_x)
|
||||
x_1 = torch.ceil(values_x)
|
||||
|
||||
return torch.cat([y_0, y_0, y_1, y_1],
|
||||
0).long(), torch.cat([x_0, x_1, x_0, x_1], 0).long()
|
||||
|
||||
|
||||
def compose_quad_mask(new_mask_image, update_mask_image, old_mask_image,
|
||||
device):
|
||||
"""
|
||||
compose quad mask:
|
||||
-> 0: background
|
||||
-> 1: old
|
||||
-> 2: update
|
||||
-> 3: new
|
||||
"""
|
||||
|
||||
new_mask_tensor = transforms.ToTensor()(new_mask_image).to(device)
|
||||
update_mask_tensor = transforms.ToTensor()(update_mask_image).to(device)
|
||||
old_mask_tensor = transforms.ToTensor()(old_mask_image).to(device)
|
||||
|
||||
all_mask_tensor = new_mask_tensor + update_mask_tensor + old_mask_tensor
|
||||
|
||||
quad_mask_tensor = torch.zeros_like(all_mask_tensor)
|
||||
quad_mask_tensor[old_mask_tensor == 1] = 1
|
||||
quad_mask_tensor[update_mask_tensor == 1] = 2
|
||||
quad_mask_tensor[new_mask_tensor == 1] = 3
|
||||
|
||||
return old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
|
||||
|
||||
|
||||
def compute_view_heat(similarity_tensor, quad_mask_tensor):
|
||||
num_total_pixels = quad_mask_tensor.reshape(-1).shape[0]
|
||||
heat = 0
|
||||
for idx in QUAD_WEIGHTS:
|
||||
heat += (quad_mask_tensor
|
||||
== idx).sum() * QUAD_WEIGHTS[idx] / num_total_pixels
|
||||
|
||||
return heat
|
||||
|
||||
|
||||
def select_viewpoint(selected_view_ids,
|
||||
view_punishments,
|
||||
mode,
|
||||
dist_list,
|
||||
elev_list,
|
||||
azim_list,
|
||||
sector_list,
|
||||
view_idx,
|
||||
similarity_texture_cache,
|
||||
exist_texture,
|
||||
mesh,
|
||||
faces,
|
||||
verts_uvs,
|
||||
image_size,
|
||||
faces_per_pixel,
|
||||
init_image_dir,
|
||||
mask_image_dir,
|
||||
normal_map_dir,
|
||||
depth_map_dir,
|
||||
similarity_map_dir,
|
||||
device,
|
||||
use_principle=False):
|
||||
if mode == 'sequential':
|
||||
|
||||
num_views = len(dist_list)
|
||||
|
||||
dist = dist_list[view_idx % num_views]
|
||||
elev = elev_list[view_idx % num_views]
|
||||
azim = azim_list[view_idx % num_views]
|
||||
sector = sector_list[view_idx % num_views]
|
||||
|
||||
selected_view_ids.append(view_idx % num_views)
|
||||
|
||||
elif mode == 'heuristic':
|
||||
|
||||
if use_principle and view_idx < 6:
|
||||
|
||||
selected_view_idx = view_idx
|
||||
|
||||
else:
|
||||
|
||||
selected_view_idx = None
|
||||
max_heat = 0
|
||||
|
||||
print('=> selecting next view...')
|
||||
view_heat_list = []
|
||||
for sample_idx in tqdm(range(len(dist_list))):
|
||||
|
||||
view_heat, *_ = render_one_view_and_build_masks(
|
||||
dist_list[sample_idx], elev_list[sample_idx],
|
||||
azim_list[sample_idx], sample_idx, sample_idx,
|
||||
view_punishments, similarity_texture_cache, exist_texture,
|
||||
mesh, faces, verts_uvs, image_size, faces_per_pixel,
|
||||
init_image_dir, mask_image_dir, normal_map_dir,
|
||||
depth_map_dir, similarity_map_dir, device)
|
||||
|
||||
if view_heat > max_heat:
|
||||
selected_view_idx = sample_idx
|
||||
max_heat = view_heat
|
||||
|
||||
view_heat_list.append(view_heat.item())
|
||||
|
||||
print(view_heat_list)
|
||||
print('select view {} with heat {}'.format(selected_view_idx,
|
||||
max_heat))
|
||||
|
||||
dist = dist_list[selected_view_idx]
|
||||
elev = elev_list[selected_view_idx]
|
||||
azim = azim_list[selected_view_idx]
|
||||
sector = sector_list[selected_view_idx]
|
||||
|
||||
selected_view_ids.append(selected_view_idx)
|
||||
|
||||
view_punishments[selected_view_idx] *= 0.01
|
||||
|
||||
elif mode == 'random':
|
||||
|
||||
selected_view_idx = random.choice(range(len(dist_list)))
|
||||
|
||||
dist = dist_list[selected_view_idx]
|
||||
elev = elev_list[selected_view_idx]
|
||||
azim = azim_list[selected_view_idx]
|
||||
sector = sector_list[selected_view_idx]
|
||||
|
||||
selected_view_ids.append(selected_view_idx)
|
||||
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return dist, elev, azim, sector, selected_view_ids, view_punishments
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def build_backproject_mask(mesh, faces, verts_uvs, cameras, reference_image,
|
||||
faces_per_pixel, image_size, uv_size, device):
|
||||
# construct pixel UVs
|
||||
renderer_scaled = init_renderer(
|
||||
cameras,
|
||||
shader=init_soft_phong_shader(
|
||||
camera=cameras, blend_params=BlendParams(), device=device),
|
||||
image_size=image_size,
|
||||
faces_per_pixel=faces_per_pixel)
|
||||
fragments_scaled = renderer_scaled.rasterizer(mesh)
|
||||
|
||||
# get UV coordinates for each pixel
|
||||
faces_verts_uvs = verts_uvs[faces.textures_idx]
|
||||
|
||||
pixel_uvs = interpolate_face_attributes(fragments_scaled.pix_to_face,
|
||||
fragments_scaled.bary_coords,
|
||||
faces_verts_uvs) # NxHsxWsxKx2
|
||||
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(-1, 2)
|
||||
|
||||
texture_locations_y, texture_locations_x = get_all_4_locations(
|
||||
(1 - pixel_uvs[:, 1]).reshape(-1) * (uv_size - 1),
|
||||
pixel_uvs[:, 0].reshape(-1) * (uv_size - 1))
|
||||
|
||||
K = faces_per_pixel
|
||||
|
||||
texture_values = torch.from_numpy(
|
||||
np.array(reference_image.resize(
|
||||
(image_size, image_size)))).float() / 255.
|
||||
texture_values = texture_values.to(device).unsqueeze(0).expand(
|
||||
[4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
|
||||
|
||||
# texture
|
||||
texture_tensor = torch.zeros(uv_size, uv_size, 3).to(device)
|
||||
texture_tensor[texture_locations_y,
|
||||
texture_locations_x, :] = texture_values.reshape(-1, 3)
|
||||
|
||||
return texture_tensor[:, :, 0]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def build_diffusion_mask(mesh_stuff,
|
||||
renderer,
|
||||
exist_texture,
|
||||
similarity_texture_cache,
|
||||
target_value,
|
||||
device,
|
||||
image_size,
|
||||
smooth_mask=False,
|
||||
view_threshold=0.01):
|
||||
mesh, faces, verts_uvs = mesh_stuff
|
||||
mask_mesh = mesh.clone() # NOTE in-place operation - DANGER!!!
|
||||
|
||||
# visible mask => the whole region
|
||||
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(
|
||||
-1, -1, -1, 3).to(device)
|
||||
mask_mesh.textures = TexturesUV(
|
||||
maps=torch.ones_like(exist_texture_expand),
|
||||
faces_uvs=faces.textures_idx[None, ...],
|
||||
verts_uvs=verts_uvs[None, ...],
|
||||
sampling_mode='nearest')
|
||||
# visible_mask_tensor, *_ = render(mask_mesh, renderer)
|
||||
visible_mask_tensor, _, similarity_map_tensor, *_ = render(
|
||||
mask_mesh, renderer)
|
||||
# faces that are too rotated away from the viewpoint will be treated as invisible
|
||||
valid_mask_tensor = (similarity_map_tensor >= view_threshold).float()
|
||||
visible_mask_tensor *= valid_mask_tensor
|
||||
|
||||
# nonexist mask <=> new mask
|
||||
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(
|
||||
-1, -1, -1, 3).to(device)
|
||||
mask_mesh.textures = TexturesUV(
|
||||
maps=1 - exist_texture_expand,
|
||||
faces_uvs=faces.textures_idx[None, ...],
|
||||
verts_uvs=verts_uvs[None, ...],
|
||||
sampling_mode='nearest')
|
||||
new_mask_tensor, *_ = render(mask_mesh, renderer)
|
||||
new_mask_tensor *= valid_mask_tensor
|
||||
|
||||
# exist mask => visible mask - new mask
|
||||
exist_mask_tensor = visible_mask_tensor - new_mask_tensor
|
||||
exist_mask_tensor[
|
||||
exist_mask_tensor < 0] = 0 # NOTE dilate can lead to overflow
|
||||
|
||||
# all update mask
|
||||
mask_mesh.textures = TexturesUV(
|
||||
maps=(
|
||||
similarity_texture_cache.argmax(0) == target_value
|
||||
# # only consider the views that have already appeared before
|
||||
# similarity_texture_cache[0:target_value+1].argmax(0) == target_value
|
||||
).float().unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device),
|
||||
faces_uvs=faces.textures_idx[None, ...],
|
||||
verts_uvs=verts_uvs[None, ...],
|
||||
sampling_mode='nearest')
|
||||
all_update_mask_tensor, *_ = render(mask_mesh, renderer)
|
||||
|
||||
# current update mask => intersection between all update mask and exist mask
|
||||
update_mask_tensor = exist_mask_tensor * all_update_mask_tensor
|
||||
|
||||
# keep mask => exist mask - update mask
|
||||
old_mask_tensor = exist_mask_tensor - update_mask_tensor
|
||||
|
||||
# convert
|
||||
new_mask = new_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
||||
new_mask = transforms.ToPILImage()(new_mask).convert('L')
|
||||
|
||||
update_mask = update_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
||||
update_mask = transforms.ToPILImage()(update_mask).convert('L')
|
||||
|
||||
old_mask = old_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
||||
old_mask = transforms.ToPILImage()(old_mask).convert('L')
|
||||
|
||||
exist_mask = exist_mask_tensor[0].cpu().float().permute(2, 0, 1)
|
||||
exist_mask = transforms.ToPILImage()(exist_mask).convert('L')
|
||||
|
||||
return new_mask, update_mask, old_mask, exist_mask
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def render_one_view(mesh, dist, elev, azim, image_size, faces_per_pixel,
|
||||
device):
|
||||
# render the view
|
||||
# print(image_size)
|
||||
cameras = init_camera(dist, elev, azim, image_size, device)
|
||||
renderer = init_renderer(
|
||||
cameras,
|
||||
shader=init_soft_phong_shader(
|
||||
camera=cameras, blend_params=BlendParams(), device=device),
|
||||
image_size=image_size,
|
||||
faces_per_pixel=faces_per_pixel)
|
||||
|
||||
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments = render(
|
||||
mesh, renderer)
|
||||
# print(init_images_tensor.shape, torch.max(init_images_tensor), torch.min(init_images_tensor))
|
||||
cv2.imwrite('img.png',
|
||||
(np.array(init_images_tensor.squeeze(0)[:, :, :3].cpu())
|
||||
* 255).astype(np.uint8))
|
||||
return (cameras, renderer, init_images_tensor, normal_maps_tensor,
|
||||
similarity_tensor, depth_maps_tensor, fragments)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def build_similarity_texture_cache_for_all_views(mesh, faces, verts_uvs,
|
||||
dist_list, elev_list,
|
||||
azim_list, image_size,
|
||||
image_size_scaled, uv_size,
|
||||
faces_per_pixel, device):
|
||||
num_candidate_views = len(dist_list)
|
||||
similarity_texture_cache = torch.zeros(num_candidate_views, uv_size,
|
||||
uv_size).to(device)
|
||||
|
||||
print('=> building similarity texture cache for all views...')
|
||||
for i in tqdm(range(num_candidate_views)):
|
||||
cameras, _, _, _, similarity_tensor, _, _ = render_one_view(
|
||||
mesh, dist_list[i], elev_list[i], azim_list[i], image_size,
|
||||
faces_per_pixel, device)
|
||||
|
||||
similarity_texture_cache[i] = build_backproject_mask(
|
||||
mesh, faces, verts_uvs, cameras,
|
||||
transforms.ToPILImage()(similarity_tensor[0, :, :,
|
||||
0]).convert('RGB'),
|
||||
faces_per_pixel, image_size_scaled, uv_size, device)
|
||||
|
||||
return similarity_texture_cache
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def render_one_view_and_build_masks(dist,
|
||||
elev,
|
||||
azim,
|
||||
selected_view_idx,
|
||||
view_idx,
|
||||
view_punishments,
|
||||
similarity_texture_cache,
|
||||
exist_texture,
|
||||
mesh,
|
||||
faces,
|
||||
verts_uvs,
|
||||
image_size,
|
||||
faces_per_pixel,
|
||||
init_image_dir,
|
||||
mask_image_dir,
|
||||
normal_map_dir,
|
||||
depth_map_dir,
|
||||
similarity_map_dir,
|
||||
device,
|
||||
save_intermediate=False,
|
||||
smooth_mask=False,
|
||||
view_threshold=0.01):
|
||||
# render the view
|
||||
(cameras, renderer, init_images_tensor, normal_maps_tensor,
|
||||
similarity_tensor, depth_maps_tensor,
|
||||
fragments) = render_one_view(mesh, dist, elev, azim, image_size,
|
||||
faces_per_pixel, device)
|
||||
|
||||
init_image = init_images_tensor[0].cpu()
|
||||
init_image = init_image.permute(2, 0, 1)
|
||||
init_image = transforms.ToPILImage()(init_image).convert('RGB')
|
||||
|
||||
normal_map = normal_maps_tensor[0].cpu()
|
||||
normal_map = normal_map.permute(2, 0, 1)
|
||||
normal_map = transforms.ToPILImage()(normal_map).convert('RGB')
|
||||
|
||||
depth_map = depth_maps_tensor[0].cpu().numpy()
|
||||
depth_map = Image.fromarray(depth_map).convert('L')
|
||||
|
||||
similarity_map = similarity_tensor[0, :, :, 0].cpu()
|
||||
similarity_map = transforms.ToPILImage()(similarity_map).convert('L')
|
||||
|
||||
flat_renderer = init_renderer(
|
||||
cameras,
|
||||
shader=init_flat_texel_shader(camera=cameras, device=device),
|
||||
image_size=image_size,
|
||||
faces_per_pixel=faces_per_pixel)
|
||||
new_mask_image, update_mask_image, old_mask_image, exist_mask_image = build_diffusion_mask(
|
||||
(mesh, faces, verts_uvs),
|
||||
flat_renderer,
|
||||
exist_texture,
|
||||
similarity_texture_cache,
|
||||
selected_view_idx,
|
||||
device,
|
||||
image_size,
|
||||
smooth_mask=smooth_mask,
|
||||
view_threshold=view_threshold)
|
||||
# NOTE the view idx is the absolute idx in the sample space (i.e. `selected_view_idx`)
|
||||
# it should match with `similarity_texture_cache`
|
||||
|
||||
(old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor,
|
||||
quad_mask_tensor) = compose_quad_mask(new_mask_image, update_mask_image,
|
||||
old_mask_image, device)
|
||||
|
||||
view_heat = compute_view_heat(similarity_tensor, quad_mask_tensor)
|
||||
view_heat *= view_punishments[selected_view_idx]
|
||||
|
||||
# save intermediate results
|
||||
if save_intermediate:
|
||||
init_image.save(
|
||||
os.path.join(init_image_dir, '{}.png'.format(view_idx)))
|
||||
normal_map.save(
|
||||
os.path.join(normal_map_dir, '{}.png'.format(view_idx)))
|
||||
depth_map.save(os.path.join(depth_map_dir, '{}.png'.format(view_idx)))
|
||||
similarity_map.save(
|
||||
os.path.join(similarity_map_dir, '{}.png'.format(view_idx)))
|
||||
|
||||
new_mask_image.save(
|
||||
os.path.join(mask_image_dir, '{}_new.png'.format(view_idx)))
|
||||
update_mask_image.save(
|
||||
os.path.join(mask_image_dir, '{}_update.png'.format(view_idx)))
|
||||
old_mask_image.save(
|
||||
os.path.join(mask_image_dir, '{}_old.png'.format(view_idx)))
|
||||
exist_mask_image.save(
|
||||
os.path.join(mask_image_dir, '{}_exist.png'.format(view_idx)))
|
||||
|
||||
visualize_quad_mask(mask_image_dir, quad_mask_tensor, view_idx,
|
||||
view_heat, device)
|
||||
|
||||
return (view_heat, renderer, cameras, fragments, init_image, normal_map,
|
||||
depth_map, init_images_tensor, normal_maps_tensor,
|
||||
depth_maps_tensor, similarity_tensor, old_mask_image,
|
||||
update_mask_image, new_mask_image, old_mask_tensor,
|
||||
update_mask_tensor, new_mask_tensor, all_mask_tensor,
|
||||
quad_mask_tensor)
|
||||
|
||||
|
||||
def save_full_obj(output_dir, obj_name, verts, faces, verts_uvs, faces_uvs,
|
||||
projected_texture, device):
|
||||
print('=> saving OBJ file...')
|
||||
texture_map = transforms.ToTensor()(projected_texture).to(device)
|
||||
texture_map = texture_map.permute(1, 2, 0)
|
||||
obj_path = os.path.join(output_dir, obj_name)
|
||||
|
||||
save_obj(
|
||||
obj_path,
|
||||
verts=verts,
|
||||
faces=faces,
|
||||
decimal_places=5,
|
||||
verts_uvs=verts_uvs,
|
||||
faces_uvs=faces_uvs,
|
||||
texture_map=texture_map)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def backproject_from_image(mesh, faces, verts_uvs, cameras, reference_image,
|
||||
new_mask_image, update_mask_image, init_texture,
|
||||
exist_texture, image_size, uv_size, faces_per_pixel,
|
||||
device):
|
||||
# construct pixel UVs
|
||||
renderer_scaled = init_renderer(
|
||||
cameras,
|
||||
shader=init_soft_phong_shader(
|
||||
camera=cameras, blend_params=BlendParams(), device=device),
|
||||
image_size=image_size,
|
||||
faces_per_pixel=faces_per_pixel)
|
||||
fragments_scaled = renderer_scaled.rasterizer(mesh)
|
||||
|
||||
# get UV coordinates for each pixel
|
||||
faces_verts_uvs = verts_uvs[faces.textures_idx]
|
||||
|
||||
pixel_uvs = interpolate_face_attributes(fragments_scaled.pix_to_face,
|
||||
fragments_scaled.bary_coords,
|
||||
faces_verts_uvs) # NxHsxWsxKx2
|
||||
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2,
|
||||
4).reshape(pixel_uvs.shape[-2],
|
||||
pixel_uvs.shape[1],
|
||||
pixel_uvs.shape[2], 2)
|
||||
|
||||
# the update mask has to be on top of the diffusion mask
|
||||
new_mask_image_tensor = transforms.ToTensor()(new_mask_image).to(
|
||||
device).unsqueeze(-1)
|
||||
update_mask_image_tensor = transforms.ToTensor()(update_mask_image).to(
|
||||
device).unsqueeze(-1)
|
||||
|
||||
project_mask_image_tensor = torch.logical_or(
|
||||
update_mask_image_tensor, new_mask_image_tensor).float()
|
||||
project_mask_image = project_mask_image_tensor * 255.
|
||||
project_mask_image = Image.fromarray(
|
||||
project_mask_image[0, :, :, 0].cpu().numpy().astype(np.uint8))
|
||||
|
||||
project_mask_image_scaled = project_mask_image.resize(
|
||||
(image_size, image_size), )
|
||||
# Image.Resampling.NEAREST
|
||||
# )
|
||||
project_mask_image_tensor_scaled = transforms.ToTensor()(
|
||||
project_mask_image_scaled).to(device)
|
||||
|
||||
pixel_uvs_masked = pixel_uvs[project_mask_image_tensor_scaled == 1]
|
||||
|
||||
texture_locations_y, texture_locations_x = get_all_4_locations(
|
||||
(1 - pixel_uvs_masked[:, 1]).reshape(-1) * (uv_size - 1),
|
||||
pixel_uvs_masked[:, 0].reshape(-1) * (uv_size - 1))
|
||||
|
||||
K = pixel_uvs.shape[0]
|
||||
project_mask_image_tensor_scaled = project_mask_image_tensor_scaled[:,
|
||||
None, :, :,
|
||||
None].repeat(
|
||||
1,
|
||||
4,
|
||||
1,
|
||||
1,
|
||||
3)
|
||||
|
||||
texture_values = torch.from_numpy(
|
||||
np.array(reference_image.resize((image_size, image_size))))
|
||||
texture_values = texture_values.to(device).unsqueeze(0).expand(
|
||||
[4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
|
||||
|
||||
texture_values_masked = texture_values.reshape(
|
||||
-1, 3)[project_mask_image_tensor_scaled.reshape(-1, 3) == 1].reshape(
|
||||
-1, 3)
|
||||
|
||||
# texture
|
||||
texture_tensor = torch.from_numpy(np.array(init_texture)).to(device)
|
||||
texture_tensor[texture_locations_y,
|
||||
texture_locations_x, :] = texture_values_masked
|
||||
|
||||
init_texture = Image.fromarray(texture_tensor.cpu().numpy().astype(
|
||||
np.uint8))
|
||||
|
||||
# update texture cache
|
||||
exist_texture[texture_locations_y, texture_locations_x] = 1
|
||||
|
||||
return init_texture, project_mask_image, exist_texture
|
||||
268
modelscope/models/cv/text_texture_generation/lib2/viusel.py
Normal file
268
modelscope/models/cv/text_texture_generation/lib2/viusel.py
Normal file
@@ -0,0 +1,268 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import imageio.v2 as imageio
|
||||
# visualization
|
||||
import matplotlib
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.models.cv.text_texture_generation.lib2.camera import \
|
||||
polar_to_xyz
|
||||
from modelscope.models.cv.text_texture_generation.lib2.init_view import *
|
||||
|
||||
matplotlib.use('Agg')
|
||||
|
||||
sys.path.append('.')
|
||||
|
||||
|
||||
def visualize_quad_mask(mask_image_dir, quad_mask_tensor, view_idx, view_score,
|
||||
device):
|
||||
quad_mask_tensor = quad_mask_tensor.unsqueeze(-1).repeat(1, 1, 1, 3)
|
||||
quad_mask_image_tensor = torch.zeros_like(quad_mask_tensor)
|
||||
|
||||
for idx in PALETTE:
|
||||
selected = quad_mask_tensor[quad_mask_tensor == idx].reshape(-1, 3)
|
||||
selected = torch.FloatTensor(
|
||||
PALETTE[idx]).to(device).unsqueeze(0).repeat(selected.shape[0], 1)
|
||||
|
||||
quad_mask_image_tensor[quad_mask_tensor == idx] = selected.reshape(-1)
|
||||
|
||||
quad_mask_image_np = quad_mask_image_tensor[0].cpu().numpy().astype(
|
||||
np.uint8)
|
||||
quad_mask_image = Image.fromarray(quad_mask_image_np).convert('RGB')
|
||||
quad_mask_image.save(
|
||||
os.path.join(mask_image_dir,
|
||||
'{}_quad_{:.5f}.png'.format(view_idx, view_score)))
|
||||
|
||||
|
||||
def visualize_outputs(output_dir, init_image_dir, mask_image_dir,
|
||||
inpainted_image_dir, num_views):
|
||||
# subplot settings
|
||||
num_col = 3
|
||||
num_row = 1
|
||||
sus = 4
|
||||
|
||||
summary_image_dir = os.path.join(output_dir, 'summary')
|
||||
os.makedirs(summary_image_dir, exist_ok=True)
|
||||
|
||||
# graph settings
|
||||
print('=> visualizing results...')
|
||||
for view_idx in range(num_views):
|
||||
plt.switch_backend('agg')
|
||||
fig = plt.figure(dpi=100)
|
||||
fig.set_size_inches(sus * num_col, sus * (num_row + 1))
|
||||
fig.set_facecolor('white')
|
||||
|
||||
# rendering
|
||||
plt.subplot2grid((num_row, num_col), (0, 0))
|
||||
plt.imshow(
|
||||
Image.open(
|
||||
os.path.join(init_image_dir, '{}.png'.format(view_idx))))
|
||||
plt.text(
|
||||
0,
|
||||
0,
|
||||
'Rendering',
|
||||
fontsize=16,
|
||||
color='black',
|
||||
backgroundcolor='white')
|
||||
plt.axis('off')
|
||||
|
||||
# mask
|
||||
plt.subplot2grid((num_row, num_col), (0, 1))
|
||||
plt.imshow(
|
||||
Image.open(
|
||||
os.path.join(mask_image_dir,
|
||||
'{}_project.png'.format(view_idx))))
|
||||
plt.text(
|
||||
0,
|
||||
0,
|
||||
'Project Mask',
|
||||
fontsize=16,
|
||||
color='black',
|
||||
backgroundcolor='white')
|
||||
plt.set_cmap(cm.Greys_r)
|
||||
plt.axis('off')
|
||||
|
||||
# inpainted
|
||||
plt.subplot2grid((num_row, num_col), (0, 2))
|
||||
plt.imshow(
|
||||
Image.open(
|
||||
os.path.join(inpainted_image_dir, '{}.png'.format(view_idx))))
|
||||
plt.text(
|
||||
0,
|
||||
0,
|
||||
'Inpainted',
|
||||
fontsize=16,
|
||||
color='black',
|
||||
backgroundcolor='white')
|
||||
plt.axis('off')
|
||||
|
||||
plt.savefig(
|
||||
os.path.join(summary_image_dir, '{}.png'.format(view_idx)),
|
||||
bbox_inches='tight')
|
||||
fig.clf()
|
||||
|
||||
# generate GIF
|
||||
images = [
|
||||
imageio.imread(
|
||||
os.path.join(summary_image_dir, '{}.png'.format(view_idx)))
|
||||
for view_idx in range(num_views)
|
||||
]
|
||||
imageio.mimsave(
|
||||
os.path.join(summary_image_dir, 'output.gif'), images, duration=1)
|
||||
|
||||
print('=> done!')
|
||||
|
||||
|
||||
def visualize_principle_viewpoints(output_dir, dist_list, elev_list,
|
||||
azim_list):
|
||||
theta_list = [e for e in azim_list]
|
||||
phi_list = [90 - e for e in elev_list]
|
||||
DIST = dist_list[0]
|
||||
|
||||
xyz_list = [
|
||||
polar_to_xyz(theta, phi, DIST)
|
||||
for theta, phi in zip(theta_list, phi_list)
|
||||
]
|
||||
|
||||
xyz_np = np.array(xyz_list)
|
||||
color_np = np.array([[0, 0, 0]]).repeat(xyz_np.shape[0], 0)
|
||||
|
||||
ax = plt.axes(projection='3d')
|
||||
SCALE = 0.8
|
||||
ax.set_xlim((-DIST, DIST))
|
||||
ax.set_ylim((-DIST, DIST))
|
||||
ax.set_zlim((-SCALE * DIST, SCALE * DIST))
|
||||
|
||||
ax.scatter(
|
||||
xyz_np[:, 0],
|
||||
xyz_np[:, 2],
|
||||
xyz_np[:, 1],
|
||||
s=100,
|
||||
c=color_np,
|
||||
depthshade=True,
|
||||
label='Principle views')
|
||||
ax.scatter([0], [0], [0],
|
||||
c=[[1, 0, 0]],
|
||||
s=100,
|
||||
depthshade=True,
|
||||
label='Object center')
|
||||
|
||||
# draw hemisphere
|
||||
# theta inclination angle
|
||||
# phi azimuthal angle
|
||||
n_theta = 50 # number of values for theta
|
||||
n_phi = 200 # number of values for phi
|
||||
r = DIST # radius of sphere
|
||||
|
||||
# theta, phi = np.mgrid[0.0:0.5*np.pi:n_theta*1j, 0.0:2.0*np.pi:n_phi*1j]
|
||||
theta, phi = np.mgrid[0.0:1 * np.pi:n_theta * 1j,
|
||||
0.0:2.0 * np.pi:n_phi * 1j]
|
||||
|
||||
x = r * np.sin(theta) * np.cos(phi)
|
||||
y = r * np.sin(theta) * np.sin(phi)
|
||||
z = r * np.cos(theta)
|
||||
|
||||
ax.plot_surface(x, y, z, rstride=1, cstride=1, alpha=0.25, linewidth=1)
|
||||
|
||||
# Make the grid
|
||||
ax.quiver(
|
||||
xyz_np[:, 0],
|
||||
xyz_np[:, 2],
|
||||
xyz_np[:, 1],
|
||||
-xyz_np[:, 0],
|
||||
-xyz_np[:, 2],
|
||||
-xyz_np[:, 1],
|
||||
normalize=True,
|
||||
length=0.3)
|
||||
|
||||
ax.set_xlabel('X Label')
|
||||
ax.set_ylabel('Z Label')
|
||||
ax.set_zlabel('Y Label')
|
||||
|
||||
ax.view_init(30, 35)
|
||||
ax.legend()
|
||||
|
||||
plt.show()
|
||||
|
||||
plt.savefig(os.path.join(output_dir, 'principle_viewpoints.png'))
|
||||
|
||||
|
||||
def visualize_refinement_viewpoints(output_dir, selected_view_ids, dist_list,
|
||||
elev_list, azim_list):
|
||||
theta_list = [azim_list[i] for i in selected_view_ids]
|
||||
phi_list = [90 - elev_list[i] for i in selected_view_ids]
|
||||
DIST = dist_list[0]
|
||||
|
||||
xyz_list = [
|
||||
polar_to_xyz(theta, phi, DIST)
|
||||
for theta, phi in zip(theta_list, phi_list)
|
||||
]
|
||||
|
||||
xyz_np = np.array(xyz_list)
|
||||
color_np = np.array([[0, 0, 0]]).repeat(xyz_np.shape[0], 0)
|
||||
|
||||
fig = plt.figure()
|
||||
ax = plt.axes(projection='3d')
|
||||
SCALE = 0.8
|
||||
ax.set_xlim((-DIST, DIST))
|
||||
ax.set_ylim((-DIST, DIST))
|
||||
ax.set_zlim((-SCALE * DIST, SCALE * DIST))
|
||||
|
||||
ax.scatter(
|
||||
xyz_np[:, 0],
|
||||
xyz_np[:, 2],
|
||||
xyz_np[:, 1],
|
||||
c=color_np,
|
||||
depthshade=True,
|
||||
label='Refinement views')
|
||||
ax.scatter([0], [0], [0],
|
||||
c=[[1, 0, 0]],
|
||||
s=100,
|
||||
depthshade=True,
|
||||
label='Object center')
|
||||
|
||||
# draw hemisphere
|
||||
# theta inclination angle
|
||||
# phi azimuthal angle
|
||||
n_theta = 50 # number of values for theta
|
||||
n_phi = 200 # number of values for phi
|
||||
r = DIST # radius of sphere
|
||||
|
||||
# theta, phi = np.mgrid[0.0:0.5*np.pi:n_theta*1j, 0.0:2.0*np.pi:n_phi*1j]
|
||||
theta, phi = np.mgrid[0.0:1 * np.pi:n_theta * 1j,
|
||||
0.0:2.0 * np.pi:n_phi * 1j]
|
||||
|
||||
x = r * np.sin(theta) * np.cos(phi)
|
||||
y = r * np.sin(theta) * np.sin(phi)
|
||||
z = r * np.cos(theta)
|
||||
|
||||
ax.plot_surface(x, y, z, rstride=1, cstride=1, alpha=0.25, linewidth=1)
|
||||
|
||||
# Make the grid
|
||||
ax.quiver(
|
||||
xyz_np[:, 0],
|
||||
xyz_np[:, 2],
|
||||
xyz_np[:, 1],
|
||||
-xyz_np[:, 0],
|
||||
-xyz_np[:, 2],
|
||||
-xyz_np[:, 1],
|
||||
normalize=True,
|
||||
length=0.3)
|
||||
|
||||
ax.set_xlabel('X Label')
|
||||
ax.set_ylabel('Z Label')
|
||||
ax.set_zlabel('Y Label')
|
||||
|
||||
ax.view_init(30, 35)
|
||||
ax.legend()
|
||||
|
||||
plt.show()
|
||||
|
||||
plt.savefig(os.path.join(output_dir, 'refinement_viewpoints.png'))
|
||||
|
||||
fig.clear()
|
||||
91
modelscope/models/cv/text_texture_generation/utils.py
Normal file
91
modelscope/models/cv/text_texture_generation/utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# common utils
|
||||
import os
|
||||
|
||||
import imageio.v2 as imageio
|
||||
import torch
|
||||
# pytorch3d
|
||||
from pytorch3d.io import load_obj, load_objs_as_meshes
|
||||
from pytorch3d.renderer import (AmbientLights, MeshRasterizer,
|
||||
MeshRendererWithFragments, PerspectiveCameras,
|
||||
RasterizationSettings, SoftPhongShader,
|
||||
look_at_view_transform)
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
IMAGE_SIZE = 768
|
||||
|
||||
|
||||
def init_mesh(model_path, device):
|
||||
verts, faces, aux = load_obj(model_path, device=device)
|
||||
mesh = load_objs_as_meshes([model_path], device=device)
|
||||
return mesh, verts, faces, aux
|
||||
|
||||
|
||||
def init_camera(num_views, dist, elev, azim, view_idx, device):
|
||||
interval = 360 // num_views
|
||||
azim = (azim + interval * view_idx) % 360
|
||||
R, T = look_at_view_transform(dist, elev, azim)
|
||||
T[0][2] = dist
|
||||
image_size = torch.tensor([IMAGE_SIZE, IMAGE_SIZE]).unsqueeze(0)
|
||||
focal_length = torch.tensor(2.0)
|
||||
cameras = PerspectiveCameras(
|
||||
focal_length=focal_length,
|
||||
R=R,
|
||||
T=T,
|
||||
device=device,
|
||||
image_size=image_size)
|
||||
return cameras, dist, elev, azim
|
||||
|
||||
|
||||
def init_renderer(camera, device):
|
||||
raster_settings = RasterizationSettings(image_size=IMAGE_SIZE)
|
||||
lights = AmbientLights(device=device)
|
||||
renderer = MeshRendererWithFragments(
|
||||
rasterizer=MeshRasterizer(
|
||||
cameras=camera, raster_settings=raster_settings),
|
||||
shader=SoftPhongShader(cameras=camera, lights=lights, device=device))
|
||||
|
||||
return renderer
|
||||
|
||||
|
||||
def generation_gif(mesh_path):
|
||||
num_views = 72
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = torch.device('cuda:0')
|
||||
torch.cuda.set_device(DEVICE)
|
||||
else:
|
||||
print('no gpu avaiable')
|
||||
exit()
|
||||
output_dir = 'GIF-{}'.format(num_views)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
mesh, verts, faces, aux = init_mesh(mesh_path, DEVICE)
|
||||
|
||||
# rendering
|
||||
print('=> rendering...')
|
||||
for view_idx in tqdm(range(num_views)):
|
||||
init_image_path = os.path.join(output_dir, '{}.png'.format(view_idx))
|
||||
dist = 1.8
|
||||
elev = 15
|
||||
azim = 0
|
||||
|
||||
cameras, dist, elev, azim = init_camera(num_views, dist, elev, azim,
|
||||
view_idx, DEVICE)
|
||||
renderer = init_renderer(cameras, DEVICE)
|
||||
init_images_tensor, fragments = renderer(mesh)
|
||||
|
||||
# save images
|
||||
init_image = init_images_tensor[0].cpu()
|
||||
init_image = init_image.permute(2, 0, 1)
|
||||
init_image = transforms.ToPILImage()(init_image).convert('RGB')
|
||||
init_image.save(init_image_path)
|
||||
|
||||
# generate GIF
|
||||
images = [
|
||||
imageio.imread(os.path.join(output_dir, '{}.png').format(v_id))
|
||||
for v_id in range(args.num_views)
|
||||
]
|
||||
imageio.mimsave(
|
||||
os.path.join(output_dir, 'output.gif'), images, duration=0.1)
|
||||
imageio.mimsave(os.path.join(output_dir, 'output.mp4'), images, fps=25)
|
||||
print('=> done!')
|
||||
0
modelscope/models/cv/text_to_head/__init__.py
Normal file
0
modelscope/models/cv/text_to_head/__init__.py
Normal file
55
modelscope/models/cv/text_to_head/text_to_head_model.py
Normal file
55
modelscope/models/cv/text_to_head/text_to_head_model.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import (ControlNetModel, DDIMScheduler,
|
||||
StableDiffusionControlNetPipeline)
|
||||
from diffusers.utils import load_image
|
||||
|
||||
from modelscope.models import MODELS, TorchModel
|
||||
|
||||
|
||||
@MODELS.register_module('text-to-head', 'text_to_head')
|
||||
class TextToHeadModel(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
"""The HeadReconModel is implemented based on HRN, publicly available at
|
||||
https://github.com/youngLBW/HRN
|
||||
|
||||
Args:
|
||||
model_dir: the root directory of the model files
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
|
||||
self.model_dir = model_dir
|
||||
|
||||
base_model_path = os.path.join(model_dir, 'base_model')
|
||||
controlnet_path = os.path.join(model_dir, 'control_net')
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
controlnet_path, torch_dtype=torch.float16)
|
||||
self.face_gen_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
|
||||
base_model_path, controlnet=controlnet, torch_dtype=torch.float16)
|
||||
self.face_gen_pipeline.scheduler = DDIMScheduler.from_config(
|
||||
self.face_gen_pipeline.scheduler.config)
|
||||
self.face_gen_pipeline.enable_model_cpu_offload()
|
||||
|
||||
self.add_prompt = ', 4K, good looking face, epic realistic, Sony a7, sharp, ' \
|
||||
'skin detail pores, soft light, uniform illumination'
|
||||
self.neg_prompt = 'ugly, cross eye, bangs, teeth, glasses, hat, dark, shadow'
|
||||
|
||||
control_pose_path = os.path.join(self.model_dir, 'control_pose.jpg')
|
||||
self.control_pose = load_image(control_pose_path)
|
||||
|
||||
def forward(self, input):
|
||||
prompt = input['text'] + self.add_prompt
|
||||
image = self.face_gen_pipeline(
|
||||
prompt,
|
||||
negative_prompt=self.neg_prompt,
|
||||
image=self.control_pose,
|
||||
num_inference_steps=20).images[0] # PIL.Image
|
||||
|
||||
return image
|
||||
@@ -314,7 +314,7 @@ class CLIP4Clip(CLIP4ClipPreTrainedModel):
|
||||
if key in clip_state_dict:
|
||||
del clip_state_dict[key]
|
||||
|
||||
convert_weights(self.clip)
|
||||
# convert_weights(self.clip)
|
||||
# <=== End of CLIP Encoders
|
||||
|
||||
self.sim_header = 'seqTransf'
|
||||
|
||||
@@ -421,8 +421,10 @@ class Frame_Layer(nn.Module):
|
||||
tgt = self.norm1(tgt)
|
||||
memory = self.norm2(memory)
|
||||
mask_new = adaptive_mask(tgt.shape[0], memory.shape[0], ada_para=0.2)
|
||||
if torch.cuda.is_available():
|
||||
mask_new = mask_new.cuda()
|
||||
tgt2, atten_weights = self.multihead_attn(
|
||||
tgt, memory, memory, attn_mask=mask_new.cuda())
|
||||
tgt, memory, memory, attn_mask=mask_new)
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
|
||||
tgt = self.norm3(tgt)
|
||||
|
||||
@@ -14,7 +14,8 @@ class GenUnifiedTransformer(UnifiedTransformer):
|
||||
super(GenUnifiedTransformer, self).__init__(model_dir, config, reader,
|
||||
generator)
|
||||
self.understand = config.BPETextField.understand
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.use_gpu = True
|
||||
if self.use_gpu:
|
||||
self.cuda()
|
||||
return
|
||||
@@ -201,15 +202,21 @@ class GenUnifiedTransformer(UnifiedTransformer):
|
||||
mask = state['mask']
|
||||
|
||||
# shape: [batch_size, 1, 1]
|
||||
pred_token = state['pred_token']
|
||||
pred_mask = state['pred_mask']
|
||||
pred_pos = state['pred_pos']
|
||||
pred_type = state['pred_type']
|
||||
pred_turn = state['pred_turn']
|
||||
if self.use_gpu:
|
||||
pred_token = state['pred_token'].cuda()
|
||||
pred_mask = state['pred_mask'].cuda()
|
||||
pred_pos = state['pred_pos'].cuda()
|
||||
pred_type = state['pred_type'].cuda()
|
||||
pred_turn = state['pred_turn'].cuda()
|
||||
else:
|
||||
pred_token = state['pred_token']
|
||||
pred_mask = state['pred_mask']
|
||||
pred_pos = state['pred_pos']
|
||||
pred_type = state['pred_type']
|
||||
pred_turn = state['pred_turn']
|
||||
|
||||
# list of shape(len: num_layers): [batch_size, seq_len, hidden_dim]
|
||||
cache = state['cache']
|
||||
|
||||
pred_embed = self.embedder(pred_token, pred_pos, pred_type,
|
||||
pred_turn).squeeze(-2)
|
||||
pred_embed = self.embed_layer_norm(pred_embed)
|
||||
|
||||
@@ -67,6 +67,8 @@ class SpaceGenerator(object):
|
||||
self.min_gen_len = config.Generator.min_gen_len
|
||||
self.max_gen_len = config.Generator.max_gen_len
|
||||
self.use_gpu = config.use_gpu
|
||||
if torch.cuda.is_available():
|
||||
self.use_gpu = True
|
||||
assert 1 <= self.min_gen_len <= self.max_gen_len
|
||||
return
|
||||
|
||||
@@ -184,7 +186,6 @@ class BeamSearch(SpaceGenerator):
|
||||
unk_penalty = unk_penalty.cuda()
|
||||
eos_penalty = eos_penalty.cuda()
|
||||
scores_after_end = scores_after_end.cuda()
|
||||
|
||||
if self.ignore_unk:
|
||||
scores = scores + unk_penalty
|
||||
scores = scores + eos_penalty
|
||||
|
||||
11
modelscope/ops/image_control_3d_portrait/dnnlib/__init__.py
Normal file
11
modelscope/ops/image_control_3d_portrait/dnnlib/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
from .util import EasyDict, make_cache_dir_path
|
||||
52
modelscope/ops/image_control_3d_portrait/dnnlib/util.py
Normal file
52
modelscope/ops/image_control_3d_portrait/dnnlib/util.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
"""Miscellaneous utility classes and functions."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
|
||||
class EasyDict(dict):
|
||||
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError:
|
||||
raise AttributeError(name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
self[name] = value
|
||||
|
||||
def __delattr__(self, name: str) -> None:
|
||||
del self[name]
|
||||
|
||||
|
||||
_dnnlib_cache_dir = None
|
||||
|
||||
|
||||
def set_cache_dir(path: str) -> None:
|
||||
global _dnnlib_cache_dir
|
||||
_dnnlib_cache_dir = path
|
||||
|
||||
|
||||
def make_cache_dir_path(*paths: str) -> str:
|
||||
if _dnnlib_cache_dir is not None:
|
||||
return os.path.join(_dnnlib_cache_dir, *paths)
|
||||
if 'DNNLIB_CACHE_DIR' in os.environ:
|
||||
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
||||
if 'HOME' in os.environ:
|
||||
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
||||
if 'USERPROFILE' in os.environ:
|
||||
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib',
|
||||
*paths)
|
||||
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
||||
@@ -0,0 +1,181 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import glob
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
from torch.utils.file_baton import FileBaton
|
||||
|
||||
# Global options.
|
||||
|
||||
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
||||
|
||||
# Internal helper funcs.
|
||||
|
||||
|
||||
def _find_compiler_bindir():
|
||||
patterns = [
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
||||
]
|
||||
for pattern in patterns:
|
||||
matches = sorted(glob.glob(pattern))
|
||||
if len(matches):
|
||||
return matches[-1]
|
||||
return None
|
||||
|
||||
|
||||
def _get_mangled_gpu_name():
|
||||
name = torch.cuda.get_device_name().lower()
|
||||
out = []
|
||||
for c in name:
|
||||
if re.match('[a-z0-9_-]+', c):
|
||||
out.append(c)
|
||||
else:
|
||||
out.append('-')
|
||||
return ''.join(out)
|
||||
|
||||
|
||||
# Main entry point for compiling and loading C++/CUDA plugins.
|
||||
|
||||
_cached_plugins = dict()
|
||||
|
||||
|
||||
def get_plugin(module_name,
|
||||
sources,
|
||||
headers=None,
|
||||
source_dir=None,
|
||||
**build_kwargs):
|
||||
assert verbosity in ['none', 'brief', 'full']
|
||||
if headers is None:
|
||||
headers = []
|
||||
if source_dir is not None:
|
||||
sources = [os.path.join(source_dir, fname) for fname in sources]
|
||||
headers = [os.path.join(source_dir, fname) for fname in headers]
|
||||
|
||||
# Already cached?
|
||||
if module_name in _cached_plugins:
|
||||
return _cached_plugins[module_name]
|
||||
|
||||
# Print status.
|
||||
if verbosity == 'full':
|
||||
print(f'Setting up PyTorch plugin "{module_name}"...')
|
||||
elif verbosity == 'brief':
|
||||
print(
|
||||
f'Setting up PyTorch plugin "{module_name}"... ',
|
||||
end='',
|
||||
flush=True)
|
||||
verbose_build = (verbosity == 'full')
|
||||
|
||||
# Compile and load.
|
||||
try:
|
||||
if os.name == 'nt' and os.system('where cl.exe >nul 2>nul') != 0:
|
||||
compiler_bindir = _find_compiler_bindir()
|
||||
if compiler_bindir is None:
|
||||
raise RuntimeError(
|
||||
f'Could not find MSVC/GCC/CLANG installation on this computer.'
|
||||
f' Check _find_compiler_bindir() in "{__file__}".')
|
||||
os.environ['PATH'] += ';' + compiler_bindir
|
||||
|
||||
# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
|
||||
# break the build or unnecessarily restrict what's available to nvcc.
|
||||
# Unset it to let nvcc decide based on what's available on the
|
||||
# machine.
|
||||
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
|
||||
|
||||
# Incremental build md5sum trickery. Copies all the input source files
|
||||
# into a cached build directory under a combined md5 digest of the input
|
||||
# source files. Copying is done only if the combined digest has changed.
|
||||
# This keeps input file timestamps and filenames the same as in previous
|
||||
# extension builds, allowing for fast incremental rebuilds.
|
||||
#
|
||||
# This optimization is done only in case all the source files reside in
|
||||
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
||||
# environment variable is set (we take this as a signal that the user
|
||||
# actually cares about this.)
|
||||
#
|
||||
# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
|
||||
# around the *.cu dependency bug in ninja config.
|
||||
#
|
||||
all_source_files = sorted(sources + headers)
|
||||
all_source_dirs = set(
|
||||
os.path.dirname(fname) for fname in all_source_files)
|
||||
if len(all_source_dirs
|
||||
) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
||||
|
||||
# Compute combined hash digest for all source files.
|
||||
hash_md5 = hashlib.md5()
|
||||
for src in all_source_files:
|
||||
with open(src, 'rb') as f:
|
||||
hash_md5.update(f.read())
|
||||
|
||||
# Select cached build directory name.
|
||||
source_digest = hash_md5.hexdigest()
|
||||
build_top_dir = torch.utils.cpp_extension._get_build_directory(
|
||||
module_name, verbose=verbose_build) # pylint: disable=protected-access
|
||||
cached_build_dir = os.path.join(
|
||||
build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
|
||||
|
||||
if not os.path.isdir(cached_build_dir):
|
||||
tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
|
||||
os.makedirs(tmpdir)
|
||||
for src in all_source_files:
|
||||
shutil.copyfile(
|
||||
src, os.path.join(tmpdir, os.path.basename(src)))
|
||||
try:
|
||||
os.replace(tmpdir, cached_build_dir) # atomic
|
||||
except OSError:
|
||||
# source directory already exists, delete tmpdir and its contents.
|
||||
shutil.rmtree(tmpdir)
|
||||
if not os.path.isdir(cached_build_dir):
|
||||
raise
|
||||
|
||||
# Compile.
|
||||
cached_sources = [
|
||||
os.path.join(cached_build_dir, os.path.basename(fname))
|
||||
for fname in sources
|
||||
]
|
||||
torch.utils.cpp_extension.load(
|
||||
name=module_name,
|
||||
build_directory=cached_build_dir,
|
||||
verbose=verbose_build,
|
||||
sources=cached_sources,
|
||||
**build_kwargs)
|
||||
else:
|
||||
torch.utils.cpp_extension.load(
|
||||
name=module_name,
|
||||
verbose=verbose_build,
|
||||
sources=sources,
|
||||
**build_kwargs)
|
||||
|
||||
# Load.
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
except Exception:
|
||||
if verbosity == 'brief':
|
||||
print('Failed!')
|
||||
raise
|
||||
|
||||
# Print status and add to cache dict.
|
||||
if verbosity == 'full':
|
||||
print(f'Done setting up PyTorch plugin "{module_name}".')
|
||||
elif verbosity == 'brief':
|
||||
print('Done.')
|
||||
_cached_plugins[module_name] = module
|
||||
return module
|
||||
325
modelscope/ops/image_control_3d_portrait/torch_utils/misc.py
Normal file
325
modelscope/ops/image_control_3d_portrait/torch_utils/misc.py
Normal file
@@ -0,0 +1,325 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import contextlib
|
||||
import re
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .. import dnnlib
|
||||
|
||||
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
||||
# same constant is used multiple times.
|
||||
|
||||
_constant_cache = dict()
|
||||
|
||||
|
||||
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
||||
value = np.asarray(value)
|
||||
if shape is not None:
|
||||
shape = tuple(shape)
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if device is None:
|
||||
device = torch.device('cpu')
|
||||
if memory_format is None:
|
||||
memory_format = torch.contiguous_format
|
||||
|
||||
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device,
|
||||
memory_format)
|
||||
tensor = _constant_cache.get(key, None)
|
||||
if tensor is None:
|
||||
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
||||
if shape is not None:
|
||||
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
||||
tensor = tensor.contiguous(memory_format=memory_format)
|
||||
_constant_cache[key] = tensor
|
||||
return tensor
|
||||
|
||||
|
||||
# Replace NaN/Inf with specified numerical values.
|
||||
|
||||
try:
|
||||
nan_to_num = torch.nan_to_num # 1.8.0a0
|
||||
except AttributeError:
|
||||
|
||||
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
||||
assert isinstance(input, torch.Tensor)
|
||||
if posinf is None:
|
||||
posinf = torch.finfo(input.dtype).max
|
||||
if neginf is None:
|
||||
neginf = torch.finfo(input.dtype).min
|
||||
assert nan == 0
|
||||
return torch.clamp(
|
||||
input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
||||
|
||||
|
||||
# Symbolic assert.
|
||||
|
||||
try:
|
||||
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
||||
except AttributeError:
|
||||
symbolic_assert = torch.Assert # 1.7.0
|
||||
|
||||
# Context manager to temporarily suppress known warnings in torch.jit.trace().
|
||||
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_tracer_warnings():
|
||||
flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
|
||||
warnings.filters.insert(0, flt)
|
||||
yield
|
||||
warnings.filters.remove(flt)
|
||||
|
||||
|
||||
# Assert that the shape of a tensor matches the given list of integers.
|
||||
# None indicates that the size of a dimension is allowed to vary.
|
||||
# Performs symbolic assertion when used in torch.jit.trace().
|
||||
|
||||
|
||||
def assert_shape(tensor, ref_shape):
|
||||
if tensor.ndim != len(ref_shape):
|
||||
raise AssertionError(
|
||||
f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}'
|
||||
)
|
||||
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
||||
if ref_size is None:
|
||||
pass
|
||||
elif isinstance(ref_size, torch.Tensor):
|
||||
with suppress_tracer_warnings(
|
||||
): # as_tensor results are registered as constants
|
||||
symbolic_assert(
|
||||
torch.equal(torch.as_tensor(size), ref_size),
|
||||
f'Wrong size for dimension {idx}')
|
||||
elif isinstance(size, torch.Tensor):
|
||||
with suppress_tracer_warnings(
|
||||
): # as_tensor results are registered as constants
|
||||
symbolic_assert(
|
||||
torch.equal(size, torch.as_tensor(ref_size)),
|
||||
f'Wrong size for dimension {idx}: expected {ref_size}')
|
||||
elif size != ref_size:
|
||||
raise AssertionError(
|
||||
f'Wrong size for dimension {idx}: got {size}, expected {ref_size}'
|
||||
)
|
||||
|
||||
|
||||
# Function decorator that calls torch.autograd.profiler.record_function().
|
||||
|
||||
|
||||
def profiled_function(fn):
|
||||
|
||||
def decorator(*args, **kwargs):
|
||||
with torch.autograd.profiler.record_function(fn.__name__):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
decorator.__name__ = fn.__name__
|
||||
return decorator
|
||||
|
||||
|
||||
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
||||
# indefinitely, shuffling items as it goes.
|
||||
|
||||
|
||||
class InfiniteSampler(torch.utils.data.Sampler):
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
rank=0,
|
||||
num_replicas=1,
|
||||
shuffle=True,
|
||||
seed=0,
|
||||
window_size=0.5):
|
||||
assert len(dataset) > 0
|
||||
assert num_replicas > 0
|
||||
assert 0 <= rank < num_replicas
|
||||
assert 0 <= window_size <= 1
|
||||
super().__init__(dataset)
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.window_size = window_size
|
||||
|
||||
def __iter__(self):
|
||||
order = np.arange(len(self.dataset))
|
||||
rnd = None
|
||||
window = 0
|
||||
if self.shuffle:
|
||||
rnd = np.random.RandomState(self.seed)
|
||||
rnd.shuffle(order)
|
||||
window = int(np.rint(order.size * self.window_size))
|
||||
|
||||
idx = 0
|
||||
while True:
|
||||
i = idx % order.size
|
||||
if idx % self.num_replicas == self.rank:
|
||||
yield order[i]
|
||||
if window >= 2:
|
||||
j = (i - rnd.randint(window)) % order.size
|
||||
order[i], order[j] = order[j], order[i]
|
||||
idx += 1
|
||||
|
||||
|
||||
# Utilities for operating with torch.nn.Module parameters and buffers.
|
||||
|
||||
|
||||
def params_and_buffers(module):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
return list(module.parameters()) + list(module.buffers())
|
||||
|
||||
|
||||
def named_params_and_buffers(module):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
return list(module.named_parameters()) + list(module.named_buffers())
|
||||
|
||||
|
||||
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
||||
assert isinstance(src_module, torch.nn.Module)
|
||||
assert isinstance(dst_module, torch.nn.Module)
|
||||
src_tensors = dict(named_params_and_buffers(src_module))
|
||||
for name, tensor in named_params_and_buffers(dst_module):
|
||||
assert (name in src_tensors) or (not require_all)
|
||||
if name in src_tensors:
|
||||
tensor.copy_(src_tensors[name].detach()).requires_grad_(
|
||||
tensor.requires_grad)
|
||||
|
||||
|
||||
# Context manager for easily enabling/disabling DistributedDataParallel
|
||||
# synchronization.
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ddp_sync(module, sync):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
if sync or not isinstance(module,
|
||||
torch.nn.parallel.DistributedDataParallel):
|
||||
yield
|
||||
else:
|
||||
with module.no_sync():
|
||||
yield
|
||||
|
||||
|
||||
# Check DistributedDataParallel consistency across processes.
|
||||
|
||||
|
||||
def check_ddp_consistency(module, ignore_regex=None):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
for name, tensor in named_params_and_buffers(module):
|
||||
fullname = type(module).__name__ + '.' + name
|
||||
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
||||
continue
|
||||
tensor = tensor.detach()
|
||||
if tensor.is_floating_point():
|
||||
tensor = nan_to_num(tensor)
|
||||
other = tensor.clone()
|
||||
torch.distributed.broadcast(tensor=other, src=0)
|
||||
assert (tensor == other).all(), fullname
|
||||
|
||||
|
||||
# Print summary table of module hierarchy.
|
||||
|
||||
|
||||
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
assert not isinstance(module, torch.jit.ScriptModule)
|
||||
assert isinstance(inputs, (tuple, list))
|
||||
|
||||
# Register hooks.
|
||||
entries = []
|
||||
nesting = [0]
|
||||
|
||||
def pre_hook(_mod, _inputs):
|
||||
nesting[0] += 1
|
||||
|
||||
def post_hook(mod, _inputs, outputs):
|
||||
nesting[0] -= 1
|
||||
if nesting[0] <= max_nesting:
|
||||
outputs = list(outputs) if isinstance(outputs,
|
||||
(tuple,
|
||||
list)) else [outputs]
|
||||
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
||||
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
||||
|
||||
hooks = [
|
||||
mod.register_forward_pre_hook(pre_hook) for mod in module.modules()
|
||||
]
|
||||
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
||||
|
||||
# Run module.
|
||||
outputs = module(*inputs)
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# Identify unique outputs, parameters, and buffers.
|
||||
tensors_seen = set()
|
||||
for e in entries:
|
||||
e.unique_params = [
|
||||
t for t in e.mod.parameters() if id(t) not in tensors_seen
|
||||
]
|
||||
e.unique_buffers = [
|
||||
t for t in e.mod.buffers() if id(t) not in tensors_seen
|
||||
]
|
||||
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
||||
tensors_seen |= {
|
||||
id(t)
|
||||
for t in e.unique_params + e.unique_buffers + e.unique_outputs
|
||||
}
|
||||
|
||||
# Filter out redundant entries.
|
||||
if skip_redundant:
|
||||
entries = [
|
||||
e for e in entries if len(e.unique_params) or len(e.unique_buffers)
|
||||
or len(e.unique_outputs)
|
||||
]
|
||||
|
||||
# Construct table.
|
||||
rows = [[
|
||||
type(module).__name__, 'Parameters', 'Buffers', 'Output shape',
|
||||
'Datatype'
|
||||
]]
|
||||
rows += [['---'] * len(rows[0])]
|
||||
param_total = 0
|
||||
buffer_total = 0
|
||||
submodule_names = {mod: name for name, mod in module.named_modules()}
|
||||
for e in entries:
|
||||
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
||||
param_size = sum(t.numel() for t in e.unique_params)
|
||||
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
||||
output_shapes = [str(list(t.shape)) for t in e.outputs]
|
||||
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
||||
rows += [[
|
||||
name + (':0' if len(e.outputs) >= 2 else ''),
|
||||
str(param_size) if param_size else '-',
|
||||
str(buffer_size) if buffer_size else '-',
|
||||
(output_shapes + ['-'])[0],
|
||||
(output_dtypes + ['-'])[0],
|
||||
]]
|
||||
for idx in range(1, len(e.outputs)):
|
||||
rows += [[
|
||||
name + f':{idx}', '-', '-', output_shapes[idx],
|
||||
output_dtypes[idx]
|
||||
]]
|
||||
param_total += param_size
|
||||
buffer_total += buffer_size
|
||||
rows += [['---'] * len(rows[0])]
|
||||
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
||||
|
||||
# Print table.
|
||||
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
||||
print()
|
||||
for row in rows:
|
||||
print(' '.join(cell + ' ' * (width - len(cell))
|
||||
for cell, width in zip(row, widths)))
|
||||
print()
|
||||
return outputs
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user