Merge pull request #574 from modelscope/master-merge-internal20231007

Master merge internal20231007
This commit is contained in:
liuyhwangyh
2023-10-11 10:31:10 +08:00
committed by GitHub
170 changed files with 31931 additions and 288 deletions

View File

@@ -1,2 +1,3 @@
recursive-include modelscope/configs *.py *.cu *.h *.cpp
recursive-include modelscope/cli/template *.tpl
recursive-include modelscope/utils *.json

View File

@@ -29,9 +29,10 @@ RUN pip install --no-cache-dir text2sql_lgesql==1.3.0 \
detectron2==0.3 -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html --force --no-deps
RUN pip install --no-cache-dir mpi4py paint_ldm \
mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 pai-easycv \
mmcls>=0.21.0 mmdet>=2.25.0 decord>=0.6.0 pai-easycv ms_swift \
ipykernel fasttext fairseq deepspeed -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
ARG USE_GPU
# for cpu install cpu version faiss, faiss depends on blas lib, we install libopenblas TODO rename gpu or cpu version faiss
RUN if [ "$USE_GPU" = "True" ] ; then \
pip install --no-cache-dir funtextprocessing kwsbp==0.0.6 faiss==1.7.2 safetensors typeguard==2.13.3 scikit-learn librosa==0.9.2 funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html; \
@@ -45,10 +46,14 @@ COPY examples /modelscope/examples
# for pai-easycv setup compatiblity issue
ENV SETUPTOOLS_USE_DISTUTILS=stdlib
RUN CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" pip install --no-cache-dir 'git+https://github.com/facebookresearch/detectron2.git'
RUN if [ "$USE_GPU" = "True" ] ; then \
CUDA_HOME=/usr/local/cuda TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6" pip install --no-cache-dir 'git+https://github.com/facebookresearch/detectron2.git'; \
else \
echo 'cpu unsupport detectron2'; \
fi
# torchmetrics==0.11.4 for ofa
RUN pip install --no-cache-dir tiktoken torchmetrics==0.11.4 https://modelscope.oss-cn-beijing.aliyuncs.com/releases/v/ms_swift-1.1.0-py3-none-any.whl transformers_stream_generator 'protobuf<=3.20.0' bitsandbytes basicsr
RUN pip install --no-cache-dir jupyterlab torchmetrics==0.11.4 tiktoken transformers_stream_generator 'protobuf<=3.20.0' bitsandbytes basicsr
COPY docker/scripts/install_flash_attension.sh /tmp/install_flash_attension.sh
RUN if [ "$USE_GPU" = "True" ] ; then \
bash /tmp/install_flash_attension.sh; \

View File

@@ -4,36 +4,39 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .version import __release_datetime__, __version__
from .trainers import EpochBasedTrainer, TrainingArgs, build_dataset_from_file
from .trainers import Hook, Priority
from .exporters import Exporter
from .exporters import TfModelExporter
from .exporters import TorchModelExporter
from .exporters import Exporter, TfModelExporter, TorchModelExporter
from .hub.api import HubApi
from .hub.snapshot_download import snapshot_download
from .hub.check_model import check_local_model_is_latest, check_model_is_id
from .hub.push_to_hub import push_to_hub, push_to_hub_async
from .hub.check_model import check_model_is_id, check_local_model_is_latest
from .metrics import AudioNoiseMetric, Metric, task_default_metrics, ImageColorEnhanceMetric, ImageDenoiseMetric, \
ImageInstanceSegmentationCOCOMetric, ImagePortraitEnhancementMetric, SequenceClassificationMetric, \
TextGenerationMetric, TokenClassificationMetric, VideoSummarizationMetric, MovieSceneSegmentationMetric, \
AccuracyMetric, BleuMetric, ImageInpaintingMetric, ReferringVideoObjectSegmentationMetric, \
VideoFrameInterpolationMetric, VideoStabilizationMetric, VideoSuperResolutionMetric, PplMetric, \
ImageQualityAssessmentDegradationMetric, ImageQualityAssessmentMosMetric, TextRankingMetric, \
LossMetric, ImageColorizationMetric, OCRRecognitionMetric
from .hub.snapshot_download import snapshot_download
from .metrics import (
AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric,
ImageColorizationMetric, ImageDenoiseMetric, ImageInpaintingMetric,
ImageInstanceSegmentationCOCOMetric, ImagePortraitEnhancementMetric,
ImageQualityAssessmentDegradationMetric,
ImageQualityAssessmentMosMetric, LossMetric, Metric,
MovieSceneSegmentationMetric, OCRRecognitionMetric, PplMetric,
ReferringVideoObjectSegmentationMetric, SequenceClassificationMetric,
TextGenerationMetric, TextRankingMetric, TokenClassificationMetric,
VideoFrameInterpolationMetric, VideoStabilizationMetric,
VideoSummarizationMetric, VideoSuperResolutionMetric,
task_default_metrics)
from .models import Model, TorchModel
from .preprocessors import Preprocessor
from .msdatasets import MsDataset
from .pipelines import Pipeline, pipeline
from .utils.hub import read_config, create_model_if_not_exist
from .utils.logger import get_logger
from .preprocessors import Preprocessor
from .trainers import (EpochBasedTrainer, Hook, Priority, TrainingArgs,
build_dataset_from_file)
from .utils.constant import Tasks
from .utils.hf_util import AutoConfig, GenerationConfig, GPTQConfig, BitsAndBytesConfig
from .utils.hf_util import AutoConfig, GPTQConfig, BitsAndBytesConfig
from .utils.hf_util import (AutoModel, AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification)
from .utils.hf_util import AutoTokenizer
from .msdatasets import MsDataset
AutoModelForTokenClassification, AutoTokenizer,
GenerationConfig)
from .utils.hub import create_model_if_not_exist, read_config
from .utils.logger import get_logger
from .version import __release_datetime__, __version__
else:
_import_structure = {

View File

@@ -7,13 +7,13 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .base import Exporter
from .builder import build_exporter
from .cv import CartoonTranslationExporter
from .nlp import CsanmtForTranslationExporter
from .tf_model_exporter import TfModelExporter
from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter
from .torch_model_exporter import TorchModelExporter
from .cv import FaceDetectionSCRFDExporter
from .cv import CartoonTranslationExporter, FaceDetectionSCRFDExporter
from .multi_modal import StableDiffuisonExporter
from .nlp import (CsanmtForTranslationExporter,
SbertForSequenceClassificationExporter,
SbertForZeroShotClassificationExporter)
from .tf_model_exporter import TfModelExporter
from .torch_model_exporter import TorchModelExporter
else:
_import_structure = {
'base': ['Exporter'],

View File

@@ -6,8 +6,9 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .cartoon_translation_exporter import CartoonTranslationExporter
from .object_detection_damoyolo_exporter import ObjectDetectionDamoyoloExporter
from .face_detection_scrfd_exporter import FaceDetectionSCRFDExporter
from .object_detection_damoyolo_exporter import \
ObjectDetectionDamoyoloExporter
else:
_import_structure = {
'cartoon_translation_exporter': ['CartoonTranslationExporter'],

View File

@@ -0,0 +1,41 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from functools import partial
from typing import Mapping
import numpy as np
import onnx
import torch
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.torch_model_exporter import TorchModelExporter
from modelscope.metainfo import Models
from modelscope.utils.constant import ModelFile, Tasks
@EXPORTERS.register_module(
Tasks.ocr_detection, module_name=Models.ocr_detection)
class OCRDetectionDBExporter(TorchModelExporter):
def export_onnx(self,
output_dir: str,
opset=11,
input_shape=(1, 3, 800, 800)):
onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE)
dummy_input = torch.randn(*input_shape)
self.model.onnx_export = True
self.model.eval()
_ = self.model(dummy_input)
torch.onnx._export(
self.model,
dummy_input,
onnx_file,
input_names=[
'images',
],
output_names=[
'pred',
],
opset_version=opset)
return {'model', onnx_file}

View File

@@ -0,0 +1,40 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from functools import partial
from typing import Mapping
import numpy as np
import onnx
import torch
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.torch_model_exporter import TorchModelExporter
from modelscope.metainfo import Models
from modelscope.utils.constant import ModelFile, Tasks
@EXPORTERS.register_module(
Tasks.ocr_recognition, module_name=Models.ocr_recognition)
class OCRRecognitionExporter(TorchModelExporter):
def export_onnx(self,
output_dir: str,
opset=11,
input_shape=(1, 3, 32, 640)):
onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE)
dummy_input = torch.randn(*input_shape)
self.model.onnx_export = True
self.model.eval()
_ = self.model(dummy_input)
torch.onnx._export(
self.model,
dummy_input,
onnx_file,
input_names=[
'images',
],
output_names=[
'pred',
],
opset_version=opset)
return {'model', onnx_file}

View File

@@ -6,7 +6,8 @@ from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .csanmt_for_translation_exporter import CsanmtForTranslationExporter
from .model_for_token_classification_exporter import ModelForSequenceClassificationExporter
from .model_for_token_classification_exporter import \
ModelForSequenceClassificationExporter
from .sbert_for_sequence_classification_exporter import \
SbertForSequenceClassificationExporter
from .sbert_for_zero_shot_classification_exporter import \

View File

@@ -28,7 +28,8 @@ class CsanmtForTranslationExporter(TfModelExporter):
tf.disable_eager_execution()
super().__init__(model)
from modelscope.pipelines.nlp.translation_pipeline import TranslationPipeline
from modelscope.pipelines.nlp.translation_pipeline import \
TranslationPipeline
self.pipeline = TranslationPipeline(self.model)
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:

View File

@@ -77,7 +77,9 @@ class ModelForSequenceClassificationExporter(TorchModelExporter):
return
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(output)
ort_session = ort.InferenceSession(
output,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
with torch.no_grad():
model.eval()
outputs_origin = model.forward(

View File

@@ -102,7 +102,9 @@ class TfModelExporter(Exporter):
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model, full_check=True)
ort_session = ort.InferenceSession(output)
ort_session = ort.InferenceSession(
output,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
outputs_origin = call_func(
dummy_inputs) if call_func is not None else model(dummy_inputs)
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):

View File

@@ -30,7 +30,7 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_TIMEOUT,
DEFAULT_CREDENTIALS_PATH,
MODELSCOPE_CLOUD_ENVIRONMENT,
MODELSCOPE_CLOUD_USERNAME,
ONE_YEAR_SECONDS,
MODELSCOPE_REQUEST_ID, ONE_YEAR_SECONDS,
REQUESTS_API_HTTP_METHOD, Licenses,
ModelVisibility)
from modelscope.hub.errors import (InvalidParameter, NotExistError,
@@ -105,7 +105,9 @@ class HubApi:
"""
path = f'{self.endpoint}/api/v1/login'
r = self.session.post(
path, json={'AccessToken': access_token}, headers=self.headers)
path,
json={'AccessToken': access_token},
headers=self.builder_headers(self.headers))
raise_for_http_status(r)
d = r.json()
raise_on_error(d)
@@ -166,7 +168,10 @@ class HubApi:
'TrainId': os.environ.get('MODELSCOPE_TRAIN_ID', ''),
}
r = self.session.post(
path, json=body, cookies=cookies, headers=self.headers)
path,
json=body,
cookies=cookies,
headers=self.builder_headers(self.headers))
handle_http_post_error(r, path, body)
raise_on_error(r.json())
model_repo_url = f'{get_endpoint()}/{model_id}'
@@ -189,7 +194,9 @@ class HubApi:
raise ValueError('Token does not exist, please login first.')
path = f'{self.endpoint}/api/v1/models/{model_id}'
r = self.session.delete(path, cookies=cookies, headers=self.headers)
r = self.session.delete(path,
cookies=cookies,
headers=self.builder_headers(self.headers))
raise_for_http_status(r)
raise_on_error(r.json())
@@ -223,7 +230,8 @@ class HubApi:
else:
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}'
r = self.session.get(path, cookies=cookies, headers=self.headers)
r = self.session.get(path, cookies=cookies,
headers=self.builder_headers(self.headers))
handle_http_response(r, logger, cookies, model_id)
if r.status_code == HTTPStatus.OK:
if is_ok(r.json()):
@@ -390,7 +398,7 @@ class HubApi:
data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' %
(owner_or_group, page_number, page_size),
cookies=cookies,
headers=self.headers)
headers=self.builder_headers(self.headers))
handle_http_response(r, logger, cookies, 'list_model')
if r.status_code == HTTPStatus.OK:
if is_ok(r.json()):
@@ -435,7 +443,8 @@ class HubApi:
if cutoff_timestamp is None:
cutoff_timestamp = get_release_datetime()
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp
r = self.session.get(path, cookies=cookies, headers=self.headers)
r = self.session.get(path, cookies=cookies,
headers=self.builder_headers(self.headers))
handle_http_response(r, logger, cookies, model_id)
d = r.json()
raise_on_error(d)
@@ -472,13 +481,15 @@ class HubApi:
cutoff_timestamp=release_timestamp,
use_cookies=False if cookies is None else cookies)
if len(revisions) == 0:
raise NoValidRevisionError(
'The model: %s has no valid revision!' % model_id)
# tags (revisions) returned from backend are guaranteed to be ordered by create-time
# we shall obtain the latest revision created earlier than release version of this branch
revision = revisions[0]
logger.warning(('There is no version specified and there is no version in the model repository,'
'use the master branch, which is fragile, please use it with caution!'))
revision = MASTER_MODEL_BRANCH
else:
# tags (revisions) returned from backend are guaranteed to be ordered by create-time
# we shall obtain the latest revision created earlier than release version of this branch
revision = revisions[0]
logger.info(
'Model revision not specified, use the latest revision: %s'
'Model revision not specified, use revision: %s'
% revision)
else:
# use user-specified revision
@@ -487,8 +498,11 @@ class HubApi:
cutoff_timestamp=current_timestamp,
use_cookies=False if cookies is None else cookies)
if revision not in revisions:
raise NotExistError('The model: %s has no revision: %s !' %
(model_id, revision))
if revision == MASTER_MODEL_BRANCH:
logger.warning('Using the master branch is fragile, please use it with caution!')
else:
raise NotExistError('The model: %s has no revision: %s !' %
(model_id, revision))
logger.info('Use user-specified model revision: %s' % revision)
return revision
@@ -510,7 +524,8 @@ class HubApi:
cookies = self._check_cookie(use_cookies)
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions'
r = self.session.get(path, cookies=cookies, headers=self.headers)
r = self.session.get(path, cookies=cookies,
headers=self.builder_headers(self.headers))
handle_http_response(r, logger, cookies, model_id)
d = r.json()
raise_on_error(d)
@@ -552,6 +567,7 @@ class HubApi:
if root is not None:
path = path + f'&Root={root}'
headers = self.headers if headers is None else headers
headers['X-Request-ID'] = str(uuid.uuid4().hex)
r = self.session.get(
path, cookies=cookies, headers=headers)
@@ -570,7 +586,8 @@ class HubApi:
def list_datasets(self):
path = f'{self.endpoint}/api/v1/datasets'
params = {}
r = self.session.get(path, params=params, headers=self.headers)
r = self.session.get(path, params=params,
headers=self.builder_headers(self.headers))
raise_for_http_status(r)
dataset_list = r.json()[API_RESPONSE_FIELD_DATA]
return [x['Name'] for x in dataset_list]
@@ -590,7 +607,9 @@ class HubApi:
""" Get the meta file-list of the dataset. """
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
cookies = ModelScopeConfig.get_cookies()
r = self.session.get(datahub_url, cookies=cookies, headers=self.headers)
r = self.session.get(datahub_url,
cookies=cookies,
headers=self.builder_headers(self.headers))
resp = r.json()
datahub_raise_on_error(datahub_url, resp)
file_list = resp['Data']
@@ -736,7 +755,9 @@ class HubApi:
cookies = ModelScopeConfig.get_cookies()
r = self.session.get(
url=datahub_url, cookies=cookies, headers=self.headers)
url=datahub_url,
cookies=cookies,
headers=self.builder_headers(self.headers))
resp = r.json()
raise_on_error(resp)
return resp['Data']
@@ -759,7 +780,11 @@ class HubApi:
data = dict(
data=dataset_info,
)
r = self.session.post(url=virgo_dataset_url, json=data, cookies=cookies, headers=self.headers, timeout=900)
r = self.session.post(url=virgo_dataset_url,
json=data,
cookies=cookies,
headers=self.builder_headers(self.headers),
timeout=900)
resp = r.json()
if resp['code'] != 0:
raise RuntimeError(f'Failed to get virgo dataset: {resp}')
@@ -773,7 +798,8 @@ class HubApi:
zip_file_name: str):
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
cookies = ModelScopeConfig.get_cookies()
r = self.session.get(url=datahub_url, cookies=cookies, headers=self.headers)
r = self.session.get(url=datahub_url, cookies=cookies,
headers=self.builder_headers(self.headers))
resp = r.json()
# get visibility of the dataset
raise_on_error(resp)
@@ -781,7 +807,8 @@ class HubApi:
visibility = DatasetVisibilityMap.get(data['Visibility'])
datahub_sts_url = f'{datahub_url}/ststoken?Revision={revision}'
r_sts = self.session.get(url=datahub_sts_url, cookies=cookies, headers=self.headers)
r_sts = self.session.get(url=datahub_sts_url, cookies=cookies,
headers=self.builder_headers(self.headers))
resp_sts = r_sts.json()
raise_on_error(resp_sts)
data_sts = resp_sts['Data']
@@ -848,7 +875,8 @@ class HubApi:
# Download count
download_count_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase'
download_count_resp = self.session.post(download_count_url, cookies=cookies, headers=self.headers)
download_count_resp = self.session.post(download_count_url, cookies=cookies,
headers=self.builder_headers(self.headers))
raise_for_http_status(download_count_resp)
# Download uv
@@ -860,13 +888,18 @@ class HubApi:
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
download_uv_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/' \
f'{channel}?user={user_name}'
download_uv_resp = self.session.post(download_uv_url, cookies=cookies, headers=self.headers)
download_uv_resp = self.session.post(download_uv_url, cookies=cookies,
headers=self.builder_headers(self.headers))
download_uv_resp = download_uv_resp.json()
raise_on_error(download_uv_resp)
except Exception as e:
logger.error(e)
def builder_headers(self, headers):
return {MODELSCOPE_REQUEST_ID: str(uuid.uuid4().hex),
**headers}
class ModelScopeConfig:
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)

View File

@@ -3,7 +3,7 @@
import os
from pathlib import Path
MODELSCOPE_URL_SCHEME = 'http://'
MODELSCOPE_URL_SCHEME = 'https://'
DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn'
DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DOMAIN
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB = int(
@@ -31,6 +31,7 @@ MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG'
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60
MODEL_META_FILE_NAME = '.mdl'
MODEL_META_MODEL_ID = 'id'
MODELSCOPE_REQUEST_ID = 'X-Request-ID'
class Licenses(object):

View File

@@ -5,6 +5,7 @@ from http import HTTPStatus
import requests
from requests.exceptions import HTTPError
from modelscope.hub.constants import MODELSCOPE_REQUEST_ID
from modelscope.utils.logger import get_logger
logger = get_logger()
@@ -46,6 +47,13 @@ class FileDownloadError(Exception):
pass
def get_request_id(response: requests.Response):
if MODELSCOPE_REQUEST_ID in response.request.headers:
return response.request.headers[MODELSCOPE_REQUEST_ID]
else:
return ''
def is_ok(rsp):
""" Check the request is ok
@@ -71,12 +79,14 @@ def handle_http_post_error(response, url, request_body):
response.raise_for_status()
except HTTPError as error:
message = _decode_response_error(response)
raise HTTPError('Request %s with body: %s exception, '
'Response details: %s' %
(url, request_body, message)) from error
raise HTTPError(
'Request %s with body: %s exception, '
'Response details: %s, request id: %s' %
(url, request_body, message, get_request_id(response))) from error
def handle_http_response(response, logger, cookies, model_id):
def handle_http_response(response: requests.Response, logger, cookies,
model_id):
try:
response.raise_for_status()
except HTTPError as error:
@@ -85,7 +95,8 @@ def handle_http_response(response, logger, cookies, model_id):
f'Authentication token does not exist, failed to access model {model_id} which may not exist or may be \
private. Please login first.')
message = _decode_response_error(response)
raise HTTPError('Response details: %s' % message) from error
raise HTTPError('Response details: %s, Request id: %s' %
(message, get_request_id(response))) from error
def raise_on_error(rsp):
@@ -122,9 +133,10 @@ def datahub_raise_on_error(url, rsp):
if rsp.get('Code') == HTTPStatus.OK:
return True
else:
request_id = get_request_id(rsp)
raise RequestError(
f"Url = {url}, Message = {rsp.get('Message')}, Please specify correct dataset_name and namespace."
)
f"Url = {url}, Request id={request_id} Message = {rsp.get('Message')},\
Please specify correct dataset_name and namespace.")
def raise_for_http_status(rsp):
@@ -146,14 +158,14 @@ def raise_for_http_status(rsp):
reason = rsp.reason.decode('iso-8859-1')
else:
reason = rsp.reason
request_id = get_request_id(rsp)
if 400 <= rsp.status_code < 500:
http_error_msg = u'%s Client Error: %s for url: %s' % (rsp.status_code,
reason, rsp.url)
http_error_msg = u'%s Client Error: %s, Request id: %s for url: %s' % (
rsp.status_code, reason, request_id, rsp.url)
elif 500 <= rsp.status_code < 600:
http_error_msg = u'%s Server Error: %s for url: %s' % (rsp.status_code,
reason, rsp.url)
http_error_msg = u'%s Server Error: %s, Request id: %s, for url: %s' % (
rsp.status_code, reason, request_id, rsp.url)
if http_error_msg:
req = rsp.request

View File

@@ -4,6 +4,7 @@ import copy
import os
import tempfile
import threading
import uuid
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from http.cookiejar import CookieJar
@@ -192,6 +193,7 @@ def download_part_with_retry(params):
progress, start, end, url, file_name, cookies, headers = params
get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['Range'] = 'bytes=%s-%s' % (start, end)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
retry = Retry(
total=API_FILE_DOWNLOAD_RETRY_TIMES,
backoff_factor=1,
@@ -289,6 +291,7 @@ def http_get_file(
temp_file_manager = partial(
tempfile.NamedTemporaryFile, mode='wb', dir=local_dir, delete=False)
get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
with temp_file_manager() as temp_file:
logger.debug('downloading %s to %s', url, temp_file.name)
# retry sleep 0.5s, 1s, 2s, 4s

View File

@@ -82,6 +82,7 @@ class Models(object):
image_skychange = 'image-skychange'
video_human_matting = 'video-human-matting'
human_reconstruction = 'human-reconstruction'
text_texture_generation = 'text-texture-generation'
video_frame_interpolation = 'video-frame-interpolation'
video_object_segmentation = 'video-object-segmentation'
video_deinterlace = 'video-deinterlace'
@@ -124,6 +125,8 @@ class Models(object):
pedestrian_attribute_recognition = 'pedestrian-attribute-recognition'
image_try_on = 'image-try-on'
human_image_generation = 'human-image-generation'
image_view_transform = 'image-view-transform'
image_control_3d_portrait = 'image-control-3d-portrait'
# nlp models
bert = 'bert'
@@ -293,6 +296,7 @@ class Pipelines(object):
table_recognition = 'dla34-table-recognition'
lineless_table_recognition = 'lore-lineless-table-recognition'
license_plate_detection = 'resnet18-license-plate-detection'
card_detection_correction = 'resnet18-card-detection-correction'
action_recognition = 'TAdaConv_action-recognition'
animal_recognition = 'resnet101-animal-recognition'
general_recognition = 'resnet101-general-recognition'
@@ -365,6 +369,8 @@ class Pipelines(object):
hand_detection = 'yolox-pai_hand-detection'
skin_retouching = 'unet-skin-retouching'
face_reconstruction = 'resnet50-face-reconstruction'
head_reconstruction = 'HRN-head-reconstruction'
text_to_head = 'HRN-text-to-head'
tinynas_classification = 'tinynas-classification'
easyrobust_classification = 'easyrobust-classification'
tinynas_detection = 'tinynas-detection'
@@ -401,6 +407,7 @@ class Pipelines(object):
image_skychange = 'image-skychange'
video_human_matting = 'video-human-matting'
human_reconstruction = 'human-reconstruction'
text_texture_generation = 'text-texture-generation'
vision_middleware_multi_task = 'vision-middleware-multi-task'
vidt = 'vidt'
video_frame_interpolation = 'video-frame-interpolation'
@@ -442,6 +449,10 @@ class Pipelines(object):
text_to_360panorama_image = 'text-to-360panorama-image'
image_try_on = 'image-try-on'
human_image_generation = 'human-image-generation'
human3d_render = 'human3d-render'
human3d_animation = 'human3d-animation'
image_view_transform = 'image-view-transform'
image_control_3d_portrait = 'image-control-3d-portrait'
# nlp tasks
automatic_post_editing = 'automatic-post-editing'
@@ -677,6 +688,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.license_plate_detection:
(Pipelines.license_plate_detection,
'damo/cv_resnet18_license-plate-detection_damo'),
Tasks.card_detection_correction: (Pipelines.card_detection_correction,
'damo/cv_resnet18_card_correction'),
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
Tasks.feature_extraction: (Pipelines.feature_extraction,
'damo/pert_feature-extraction_base-test'),
@@ -830,6 +843,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_effnetv2_video-human-matting'),
Tasks.human_reconstruction: (Pipelines.human_reconstruction,
'damo/cv_hrnet_image-human-reconstruction'),
Tasks.text_texture_generation: (
Pipelines.text_texture_generation,
'damo/cv_diffuser_text-texture-generation'),
Tasks.video_frame_interpolation: (
Pipelines.video_frame_interpolation,
'damo/cv_raft_video-frame-interpolation'),
@@ -908,7 +924,16 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.image_try_on: (Pipelines.image_try_on,
'damo/cv_SAL-VTON_virtual-try-on'),
Tasks.human_image_generation: (Pipelines.human_image_generation,
'damo/cv_FreqHPT_human-image-generation')
'damo/cv_FreqHPT_human-image-generation'),
Tasks.human3d_render: (Pipelines.human3d_render,
'damo/cv_3d-human-synthesis-library'),
Tasks.human3d_animation: (Pipelines.human3d_animation,
'damo/cv_3d-human-animation'),
Tasks.image_view_transform: (Pipelines.image_view_transform,
'damo/cv_image-view-transform'),
Tasks.image_control_3d_portrait: (
Pipelines.image_control_3d_portrait,
'damo/cv_vit_image-control-3d-portrait-synthesis')
}

View File

@@ -4,34 +4,39 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .accuracy_metric import AccuracyMetric
from .audio_noise_metric import AudioNoiseMetric
from .base import Metric
from .bleu_metric import BleuMetric
from .builder import METRICS, build_metric, task_default_metrics
from .image_color_enhance_metric import ImageColorEnhanceMetric
from .image_colorization_metric import ImageColorizationMetric
from .image_denoise_metric import ImageDenoiseMetric
from .image_inpainting_metric import ImageInpaintingMetric
from .image_instance_segmentation_metric import \
ImageInstanceSegmentationCOCOMetric
from .image_portrait_enhancement_metric import ImagePortraitEnhancementMetric
from .image_portrait_enhancement_metric import \
ImagePortraitEnhancementMetric
from .image_quality_assessment_degradation_metric import \
ImageQualityAssessmentDegradationMetric
from .image_quality_assessment_mos_metric import \
ImageQualityAssessmentMosMetric
from .loss_metric import LossMetric
from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric
from .ocr_recognition_metric import OCRRecognitionMetric
from .ppl_metric import PplMetric
from .referring_video_object_segmentation_metric import \
ReferringVideoObjectSegmentationMetric
from .sequence_classification_metric import SequenceClassificationMetric
from .text_generation_metric import TextGenerationMetric
from .text_ranking_metric import TextRankingMetric
from .token_classification_metric import TokenClassificationMetric
from .video_summarization_metric import VideoSummarizationMetric
from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric
from .accuracy_metric import AccuracyMetric
from .bleu_metric import BleuMetric
from .image_inpainting_metric import ImageInpaintingMetric
from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric
from .translation_evaluation_metric import TranslationEvaluationMetric
from .video_frame_interpolation_metric import VideoFrameInterpolationMetric
from .video_stabilization_metric import VideoStabilizationMetric
from .video_super_resolution_metric.video_super_resolution_metric import VideoSuperResolutionMetric
from .ppl_metric import PplMetric
from .image_quality_assessment_degradation_metric import ImageQualityAssessmentDegradationMetric
from .image_quality_assessment_mos_metric import ImageQualityAssessmentMosMetric
from .text_ranking_metric import TextRankingMetric
from .loss_metric import LossMetric
from .image_colorization_metric import ImageColorizationMetric
from .ocr_recognition_metric import OCRRecognitionMetric
from .translation_evaluation_metric import TranslationEvaluationMetric
from .video_summarization_metric import VideoSummarizationMetric
from .video_super_resolution_metric.video_super_resolution_metric import \
VideoSuperResolutionMetric
else:
_import_structure = {
'audio_noise_metric': ['AudioNoiseMetric'],

View File

@@ -143,10 +143,15 @@ class Model(ABC):
task_name = getattr(cfg, 'task', None)
if 'task' in kwargs:
task_name = kwargs.pop('task')
model_cfg = getattr(cfg, 'model', None)
if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
model_cfg.type = model_cfg.model_type
model_type = getattr(model_cfg, 'type', None)
try:
model_cfg = cfg.model
if hasattr(model_cfg,
'model_type') and not hasattr(model_cfg, 'type'):
model_cfg.type = model_cfg.model_type
model_type = model_cfg.type
except Exception:
model_cfg = {}
model_type = ''
if isinstance(device, str) and device.startswith('gpu'):
device = 'cuda' + device[3:]
use_hf = kwargs.pop('use_hf', None)

View File

@@ -5,10 +5,10 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
body_2d_keypoints, body_3d_keypoints, cartoon,
cmdssl_video_embedding, controllable_image_generation,
crowd_counting, face_detection, face_generation,
face_reconstruction, human_reconstruction, image_classification,
image_color_enhance, image_colorization, image_defrcn_fewshot,
image_denoise, image_editing, image_inpainting,
image_instance_segmentation, image_matching,
face_reconstruction, human3d_animation, human_reconstruction,
image_classification, image_color_enhance, image_colorization,
image_defrcn_fewshot, image_denoise, image_editing,
image_inpainting, image_instance_segmentation, image_matching,
image_mvs_depth_estimation, image_panoptic_segmentation,
image_portrait_enhancement, image_probing_model,
image_quality_assessment_degradation,
@@ -23,7 +23,8 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
referring_video_object_segmentation,
robust_image_classification, salient_detection,
shop_segmentation, stream_yolo, super_resolution,
surface_recon_common, table_recognition, video_deinterlace,
surface_recon_common, table_recognition,
text_texture_generation, video_deinterlace,
video_frame_interpolation, video_object_segmentation,
video_panoptic_segmentation, video_single_object_tracking,
video_stabilization, video_summarization,

View File

@@ -0,0 +1,673 @@
# Part of the implementation is borrowed and modified from Deep3DFaceRecon_pytorch,
# publicly available at https://github.com/sicxu/Deep3DFaceRecon_pytorch
import os
import numpy as np
import torch
import torch.nn.functional as F
from scipy.io import loadmat
from modelscope.models.cv.face_reconstruction.utils import read_obj
def perspective_projection(focal, center):
# return p.T (N, 3) @ (3, 3)
return np.array([focal, 0, center, 0, focal, center, 0, 0,
1]).reshape([3, 3]).astype(np.float32).transpose()
class SH:
def __init__(self):
self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
self.c = [
1 / np.sqrt(4 * np.pi),
np.sqrt(3.) / np.sqrt(4 * np.pi),
3 * np.sqrt(5.) / np.sqrt(12 * np.pi)
]
class ParametricFaceModel:
def __init__(self,
assets_root='assets',
recenter=True,
camera_distance=10.,
init_lit=np.array([0.8, 0, 0, 0, 0, 0, 0, 0, 0]),
focal=1015.,
center=112.,
is_train=True,
default_name='BFM_model_front.mat'):
model = loadmat(os.path.join(assets_root, '3dmm/BFM', default_name))
model_bfm_front = loadmat(
os.path.join(assets_root, '3dmm/BFM/BFM_model_front.mat'))
self.mean_shape_ori = model_bfm_front['meanshape'].astype(np.float32)
# mean face shape. [3*N,1]
self.mean_shape = model['meanshape'].astype(np.float32) # (1, 107127)
# identity basis. [3*N,80]
self.id_base = model['idBase'].astype(np.float32) # (107127, 80)
# expression basis. [3*N,64]
self.exp_base = model['exBase'].astype(np.float32) # (107127, 64)
# mean face texture. [3*N,1] (0-255)
self.mean_tex = model['meantex'].astype(np.float32) # (1, 107127)
# texture basis. [3*N,80]
self.tex_base = model['texBase'].astype(np.float32) # (107127, 80)
self.bfm_keep_inds = np.load(
os.path.join(assets_root, '3dmm/inds/bfm_keep_inds.npy'))
self.ours_hair_area_inds = np.load(
os.path.join(assets_root, '3dmm/inds/ours_hair_area_inds.npy'))
if default_name == 'ourRefineFull_model.mat':
self.mean_tex = self.mean_tex.reshape(1, -1, 3)
mean_tex_keep = self.mean_tex[:, self.bfm_keep_inds]
self.mean_tex[:, :len(self.bfm_keep_inds)] = mean_tex_keep
self.mean_tex[:,
len(self.bfm_keep_inds):] = np.array([200, 146,
118])[None,
None]
self.mean_tex[:, self.ours_hair_area_inds] = 40.0
self.mean_tex = self.mean_tex.reshape(1, -1)
self.mean_tex = np.ascontiguousarray(self.mean_tex)
self.tex_base = self.tex_base.reshape(-1, 3, 80)
tex_base_keep = self.tex_base[self.bfm_keep_inds]
self.tex_base[:len(self.bfm_keep_inds)] = tex_base_keep
self.tex_base[len(self.bfm_keep_inds):] = 0.0
self.tex_base = self.tex_base.reshape(-1, 80)
self.tex_base = np.ascontiguousarray(self.tex_base)
# face indices for each vertex that lies in. starts from 0. [N,8]
self.point_buf = model['point_buf'].astype(np.int64) - 1 # (35709, 8)
# vertex indices for each face. starts from 0. [F,3]
self.face_buf = model['tri'].astype(np.int64) - 1 # (70789, 3)
# vertex indices for 68 landmarks. starts from 0. [68,1]
self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
if default_name == 'ourRefineFull_model.mat':
self.keypoints = np.load(
os.path.join(
assets_root,
'3dmm/inds/our_refine0223_basis_withoutEyes_withUV_keypoints_inds.npy'
)).astype(np.int64)
self.point_buf = self.point_buf[:, :8] + 1
if is_train:
# vertex indices for small face region to compute photometric error. starts from 0.
self.front_mask = np.squeeze(model['frontmask2_idx']).astype(
np.int64) - 1
# vertex indices for each face from small face region. starts from 0. [f,3]
self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
# vertex indices for pre-defined skin region to compute reflectance loss
self.skin_mask = np.squeeze(model['skinmask'])
if default_name == 'ourRefineFull_model.mat':
nose_reduced_mesh = read_obj(
os.path.join(assets_root,
'3dmm/adjust_part/our_full/145_nose.obj'))
self.nose_reduced_part = nose_reduced_mesh['vertices'].reshape(
(1, -1)) - self.mean_shape
neck_mesh = read_obj(
os.path.join(assets_root,
'3dmm/adjust_part/our_full/154_neck.obj'))
self.neck_adjust_part = neck_mesh['vertices'].reshape(
(1, -1)) - self.mean_shape
eyes_mesh = read_obj(
os.path.join(
assets_root,
'3dmm/adjust_part/our_full/our_mean_adjust_eyes.obj'))
self.eyes_adjust_part = eyes_mesh['vertices'].reshape(
(1, -1)) - self.mean_shape
self.neck_slim_part = None
self.neck_stretch_part = None
elif default_name == 'ourRefineBFMEye0504_model.mat':
nose_reduced_mesh = read_obj(
os.path.join(assets_root,
'3dmm/adjust_part/our_full_bfmEyes/145_nose.obj'))
self.nose_reduced_part = nose_reduced_mesh['vertices'].reshape(
(1, -1)) - self.mean_shape
neck_mesh = read_obj(
os.path.join(assets_root,
'3dmm/adjust_part/our_full_bfmEyes/146_neck.obj'))
self.neck_adjust_part = neck_mesh['vertices'].reshape(
(1, -1)) - self.mean_shape
self.eyes_adjust_part = None
neck_slim_mesh = read_obj(
os.path.join(
assets_root,
'3dmm/adjust_part/our_full_bfmEyes/147_neckSlim2.obj'))
self.neck_slim_part = neck_slim_mesh['vertices'].reshape(
(1, -1)) - self.mean_shape
neck_stretch_mesh = read_obj(
os.path.join(
assets_root,
'3dmm/adjust_part/our_full_bfmEyes/148_neckLength.obj'))
self.neck_stretch_part = neck_stretch_mesh['vertices'].reshape(
(1, -1)) - self.mean_shape
else:
self.nose_reduced_part = None
self.neck_adjust_part = None
self.eyes_adjust_part = None
self.neck_slim_part = None
self.neck_stretch_part = None
if recenter:
mean_shape = self.mean_shape.reshape([-1, 3])
mean_shape_ori = self.mean_shape_ori.reshape([-1, 3])
mean_shape = mean_shape - np.mean(
mean_shape_ori[:35709, ...], axis=0, keepdims=True)
self.mean_shape = mean_shape.reshape([-1, 1])
eye_corner_inds = np.load(
os.path.join(assets_root, '3dmm/inds/eye_corner_inds.npy'))
self.eye_corner_inds = torch.from_numpy(eye_corner_inds).long()
eye_lines = np.load(
os.path.join(assets_root, '3dmm/inds/eye_corner_lines.npy'))
self.eye_lines = torch.from_numpy(eye_lines).long()
self.center = center
self.persc_proj = perspective_projection(focal, self.center)
self.camera_distance = camera_distance
self.SH = SH()
self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
def to(self, device):
self.device = device
for key, value in self.__dict__.items():
if type(value).__module__ == np.__name__:
setattr(self, key, torch.tensor(value).to(device))
def compute_shape(self,
id_coeff,
exp_coeff,
nose_coeff=0.0,
neck_coeff=0.0,
eyes_coeff=0.0,
neckSlim_coeff=0.0,
neckStretch_coeff=0.0):
"""
Return:
face_shape -- torch.tensor, size (B, N, 3)
Parameters:
id_coeff -- torch.tensor, size (B, 80), identity coeffs
exp_coeff -- torch.tensor, size (B, 64), expression coeffs
"""
batch_size = id_coeff.shape[0]
id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
if nose_coeff != 0:
face_shape = face_shape + nose_coeff * self.nose_reduced_part
if neck_coeff != 0:
face_shape = face_shape + neck_coeff * self.neck_adjust_part
if eyes_coeff != 0 and self.eyes_adjust_part is not None:
face_shape = face_shape + eyes_coeff * self.eyes_adjust_part
if neckSlim_coeff != 0 and self.neck_slim_part is not None:
face_shape = face_shape + neckSlim_coeff * self.neck_slim_part
if neckStretch_coeff != 0 and self.neck_stretch_part is not None:
neck_stretch_part = self.neck_stretch_part.reshape(1, -1, 3)
neck_stretch_part_top = neck_stretch_part[0, 37476, 1]
neck_stretch_part_bottom = neck_stretch_part[0, 37357, 1]
neck_stretch_height = neck_stretch_part_top - neck_stretch_part_bottom
face_shape_ = face_shape.reshape(1, -1, 3)
face_shape_top = face_shape_[0, 37476, 1]
face_shape_bottom = face_shape_[0, 37357, 1]
face_shape_height = face_shape_top - face_shape_bottom
target_neck_height = 0.72 # top ind 37476, bottom ind 37357
neckStretch_coeff = (target_neck_height
- face_shape_height) / neck_stretch_height
face_shape = face_shape + neckStretch_coeff * self.neck_stretch_part
return face_shape.reshape([batch_size, -1, 3])
def compute_texture(self, tex_coeff, normalize=True):
"""
Return:
face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
Parameters:
tex_coeff -- torch.tensor, size (B, 80)
"""
batch_size = tex_coeff.shape[0]
face_texture = torch.einsum('ij,aj->ai', self.tex_base,
tex_coeff) + self.mean_tex
if normalize:
face_texture = face_texture / 255.
return face_texture.reshape([batch_size, -1, 3])
def compute_norm(self, face_shape):
"""
Return:
vertex_norm -- torch.tensor, size (B, N, 3)
Parameters:
face_shape -- torch.tensor, size (B, N, 3)
"""
v1 = face_shape[:, self.face_buf[:, 0]]
v2 = face_shape[:, self.face_buf[:, 1]]
v3 = face_shape[:, self.face_buf[:, 2]]
e1 = v1 - v2
e2 = v2 - v3
face_norm = torch.cross(e1, e2, dim=-1)
face_norm = F.normalize(face_norm, dim=-1, p=2)
face_norm = torch.cat(
[face_norm,
torch.zeros(face_norm.shape[0], 1, 3).to(self.device)],
dim=1)
vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
return vertex_norm
def compute_color(self, face_texture, face_norm, gamma):
"""
Return:
face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
Parameters:
face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
face_norm -- torch.tensor, size (B, N, 3), rotated face normal
gamma -- torch.tensor, size (B, 27), SH coeffs
"""
batch_size = gamma.shape[0]
a, c = self.SH.a, self.SH.c
gamma = gamma.reshape([batch_size, 3, 9])
gamma = gamma + self.init_lit
gamma = gamma.permute(0, 2, 1)
y1 = a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device)
y2 = -a[1] * c[1] * face_norm[..., 1:2]
y3 = a[1] * c[1] * face_norm[..., 2:]
y4 = -a[1] * c[1] * face_norm[..., :1]
y5 = a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2]
y6 = -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:]
y7 = 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:]**2 - 1)
y8 = -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:]
y9 = 0.5 * a[2] * c[2] * (
face_norm[..., :1]**2 - face_norm[..., 1:2]**2)
Y = torch.cat([y1, y2, y3, y4, y5, y6, y7, y8, y9], dim=-1)
r = Y @ gamma[..., :1]
g = Y @ gamma[..., 1:2]
b = Y @ gamma[..., 2:]
face_color = torch.cat([r, g, b], dim=-1) * face_texture
return face_color
def compute_rotation(self, angles):
"""
Return:
rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
Parameters:
angles -- torch.tensor, size (B, 3), radian
"""
batch_size = angles.shape[0]
ones = torch.ones([batch_size, 1]).to(self.device)
zeros = torch.zeros([batch_size, 1]).to(self.device)
x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
value_list = [
ones, zeros, zeros, zeros,
torch.cos(x), -torch.sin(x), zeros,
torch.sin(x),
torch.cos(x)
]
rot_x = torch.cat(value_list, dim=1).reshape([batch_size, 3, 3])
value_list = [
torch.cos(y), zeros,
torch.sin(y), zeros, ones, zeros, -torch.sin(y), zeros,
torch.cos(y)
]
rot_y = torch.cat(value_list, dim=1).reshape([batch_size, 3, 3])
value_list = [
torch.cos(z), -torch.sin(z), zeros,
torch.sin(z),
torch.cos(z), zeros, zeros, zeros, ones
]
rot_z = torch.cat(value_list, dim=1).reshape([batch_size, 3, 3])
rot = rot_z @ rot_y @ rot_x
return rot.permute(0, 2, 1)
def to_camera(self, face_shape):
face_shape[..., -1] = self.camera_distance - face_shape[..., -1]
return face_shape
def to_image(self, face_shape):
"""
Return:
face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
Parameters:
face_shape -- torch.tensor, size (B, N, 3)
"""
# to image_plane
face_proj = face_shape @ self.persc_proj
face_proj = face_proj[..., :2] / face_proj[..., 2:]
return face_proj
def transform(self, face_shape, rot, trans):
"""
Return:
face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
Parameters:
face_shape -- torch.tensor, size (B, N, 3)
rot -- torch.tensor, size (B, 3, 3)
trans -- torch.tensor, size (B, 3)
"""
return face_shape @ rot + trans.unsqueeze(1)
def get_landmarks(self, face_proj):
"""
Return:
face_lms -- torch.tensor, size (B, 68, 2)
Parameters:
face_proj -- torch.tensor, size (B, N, 2)
"""
return face_proj[:, self.keypoints]
def split_coeff(self, coeffs):
"""
Return:
coeffs_dict -- a dict of torch.tensors
Parameters:
coeffs -- torch.tensor, size (B, 256)
"""
if type(coeffs) == dict and 'id' in coeffs:
return coeffs
id_coeffs = coeffs[:, :80]
exp_coeffs = coeffs[:, 80:144]
tex_coeffs = coeffs[:, 144:224]
angles = coeffs[:, 224:227]
gammas = coeffs[:, 227:254]
translations = coeffs[:, 254:]
return {
'id': id_coeffs,
'exp': exp_coeffs,
'tex': tex_coeffs,
'angle': angles,
'gamma': gammas,
'trans': translations
}
def merge_coeff(self, coeffs):
"""
Return:
coeffs_dict -- a dict of torch.tensors
Parameters:
coeffs -- torch.tensor, size (B, 256)
"""
names = ['id', 'exp', 'tex', 'angle', 'gamma', 'trans']
coeffs_merge = []
for name in names:
coeffs_merge.append(coeffs[name].detach())
coeffs_merge = torch.cat(coeffs_merge, dim=1)
return coeffs_merge
def reverse_recenter(self, face_shape):
batch_size = face_shape.shape[0]
face_shape = face_shape.reshape([-1, 3])
mean_shape_ori = self.mean_shape_ori.reshape([-1, 3])
face_shape = face_shape + torch.mean(
mean_shape_ori[:35709, ...], dim=0, keepdim=True)
face_shape = face_shape.reshape([batch_size, -1, 3])
return face_shape
def add_nonlinear_offset_eyes(self, face_shape, shape_offset):
assert face_shape.shape[0] == 1 and shape_offset.shape[0] == 1
face_shape = face_shape[0]
shape_offset = shape_offset[0]
corner_shape = face_shape[-625:, :]
corner_offset = shape_offset[self.eye_corner_inds]
for i in range(len(self.eye_lines)):
corner_shape[self.eye_lines[i]] += corner_offset[i][None, ...]
face_shape[-625:, :] = corner_shape
l_eye_landmarks = [11540, 11541]
r_eye_landmarks = [4271, 4272]
l_eye_offset = torch.mean(
shape_offset[l_eye_landmarks], dim=0, keepdim=True)
face_shape[37082:37082 + 609] += l_eye_offset
r_eye_offset = torch.mean(
shape_offset[r_eye_landmarks], dim=0, keepdim=True)
face_shape[37082 + 609:37082 + 609 + 608] += r_eye_offset
face_shape = face_shape[None, ...]
return face_shape
def add_nonlinear_offset(self, face_shape, shape_offset_uv, UVs):
"""
Args:
face_shape: torch.tensor, size (1, N, 3)
shape_offset_uv: torch.tensor, size (1, h, w, 3)
UVs: torch.tensor, size (N, 2)
Returns:
"""
assert face_shape.shape[0] == 1 and shape_offset_uv.shape[0] == 1
face_shape = face_shape[0]
shape_offset_uv = shape_offset_uv[0]
h, w = shape_offset_uv.shape[:2]
UVs_coords = UVs.clone()
UVs_coords[:, 0] *= w
UVs_coords[:, 1] *= h
UVs_coords_int = torch.floor(UVs_coords)
UVs_coords_float = UVs_coords - UVs_coords_int
UVs_coords_int = UVs_coords_int.long()
shape_lt = shape_offset_uv[(h - 1 - UVs_coords_int[:, 1]).clamp(
0, h - 1), UVs_coords_int[:, 0].clamp(0, w - 1)] # (N, 3)
shape_lb = shape_offset_uv[(h - UVs_coords_int[:, 1]).clamp(0, h - 1),
UVs_coords_int[:, 0].clamp(0, w - 1)]
shape_rt = shape_offset_uv[(h - 1
- UVs_coords_int[:, 1]).clamp(0, h - 1),
(UVs_coords_int[:, 0] + 1).clamp(0, w - 1)]
shape_rb = shape_offset_uv[(h - UVs_coords_int[:, 1]).clamp(0, h - 1),
(UVs_coords_int[:, 0] + 1).clamp(0, w - 1)]
value_1 = shape_lt * (
1 - UVs_coords_float[:, :1]) * UVs_coords_float[:, 1:]
value_2 = shape_lb * (1 - UVs_coords_float[:, :1]) * (
1 - UVs_coords_float[:, 1:])
value_3 = shape_rt * UVs_coords_float[:, :1] * UVs_coords_float[:, 1:]
value_4 = shape_rb * UVs_coords_float[:, :1] * (
1 - UVs_coords_float[:, 1:])
offset_shape = value_1 + value_2 + value_3 + value_4 # (B, N, 3)
face_shape = (face_shape + offset_shape)[None, ...]
return face_shape, offset_shape[None, ...]
def compute_for_render_head_fitting(self,
coeffs,
shape_offset_uv,
texture_offset_uv,
shape_offset_uv_head,
texture_offset_uv_head,
UVs,
reverse_recenter=True,
get_eyes=False,
get_neck=False,
nose_coeff=0.0,
neck_coeff=0.0,
eyes_coeff=0.0):
if type(coeffs) == dict:
coef_dict = coeffs
elif type(coeffs) == torch.Tensor:
coef_dict = self.split_coeff(coeffs)
face_shape = self.compute_shape(
coef_dict['id'],
coef_dict['exp'],
nose_coeff=nose_coeff,
neck_coeff=neck_coeff,
eyes_coeff=eyes_coeff) # (1, n, 3)
if reverse_recenter:
face_shape_ori_noRecenter = self.reverse_recenter(
face_shape.clone())
else:
face_shape_ori_noRecenter = face_shape.clone()
face_vertex_ori = self.to_camera(face_shape_ori_noRecenter)
face_shape[:, :35241, :], shape_offset = self.add_nonlinear_offset(
face_shape[:, :35241, :], shape_offset_uv,
UVs[:35709, ...][self.bfm_keep_inds]) # (1, n, 3)
if get_eyes:
face_shape = self.add_nonlinear_offset_eyes(
face_shape, shape_offset)
if get_neck:
face_shape[:, 35241:37082, ...], _ = self.add_nonlinear_offset(
face_shape[:, 35241:37082, ...], shape_offset_uv_head,
UVs[35709:, ...]) # (1, n, 3)
else:
face_shape[:, self.ours_hair_area_inds,
...], _ = self.add_nonlinear_offset(
face_shape[:, self.ours_hair_area_inds,
...], shape_offset_uv_head,
UVs[self.ours_hair_area_inds + (35709 - 35241),
...]) # (1, n, 3)
if reverse_recenter:
face_shape_offset_noRecenter = self.reverse_recenter(
face_shape.clone())
else:
face_shape_offset_noRecenter = face_shape.clone()
face_vertex_offset = self.to_camera(face_shape_offset_noRecenter)
rotation = self.compute_rotation(coef_dict['angle'])
face_shape_transformed = self.transform(face_shape, rotation,
coef_dict['trans'])
face_vertex = self.to_camera(face_shape_transformed)
face_proj = self.to_image(face_vertex)
landmark = self.get_landmarks(face_proj)
face_texture = self.compute_texture(coef_dict['tex']) # (1, n, 3)
face_texture[:, :35241, :], texture_offset = self.add_nonlinear_offset(
face_texture[:, :35241, :], texture_offset_uv,
UVs[:35709, ...][self.bfm_keep_inds])
face_texture[:, 35241:37082, :], _ = self.add_nonlinear_offset(
face_texture[:, 35241:37082, :], texture_offset_uv_head,
UVs[35709:, ...])
face_norm = self.compute_norm(face_shape)
face_norm_roted = face_norm @ rotation
face_color = self.compute_color(face_texture, face_norm_roted,
coef_dict['gamma'])
return face_vertex, face_texture, face_color, landmark, face_vertex_ori, face_vertex_offset, face_proj
def compute_for_render_head(self,
coeffs,
shape_offset_uv,
texture_offset_uv,
shape_offset_uv_head,
texture_offset_uv_head,
UVs,
reverse_recenter=True,
nose_coeff=0.0,
neck_coeff=0.0,
eyes_coeff=0.0,
neckSlim_coeff=0.0,
neckStretch_coeff=0.0):
if type(coeffs) == dict:
coef_dict = coeffs
elif type(coeffs) == torch.Tensor:
coef_dict = self.split_coeff(coeffs)
face_shape = self.compute_shape(
coef_dict['id'],
coef_dict['exp'],
nose_coeff=nose_coeff,
neck_coeff=neck_coeff,
eyes_coeff=eyes_coeff,
neckSlim_coeff=neckSlim_coeff,
neckStretch_coeff=neckStretch_coeff) # (1, n, 3)
if reverse_recenter:
face_shape_ori_noRecenter = self.reverse_recenter(
face_shape.clone())
else:
face_shape_ori_noRecenter = face_shape.clone()
face_vertex_ori = self.to_camera(face_shape_ori_noRecenter)
face_shape[:, :35709, :], shape_offset = self.add_nonlinear_offset(
face_shape[:, :35709, :], shape_offset_uv, UVs[:35709,
...]) # (1, n, 3)
face_shape[:, 35709:,
...], _ = self.add_nonlinear_offset(face_shape[:, 35709:,
...],
shape_offset_uv_head,
UVs[35709:,
...]) # (1, n, 3)
if reverse_recenter:
face_shape_offset_noRecenter = self.reverse_recenter(
face_shape.clone())
else:
face_shape_offset_noRecenter = face_shape.clone()
face_vertex_offset = self.to_camera(face_shape_offset_noRecenter)
rotation = self.compute_rotation(coef_dict['angle'])
face_shape_transformed = self.transform(face_shape, rotation,
coef_dict['trans'])
face_vertex = self.to_camera(face_shape_transformed)
face_proj = self.to_image(face_vertex)
landmark = self.get_landmarks(face_proj)
face_texture = self.compute_texture(coef_dict['tex']) # (1, n, 3)
face_texture[:, :35709, :], texture_offset = self.add_nonlinear_offset(
face_texture[:, :35709, :], texture_offset_uv, UVs[:35709, ...])
face_texture[:, 35709:, :], _ = self.add_nonlinear_offset(
face_texture[:, 35709:, :], texture_offset_uv_head, UVs[35709:,
...])
face_norm = self.compute_norm(face_shape)
face_norm_roted = face_norm @ rotation
face_color = self.compute_color(face_texture, face_norm_roted,
coef_dict['gamma'])
return face_vertex, face_texture, face_color, landmark, face_vertex_ori, face_vertex_offset, face_proj

View File

@@ -0,0 +1,196 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import cv2
import json
import numpy as np
import tensorflow as tf
if tf.__version__ >= '2.0':
tf = tf.compat.v1
tf.disable_eager_execution()
class HeadSegmentor():
def __init__(self, model_root):
"""The HeadSegmentor is implemented based on https://arxiv.org/abs/2004.04955
Args:
model_root: the root directory of the model files
"""
self.sess = self.load_sess(
os.path.join(model_root, 'head_segmentation',
'Matting_headparser_6_18.pb'))
self.sess_detect = self.load_sess(
os.path.join(model_root, 'head_segmentation', 'face_detect.pb'))
self.sess_face = self.load_sess(
os.path.join(model_root, 'head_segmentation', 'segment_face.pb'))
def load_sess(self, model_path):
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
with tf.gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
sess.run(tf.global_variables_initializer())
return sess
def process(self, image):
""" image: bgr
"""
h, w, c = image.shape
faceRects = self.detect_face(image)
face_num = len(faceRects)
all_head_alpha = []
all_face_mask = []
for i in range(face_num):
y1 = faceRects[i][0]
y2 = faceRects[i][1]
x1 = faceRects[i][2]
x2 = faceRects[i][3]
pad_y1, pad_y2, pad_x1, pad_x2 = self.pad_box(
y1, y2, x1, x2, 0.15, 0.15, 0.15, 0.15, h, w)
temp_img = image.copy()
roi_img = temp_img[pad_y1:pad_y2, pad_x1:pad_x2]
output_alpha = self.sess_face.run(
self.sess_face.graph.get_tensor_by_name('output_alpha_face:0'),
feed_dict={'input_image_face:0': roi_img[:, :, ::-1]})
face_mask = np.zeros((h, w, 3))
face_mask[pad_y1:pad_y2, pad_x1:pad_x2] = output_alpha
all_face_mask.append(face_mask)
cv2.imwrite(str(i) + 'face.jpg', face_mask)
cv2.imwrite(str(i) + 'face_roi.jpg', roi_img)
for i in range(face_num):
y1 = faceRects[i][0]
y2 = faceRects[i][1]
x1 = faceRects[i][2]
x2 = faceRects[i][3]
pad_y1, pad_y2, pad_x1, pad_x2 = self.pad_box(
y1, y2, x1, x2, 1.47, 1.47, 1.3, 2.0, h, w)
temp_img = image.copy()
for j in range(face_num):
y1 = faceRects[j][0]
y2 = faceRects[j][1]
x1 = faceRects[j][2]
x2 = faceRects[j][3]
small_y1, small_y2, small_x1, small_x2 = self.pad_box(
y1, y2, x1, x2, -0.1, -0.1, -0.1, -0.1, h, w)
small_width = small_x2 - small_x1
small_height = small_y2 - small_y1
if (small_x1 < 0 or small_y1 < 0 or small_width < 3
or small_height < 3 or small_x2 > w or small_y2 > h):
continue
# if(i!=j):
# temp_img[small_y1:small_y2,small_x1:small_x2]=0
if (i != j):
temp_img = temp_img * (1.0 - all_face_mask[j] / 255.0)
roi_img = temp_img[pad_y1:pad_y2, pad_x1:pad_x2]
output_alpha = self.sess.run(
self.sess.graph.get_tensor_by_name('output_alpha:0'),
feed_dict={'input_image:0': roi_img[:, :, ::-1]})
head_alpha = np.zeros((h, w))
head_alpha[pad_y1:pad_y2, pad_x1:pad_x2] = output_alpha[:, :, 0]
if np.sum(head_alpha) > 255 * w * h * 0.01 * 0.01:
all_head_alpha.append(head_alpha)
head_num = len(all_head_alpha)
head_elements = []
if head_num == 0:
return head_elements
for i in range(head_num):
head_alpha = all_head_alpha[i]
head_elements.append(head_alpha)
return head_elements
def pad_box(self, y1, y2, x1, x2, left_ratio, right_ratio, top_ratio,
bottom_ratio, h, w):
box_w = x2 - x1
box_h = y2 - y1
pad_y1 = np.maximum(np.int32(y1 - top_ratio * box_h), 0)
pad_y2 = np.minimum(np.int32(y2 + bottom_ratio * box_h), h - 1)
pad_x1 = np.maximum(np.int32(x1 - left_ratio * box_w), 0)
pad_x2 = np.minimum(np.int32(x2 + right_ratio * box_w), w - 1)
return pad_y1, pad_y2, pad_x1, pad_x2
def detect_face(self, img):
h, w, c = img.shape
input_img = cv2.resize(img[:, :, ::-1], (512, 512))
boxes, scores, num_detections = self.sess_detect.run(
[
self.sess_detect.graph.get_tensor_by_name('tower_0/boxes:0'),
self.sess_detect.graph.get_tensor_by_name('tower_0/scores:0'),
self.sess_detect.graph.get_tensor_by_name(
'tower_0/num_detections:0')
],
feed_dict={
'tower_0/images:0': input_img[np.newaxis],
'training_flag:0': False
})
faceRects = []
for i in range(num_detections[0]):
if scores[0, i] < 0.5:
continue
y1 = np.int32(boxes[0, i, 0] * h)
x1 = np.int32(boxes[0, i, 1] * w)
y2 = np.int32(boxes[0, i, 2] * h)
x2 = np.int32(boxes[0, i, 3] * w)
if x2 <= x1 + 3 or y2 <= y1 + 3:
continue
faceRects.append((y1, y2, x1, x2, y2 - y1, x2 - x1))
sorted(faceRects, key=lambda x: x[4] * x[5], reverse=True)
return faceRects
def generate_json(self, status_code, status_msg, ori_url, result_element,
track_id):
data = {}
data['originUri'] = ori_url
data['elements'] = result_element
data['statusCode'] = status_code
data['statusMessage'] = status_msg
data['requestId'] = track_id
return json.dumps(data)
def get_box(self, alpha):
h, w = alpha.shape
start_h = 0
end_h = 0
start_w = 0
end_w = 0
for i in range(0, h, 3):
line = alpha[i, :]
if np.max(line) >= 1:
start_h = i
break
for i in range(0, w, 3):
line = alpha[:, i]
if np.max(line) >= 1:
start_w = i
break
for i in range(0, h, 3):
i = h - 1 - i
line = alpha[i, :]
if np.max(line) >= 1:
end_h = i
if end_h < h - 1:
end_h = end_h + 1
break
for i in range(0, w, 3):
i = w - 1 - i
line = alpha[:, i]
if np.max(line) >= 1:
end_w = i
if end_w < w - 1:
end_w = end_w + 1
break
return start_h, start_w, end_h, end_w

View File

@@ -0,0 +1,564 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from collections import OrderedDict
import cv2
import numpy as np
import torch
from modelscope.models import MODELS, TorchModel
from modelscope.models.cv.face_reconstruction.utils import (estimate_normals,
read_obj)
from . import networks, opt
from .bfm import ParametricFaceModel
from .losses import (BinaryDiceLoss, TVLoss, TVLoss_std, landmark_loss,
perceptual_loss, photo_loss, points_loss_horizontal,
reflectance_loss, reg_loss)
from .nv_diffrast import MeshRenderer
@MODELS.register_module('head-reconstruction', 'head_reconstruction')
class HeadReconModel(TorchModel):
def __init__(self, model_dir, *args, **kwargs):
"""The HeadReconModel is implemented based on HRN, publicly available at
https://github.com/youngLBW/HRN
Args:
model_dir: the root directory of the model files
"""
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir
opt.bfm_folder = os.path.join(model_dir, 'assets')
self.opt = opt
self.isTrain = opt.isTrain
self.visual_names = ['output_vis']
self.model_names = ['net_recon']
self.parallel_names = self.model_names + [
'renderer', 'renderer_fitting'
]
# networks
self.net_recon = networks.define_net_recon(
net_recon=opt.net_recon,
use_last_fc=opt.use_last_fc,
init_path=None)
# assets
self.headmodel = ParametricFaceModel(
assets_root=opt.bfm_folder,
camera_distance=opt.camera_d,
focal=opt.focal,
center=opt.center,
is_train=self.isTrain,
default_name='ourRefineBFMEye0504_model.mat')
self.headmodel_for_fitting = ParametricFaceModel(
assets_root=opt.bfm_folder,
camera_distance=opt.camera_d,
focal=opt.focal,
center=opt.center,
is_train=self.isTrain,
default_name='ourRefineFull_model.mat')
# renderer
fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi
self.renderer = MeshRenderer(
rasterize_fov=fov,
znear=opt.z_near,
zfar=opt.z_far,
rasterize_size=int(2 * opt.center))
self.renderer_fitting = MeshRenderer(
rasterize_fov=fov,
znear=opt.z_near,
zfar=opt.z_far,
rasterize_size=int(2 * opt.center))
template_obj_path = os.path.join(
model_dir,
'assets/3dmm/template_mesh/template_ourFull_bfmEyes.obj')
self.template_output_mesh = read_obj(template_obj_path)
self.nonlinear_UVs = self.template_output_mesh['uvs']
self.nonlinear_UVs = torch.from_numpy(self.nonlinear_UVs)
self.jaw_edge_mask = cv2.imread(
os.path.join(model_dir,
'assets/texture/jaw_edge_mask2.png'))[..., 0].astype(
np.float32) / 255.0
self.jaw_edge_mask = cv2.resize(self.jaw_edge_mask, (300, 300))[...,
None]
self.input_imgs = []
self.input_img_hds = []
self.input_fat_img_hds = []
self.atten_masks = []
self.gt_lms = []
self.gt_lm_hds = []
self.trans_ms = []
self.img_names = []
self.face_masks = []
self.head_masks = []
self.input_imgs_coeff = []
self.gt_lms_coeff = []
self.loss_names = [
'all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc'
]
self.compute_feat_loss = perceptual_loss
self.comupte_color_loss = photo_loss
self.compute_lm_loss = landmark_loss
self.compute_reg_loss = reg_loss
self.compute_reflc_loss = reflectance_loss
if opt.isTrain:
self.optimizer = torch.optim.Adam(
self.net_recon.parameters(), lr=opt.lr)
self.optimizers = [self.optimizer]
self.parallel_names += ['net_recog']
def set_device(self, device):
self.device = device
self.net_recon = self.net_recon.to(self.device)
self.headmodel.to(self.device)
self.headmodel_for_fitting.to(self.device)
self.nonlinear_UVs = self.nonlinear_UVs.to(self.device)
def load_networks(self, load_path):
state_dict = torch.load(load_path, map_location=self.device)
print('loading the model from %s' % load_path)
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
net.load_state_dict(state_dict[name], strict=False)
def setup(self, checkpoint_path):
"""Load and print networks; create schedulers
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.load_networks(checkpoint_path)
def parallelize(self, convert_sync_batchnorm=True):
if not self.opt.use_ddp:
for name in self.parallel_names:
if isinstance(name, str):
module = getattr(self, name)
setattr(self, name, module.to(self.device))
else:
for name in self.model_names:
if isinstance(name, str):
module = getattr(self, name)
if convert_sync_batchnorm:
module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
module)
setattr(
self, name,
torch.nn.parallel.DistributedDataParallel(
module.to(self.device),
device_ids=[self.device.index],
find_unused_parameters=True,
broadcast_buffers=True))
# DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
for name in self.parallel_names:
if isinstance(name, str) and name not in self.model_names:
module = getattr(self, name)
setattr(self, name, module.to(self.device))
# put state_dict of optimizer to gpu device
if self.opt.phase != 'test':
if self.opt.continue_train:
for optim in self.optimizers:
for state in optim.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(self.device)
def eval(self):
"""Make models eval mode"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, name)
net.eval()
def set_render(self, image_res=1024):
fov = 2 * np.arctan(self.opt.center / self.opt.focal) * 180 / np.pi
if image_res is None:
image_res = int(2 * self.opt.center)
self.renderer = MeshRenderer(
rasterize_fov=fov,
znear=self.opt.z_near,
zfar=self.opt.z_far,
rasterize_size=image_res)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
"""
self.input_img = input['imgs'].to(self.device)
self.input_img_hd = input['imgs_hd'].to(
self.device) if 'imgs_hd' in input else None
if 'imgs_fat_hd' not in input or input['imgs_fat_hd'] is None:
self.input_fat_img_hd = self.input_img_hd
else:
self.input_fat_img_hd = input['imgs_fat_hd'].to(self.device)
self.atten_mask = input['msks'].to(
self.device) if 'msks' in input else None
self.gt_lm = input['lms'].to(self.device) if 'lms' in input else None
self.gt_lm_hd = input['lms_hd'].to(
self.device) if 'lms_hd' in input else None
self.trans_m = input['M'].to(self.device) if 'M' in input else None
self.image_paths = input['im_paths'] if 'im_paths' in input else None
self.img_name = input['img_name'] if 'img_name' in input else None
self.face_mask = input['face_mask'].to(
self.device) if 'face_mask' in input else None
self.head_mask = input['head_mask'].to(
self.device) if 'head_mask' in input else None
self.gt_normals = input['normals'].to(
self.device) if 'normals' in input else None
self.input_img_coeff = input['imgs_coeff'].to(
self.device) if 'imgs_coeff' in input else None
self.gt_lm_coeff = input['lms_coeff'].to(
self.device) if 'lms_coeff' in input else None
def check_head_pose(self, coeffs):
pi = 3.14
if coeffs[0, 225] > pi / 6 or coeffs[0, 225] < -pi / 6:
return False
elif coeffs[0, 224] > pi / 6 or coeffs[0, 224] < -pi / 6:
return False
elif coeffs[0, 226] > pi / 6 or coeffs[0, 226] < -pi / 6:
return False
else:
return True
def get_fusion_mask(self, keep_forehead=True):
self.without_forehead_inds = torch.from_numpy(
np.load(
os.path.join(self.model_dir,
'assets/3dmm/inds/bfm_withou_forehead_inds.npy'))
).long().to(self.device)
h, w = self.shape_offset_uv.shape[1:3]
self.fusion_mask = torch.zeros((h, w)).to(self.device).float()
if keep_forehead:
UVs_coords = self.nonlinear_UVs.clone()[:35709][
self.without_forehead_inds]
else:
UVs_coords = self.nonlinear_UVs.clone()[:35709]
UVs_coords[:, 0] *= w
UVs_coords[:, 1] *= h
UVs_coords_int = torch.floor(UVs_coords)
UVs_coords_int = UVs_coords_int.long()
self.fusion_mask[h - 1 - UVs_coords_int[:, 1], UVs_coords_int[:,
0]] = 1
# blur mask
self.fusion_mask = self.fusion_mask.cpu().numpy()
new_kernel1 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
new_kernel2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (8, 8))
self.fusion_mask = cv2.dilate(self.fusion_mask, new_kernel1, 1)
self.fusion_mask = cv2.erode(self.fusion_mask, new_kernel2, 1)
self.fusion_mask = cv2.blur(self.fusion_mask, (17, 17))
self.fusion_mask = torch.from_numpy(self.fusion_mask).float().to(
self.device)
def get_edge_mask(self):
h, w = self.shape_offset_uv.shape[1:3]
self.edge_mask = torch.zeros((h, w)).to(self.device).float()
UVs_coords = self.nonlinear_UVs.clone()[self.edge_points_inds]
UVs_coords[:, 0] *= w
UVs_coords[:, 1] *= h
UVs_coords_int = torch.floor(UVs_coords)
UVs_coords_int = UVs_coords_int.long()
self.edge_mask[h - 1 - UVs_coords_int[:, 1], UVs_coords_int[:, 0]] = 1
# blur mask
self.edge_mask = self.edge_mask.cpu().numpy()
new_kernel1 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (8, 8))
self.edge_mask = cv2.dilate(self.edge_mask, new_kernel1, 1)
self.edge_mask = cv2.blur(self.edge_mask, (5, 5))
self.edge_mask = torch.from_numpy(self.edge_mask).float().to(
self.device)
def blur_shape_offset_uv(self, global_blur=False, blur_size=3):
if self.edge_mask is not None:
shape_offset_uv_blur = self.shape_offset_uv[0].detach().cpu(
).numpy()
shape_offset_uv_blur = cv2.blur(shape_offset_uv_blur, (15, 15))
shape_offset_uv_blur = torch.from_numpy(
shape_offset_uv_blur).float().to(self.device)[None, ...]
self.shape_offset_uv = shape_offset_uv_blur * self.edge_mask[
None, ..., None] + self.shape_offset_uv * (
1 - self.edge_mask[None, ..., None])
self.shape_offset_uv = self.shape_offset_uv * self.fusion_mask[None,
...,
None]
if global_blur and blur_size > 0:
shape_offset_uv_blur = self.shape_offset_uv[0].detach().cpu(
).numpy()
shape_offset_uv_blur = cv2.blur(shape_offset_uv_blur,
(blur_size, blur_size))
shape_offset_uv_blur = torch.from_numpy(
shape_offset_uv_blur).float().to(self.device)[None, ...]
self.shape_offset_uv = shape_offset_uv_blur
def blur_offset_edge(self):
shape_offset_uv = self.shape_offset_uv[0].detach().cpu().numpy()
shape_offset_uv_head = self.shape_offset_uv_head[0].detach().cpu(
).numpy()
shape_offset_uv_head = cv2.resize(shape_offset_uv_head, (300, 300))
shape_offset_uv_head = shape_offset_uv_head * (
1 - self.jaw_edge_mask) + shape_offset_uv * self.jaw_edge_mask
shape_offset_uv_head = cv2.resize(shape_offset_uv_head, (100, 100))
self.shape_offset_uv_head = torch.from_numpy(
shape_offset_uv_head).float().to(self.device)[None, ...]
def fitting_nonlinear(self, coeff, n_iters=250):
output_coeff = coeff.detach().clone()
output_coeff = self.headmodel_for_fitting.split_coeff(output_coeff)
output_coeff['id'].requires_grad = True
output_coeff['exp'].requires_grad = True
output_coeff['tex'].requires_grad = True
output_coeff['angle'].requires_grad = True
output_coeff['gamma'].requires_grad = True
output_coeff['trans'].requires_grad = True
self.shape_offset_uv = torch.zeros((1, 300, 300, 3),
dtype=torch.float32).to(self.device)
self.shape_offset_uv.requires_grad = True
self.texture_offset_uv = torch.zeros(
(1, 300, 300, 3), dtype=torch.float32).to(self.device)
self.texture_offset_uv.requires_grad = True
self.shape_offset_uv_head = torch.zeros(
(1, 100, 100, 3), dtype=torch.float32).to(self.device)
self.shape_offset_uv_head.requires_grad = True
self.texture_offset_uv_head = torch.zeros(
(1, 100, 100, 3), dtype=torch.float32).to(self.device)
self.texture_offset_uv_head.requires_grad = True
head_face_inds = np.load(
os.path.join(self.model_dir,
'assets/3dmm/inds/ours_head_face_inds.npy'))
head_face_inds = torch.from_numpy(head_face_inds).to(self.device)
head_faces = self.headmodel_for_fitting.face_buf[head_face_inds]
# print('before fitting', output_coeff)
opt_parameters = [
self.shape_offset_uv, self.texture_offset_uv,
self.shape_offset_uv_head, self.texture_offset_uv_head,
output_coeff['id'], output_coeff['exp'], output_coeff['tex'],
output_coeff['gamma']
]
optim = torch.optim.Adam(opt_parameters, lr=1e-3)
optim_pose = torch.optim.Adam([output_coeff['trans']], lr=1e-1)
self.get_edge_points_horizontal()
for i in range(n_iters): # 500
self.pred_vertex_head, self.pred_tex, self.pred_color_head, self.pred_lm, face_shape, \
face_shape_offset, self.verts_proj_head = \
self.headmodel_for_fitting.compute_for_render_head_fitting(output_coeff, self.shape_offset_uv,
self.texture_offset_uv,
self.shape_offset_uv_head,
self.texture_offset_uv_head,
self.nonlinear_UVs)
self.pred_vertex = self.pred_vertex_head[:, :35241]
self.pred_color = self.pred_color_head[:, :35241]
self.verts_proj = self.verts_proj_head[:, :35241]
self.pred_mask_head, _, self.pred_head, self.occ_head = self.renderer_fitting(
self.pred_vertex_head, head_faces, feat=self.pred_color_head)
self.pred_mask, _, self.pred_face, self.occ_face = self.renderer_fitting(
self.pred_vertex,
self.headmodel_for_fitting.face_buf[:69732],
feat=self.pred_color)
self.pred_coeffs_dict = self.headmodel_for_fitting.split_coeff(
output_coeff)
self.compute_losses_fitting()
if i < 150:
optim_pose.zero_grad()
(self.loss_lm + self.loss_color * 0.1).backward()
optim_pose.step()
else:
optim.zero_grad()
self.loss_all.backward()
optim.step()
output_coeff = self.headmodel_for_fitting.merge_coeff(output_coeff)
self.get_edge_mask()
self.get_fusion_mask(keep_forehead=False)
self.blur_shape_offset_uv(global_blur=True)
self.blur_offset_edge()
return output_coeff
def forward(self):
with torch.no_grad():
output_coeff = self.net_recon(self.input_img_coeff)
if not self.check_head_pose(output_coeff):
return None
with torch.enable_grad():
output_coeff = self.fitting_nonlinear(output_coeff)
output_coeff = self.headmodel.split_coeff(output_coeff)
eye_coeffs = output_coeff['exp'][0, 16] + output_coeff['exp'][
0, 17] + output_coeff['exp'][0, 19]
if eye_coeffs > 1.0:
degree = 0.5
else:
degree = 1.0
# degree = 0.5
output_coeff['exp'][0, 16] += 1 * degree
output_coeff['exp'][0, 17] += 1 * degree
output_coeff['exp'][0, 19] += 1.5 * degree
output_coeff = self.headmodel.merge_coeff(output_coeff)
self.pred_vertex, _, _, _, face_shape_ori, face_shape, _ = \
self.headmodel.compute_for_render_head(output_coeff,
self.shape_offset_uv.detach(),
self.texture_offset_uv.detach(),
self.shape_offset_uv_head.detach() * 0,
self.texture_offset_uv_head.detach(),
self.nonlinear_UVs,
nose_coeff=0.1,
neck_coeff=0.3,
neckSlim_coeff=0.5,
neckStretch_coeff=0.5)
UVs = np.array(self.template_output_mesh['uvs'])
UVs_tensor = torch.tensor(UVs, dtype=torch.float32)
UVs_tensor = torch.unsqueeze(UVs_tensor, 0).to(self.pred_vertex.device)
target_img = self.input_fat_img_hd
target_img = target_img.permute(0, 2, 3, 1)
face_buf = self.headmodel.face_buf
# get texture map
with torch.enable_grad():
pred_mask, _, pred_face, texture_map, texture_mask = self.renderer.pred_shape_and_texture(
self.pred_vertex, face_buf, UVs_tensor, target_img, None)
self.pred_coeffs_dict = self.headmodel.split_coeff(output_coeff)
recon_shape = face_shape # get reconstructed shape, [1, 35709, 3]
recon_shape[
...,
-1] = 10 - recon_shape[..., -1] # from camera space to world space
recon_shape = recon_shape.cpu().numpy()[0]
tri = self.headmodel.face_buf.cpu().numpy()
output = {}
output['flag'] = 0
output['tex_map'] = texture_map
output['tex_mask'] = texture_mask * 255.0
'''
coeffs
{
'id': id_coeffs,
'exp': exp_coeffs,
'tex': tex_coeffs,
'angle': angles,
'gamma': gammas,
'trans': translations
}
'''
output['coeffs'] = self.pred_coeffs_dict
normals = estimate_normals(recon_shape, tri)
output['vertices'] = recon_shape
output['triangles'] = tri
output['uvs'] = UVs
output['faces_uv'] = self.template_output_mesh['faces_uv']
output['normals'] = normals
return output
def get_edge_points_horizontal(self):
left_points = []
right_points = []
for i in range(self.face_mask.shape[2]):
inds = torch.where(self.face_mask[0, 0, i, :] > 0.5) # 0.9
if len(inds[0]) > 0: # i > 112 and len(inds[0]) > 0
left_points.append(int(inds[0][0]) + 1)
right_points.append(int(inds[0][-1]))
else:
left_points.append(0)
right_points.append(self.face_mask.shape[3] - 1)
self.left_points = torch.tensor(left_points).long().to(self.device)
self.right_points = torch.tensor(right_points).long().to(self.device)
def compute_losses_fitting(self):
face_mask = self.pred_mask
face_mask = face_mask.detach()
self.loss_color = self.opt.w_color * self.comupte_color_loss(
self.pred_face, self.input_img, face_mask) # 1.0
loss_reg, loss_gamma = self.compute_reg_loss(
self.pred_coeffs_dict,
w_id=self.opt.w_id,
w_exp=self.opt.w_exp,
w_tex=self.opt.w_tex)
self.loss_reg = self.opt.w_reg * loss_reg # 1.0
self.loss_gamma = self.opt.w_gamma * loss_gamma # 1.0
self.loss_lm = self.opt.w_lm * self.compute_lm_loss(
self.pred_lm, self.gt_lm) * 0.1 # 0.1
self.loss_smooth_offset = TVLoss()(self.shape_offset_uv.permute(
0, 3, 1, 2)) * 10000 # 10000
self.loss_reg_textureOff = torch.mean(
torch.abs(self.texture_offset_uv)) * 10 # 10
self.loss_smooth_offset_std = TVLoss_std()(
self.shape_offset_uv.permute(0, 3, 1, 2)) * 50000 # 50000
self.loss_points_horizontal, self.edge_points_inds = points_loss_horizontal(
self.verts_proj, self.left_points, self.right_points) # 20
self.loss_points_horizontal *= 20
self.loss_all = self.loss_color + self.loss_lm + self.loss_reg + self.loss_gamma
self.loss_all += self.loss_smooth_offset + self.loss_smooth_offset_std + self.loss_reg_textureOff
self.loss_all += self.loss_points_horizontal
head_mask = self.pred_mask_head
head_mask = head_mask.detach()
self.loss_color_head = self.opt.w_color * self.comupte_color_loss(
self.pred_head, self.input_img, head_mask) # 1.0
self.loss_smooth_offset_head = TVLoss()(
self.shape_offset_uv_head.permute(0, 3, 1, 2)) * 100 # 10000
self.loss_smooth_offset_std_head = TVLoss_std()(
self.shape_offset_uv_head.permute(0, 3, 1, 2)) * 500 # 50000
self.loss_mask = BinaryDiceLoss()(self.occ_head, self.head_mask) * 20
self.loss_all += self.loss_mask + self.loss_color_head
self.loss_all += self.loss_smooth_offset_head + self.loss_smooth_offset_std_head

View File

@@ -0,0 +1,367 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from kornia.geometry import warp_affine
def resize_n_crop(image, M, dsize=112):
# image: (b, c, h, w)
# M : (b, 2, 3)
return warp_affine(image, M, dsize=(dsize, dsize))
# perceptual level loss
class PerceptualLoss(nn.Module):
def __init__(self, recog_net, input_size=112):
super(PerceptualLoss, self).__init__()
self.recog_net = recog_net
self.preprocess = lambda x: 2 * x - 1
self.input_size = input_size
def forward(self, imageA, imageB, M):
"""
1 - cosine distance
Parameters:
imageA --torch.tensor (B, 3, H, W), range (0, 1) , RGB order
imageB --same as imageA
"""
imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))
imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))
# freeze bn
self.recog_net.eval()
id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)
id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)
cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
return torch.sum(1 - cosine_d) / cosine_d.shape[0]
def perceptual_loss(id_featureA, id_featureB):
cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
return torch.sum(1 - cosine_d) / cosine_d.shape[0]
# image level loss
def photo_loss(imageA, imageB, mask, eps=1e-6):
"""
l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
Parameters:
imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
imageB --same as imageA
"""
loss = torch.sqrt(eps + torch.sum(
(imageA - imageB)**2, dim=1, keepdims=True)) * mask
loss = torch.sum(loss) / torch.max(
torch.sum(mask),
torch.tensor(1.0).to(mask.device))
return loss
def landmark_loss(predict_lm, gt_lm, weight=None):
"""
weighted mse loss
Parameters:
predict_lm --torch.tensor (B, 68, 2)
gt_lm --torch.tensor (B, 68, 2)
weight --numpy.array (1, 68)
"""
if not weight:
weight = np.ones([68])
weight[28:31] = 20
weight[-8:] = 20
weight = np.expand_dims(weight, 0)
weight = torch.tensor(weight).to(predict_lm.device)
loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight
loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])
return loss
# regulization
def reg_loss(coeffs_dict, w_id=1, w_exp=1, w_tex=1):
"""
l2 norm without the sqrt, from yu's implementation (mse)
tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss
Parameters:
coeffs_dict -- a dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans
"""
# coefficient regularization to ensure plausible 3d faces
value_1 = w_id * torch.sum(coeffs_dict['id']**2)
value_2 = w_exp * torch.sum(coeffs_dict['exp']**2)
value_3 = w_tex * torch.sum(coeffs_dict['tex']**2)
creg_loss = value_1 + value_2 + value_3
creg_loss = creg_loss / coeffs_dict['id'].shape[0]
# gamma regularization to ensure a nearly-monochromatic light
gamma = coeffs_dict['gamma'].reshape([-1, 3, 9])
gamma_mean = torch.mean(gamma, dim=1, keepdims=True)
gamma_loss = torch.mean((gamma - gamma_mean)**2)
return creg_loss, gamma_loss
def reflectance_loss(texture, mask):
"""
minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo
Parameters:
texture --torch.tensor, (B, N, 3)
mask --torch.tensor, (N), 1 or 0
"""
mask = mask.reshape([1, mask.shape[0], 1])
texture_mean = torch.sum(
mask * texture, dim=1, keepdims=True) / torch.sum(mask)
loss = torch.sum(((texture - texture_mean) * mask)**2) / (
texture.shape[0] * torch.sum(mask))
return loss
def lm_3d_loss(pred_lm_3d, gt_lm_3d, mask):
loss = torch.abs(pred_lm_3d - gt_lm_3d)[mask, :]
loss = torch.mean(loss)
return loss
class TVLoss(nn.Module):
def __init__(self, TVLoss_weight=1):
super(TVLoss, self).__init__()
self.TVLoss_weight = TVLoss_weight
def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = self._tensor_size(x[:, :, 1:, :])
count_w = self._tensor_size(x[:, :, :, 1:])
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
return self.TVLoss_weight * 2 * (h_tv / count_h
+ w_tv / count_w) / batch_size
def _tensor_size(self, t):
return t.size()[1] * t.size()[2] * t.size()[3]
class TVLoss_std(nn.Module):
def __init__(self, TVLoss_weight=1):
super(TVLoss_std, self).__init__()
self.TVLoss_weight = TVLoss_weight
def forward(self, x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2)
h_tv = ((h_tv - torch.mean(h_tv))**2).sum()
w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2)
w_tv = ((w_tv - torch.mean(w_tv))**2).sum()
return self.TVLoss_weight * 2 * (h_tv + w_tv) / batch_size
def _tensor_size(self, t):
return t.size()[1] * t.size()[2] * t.size()[3]
def photo_loss_sum(imageA, imageB, mask, eps=1e-6):
"""
l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
Parameters:
imageA --torch.tensor (B, 3, H, W), range (0, 1), RGB order
imageB --same as imageA
"""
loss = torch.sqrt(eps + torch.sum(
(imageA - imageB)**2, dim=1, keepdims=True)) * mask
loss = torch.sum(loss) / (
imageA.shape[0] * imageA.shape[2] * imageA.shape[3])
return loss
def points_loss_horizontal(verts, left_points, right_points, width=224):
verts_int = torch.ceil(verts[0]).long().clamp(0, width - 1) # (n, 2)
verts_left = left_points[width - 1 - verts_int[:, 1]].float()
verts_right = right_points[width - 1 - verts_int[:, 1]].float()
verts_x = verts[0, :, 0]
dist = (verts_left - verts_x) / width * (verts_right - verts_x) / width
dist /= torch.max(
torch.abs((verts_left - verts_x) / width),
torch.abs((verts_right - verts_x) / width))
edge_inds = torch.where(dist > 0)[0]
dist += 0.01
dist = torch.nn.functional.relu(dist).clone()
dist -= 0.01
dist = torch.abs(dist)
loss = torch.mean(dist)
return loss, edge_inds
class LaplacianLoss(nn.Module):
def __init__(self):
super(LaplacianLoss, self).__init__()
def forward(self, x):
batch_size, slice_num = x.size()[:2]
z_x = x.size()[2]
h_x = x.size()[3]
w_x = x.size()[4]
count_z = self._tensor_size(x[:, :, 1:, :, :])
count_h = self._tensor_size(x[:, :, :, 1:, :])
count_w = self._tensor_size(x[:, :, :, :, 1:])
z_tv = torch.pow((x[:, :, 1:, :, :] - x[:, :, :z_x - 1, :, :]),
2).sum()
h_tv = torch.pow((x[:, :, :, 1:, :] - x[:, :, :, :h_x - 1, :]),
2).sum()
w_tv = torch.pow((x[:, :, :, :, 1:] - x[:, :, :, :, :w_x - 1]),
2).sum()
return 2 * (z_tv / count_z + h_tv / count_h + w_tv / count_w) / (
batch_size * slice_num)
def _tensor_size(self, t):
return t.size()[2] * t.size()[3] * t.size()[4]
class LaplacianLoss_L1(nn.Module):
def __init__(self):
super(LaplacianLoss_L1, self).__init__()
def forward(self, x):
batch_size, slice_num = x.size()[:2]
z_x = x.size()[2]
h_x = x.size()[3]
w_x = x.size()[4]
count_z = self._tensor_size(x[:, :, 1:, :, :])
count_h = self._tensor_size(x[:, :, :, 1:, :])
count_w = self._tensor_size(x[:, :, :, :, 1:])
z_tv = torch.abs((x[:, :, 1:, :, :] - x[:, :, :z_x - 1, :, :])).sum()
h_tv = torch.abs((x[:, :, :, 1:, :] - x[:, :, :, :h_x - 1, :])).sum()
w_tv = torch.abs((x[:, :, :, :, 1:] - x[:, :, :, :, :w_x - 1])).sum()
return 2 * (z_tv / count_z + h_tv / count_h + w_tv / count_w) / (
batch_size * slice_num)
def _tensor_size(self, t):
return t.size()[2] * t.size()[3] * t.size()[4]
class GANLoss(nn.Module):
def __init__(self,
gan_mode,
target_real_label=1.0,
target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_tensor = None
self.fake_label_tensor = None
self.zero_tensor = None
self.Tensor = tensor
self.gan_mode = gan_mode
if gan_mode == 'ls':
pass
elif gan_mode == 'original':
pass
elif gan_mode == 'w':
pass
elif gan_mode == 'hinge':
pass
else:
raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
def get_target_tensor(self, input, target_is_real):
if target_is_real:
if self.real_label_tensor is None:
self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
self.real_label_tensor.requires_grad_(False)
return self.real_label_tensor.expand_as(input)
else:
if self.fake_label_tensor is None:
self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
self.fake_label_tensor.requires_grad_(False)
return self.fake_label_tensor.expand_as(input)
def get_zero_tensor(self, input):
if self.zero_tensor is None:
self.zero_tensor = self.Tensor(1).fill_(0)
self.zero_tensor.requires_grad_(False)
return self.zero_tensor.expand_as(input)
def loss(self, input, target_is_real, for_discriminator=True):
if self.gan_mode == 'original': # cross entropy loss
target_tensor = self.get_target_tensor(input, target_is_real)
loss = F.binary_cross_entropy_with_logits(input, target_tensor)
return loss
elif self.gan_mode == 'ls':
target_tensor = self.get_target_tensor(input, target_is_real)
return F.mse_loss(input, target_tensor)
elif self.gan_mode == 'hinge':
if for_discriminator:
if target_is_real:
minval = torch.min(input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)
else:
minval = torch.min(-input - 1, self.get_zero_tensor(input))
loss = -torch.mean(minval)
else:
assert target_is_real, "The generator's hinge loss must be aiming for real"
loss = -torch.mean(input)
return loss
else:
# wgan
if target_is_real:
return -input.mean()
else:
return input.mean()
def __call__(self, input, target_is_real, for_discriminator=True):
# computing loss is a bit complicated because |input| may not be
# a tensor, but list of tensors in case of multiscale discriminator
if isinstance(input, list):
loss = 0
for pred_i in input:
if isinstance(pred_i, list):
pred_i = pred_i[-1]
loss_tensor = self.loss(pred_i, target_is_real,
for_discriminator)
bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
loss += new_loss
return loss / len(input)
else:
return self.loss(input, target_is_real, for_discriminator)
class BinaryDiceLoss(nn.Module):
def __init__(self, smooth=1, p=1, reduction='mean'):
super(BinaryDiceLoss, self).__init__()
self.smooth = smooth
self.p = p
self.reduction = reduction
def forward(self, predict, target):
assert predict.shape[0] == target.shape[
0], "predict & target batch size don't match"
predict = predict.contiguous().view(predict.shape[0], -1)
target = target.contiguous().view(target.shape[0], -1)
num = torch.sum(torch.mul(predict, target), dim=1)
den = torch.sum(predict + target, dim=1)
loss = 1 - (2 * num + self.smooth) / (den + self.smooth)
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
elif self.reduction == 'none':
return loss
else:
raise Exception('Unexpected reduction {}'.format(self.reduction))

View File

@@ -0,0 +1,577 @@
# Part of the implementation is borrowed and modified from Deep3DFaceRecon_pytorch,
# publicly available at https://github.com/sicxu/Deep3DFaceRecon_pytorch
import os
from typing import Any, Callable, List, Optional, Type, Union
import torch
import torch.nn as nn
from kornia.geometry import warp_affine
from torch import Tensor
from torch.optim import lr_scheduler
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
def resize_n_crop(image, M, dsize=112):
# image: (b, c, h, w)
# M : (b, 2, 3)
return warp_affine(image, M, dsize=(dsize, dsize))
def filter_state_dict(state_dict, remove_name='fc'):
new_state_dict = {}
for key in state_dict:
if remove_name in key:
continue
new_state_dict[key] = state_dict[key]
return new_state_dict
def define_net_recon(net_recon, use_last_fc=False, init_path=None):
return ReconNetWrapper(
net_recon, use_last_fc=use_last_fc, init_path=init_path)
def define_net_recon2(net_recon, use_last_fc=False, init_path=None):
return ReconNetWrapper2(
net_recon, use_last_fc=use_last_fc, init_path=init_path)
class ReconNetWrapper(nn.Module):
fc_dim = 257
def __init__(self, net_recon, use_last_fc=False, init_path=None):
super(ReconNetWrapper, self).__init__()
self.use_last_fc = use_last_fc
if net_recon not in func_dict:
return NotImplementedError('network [%s] is not implemented',
net_recon)
func, last_dim = func_dict[net_recon]
backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)
if init_path and os.path.isfile(init_path):
state_dict = filter_state_dict(
torch.load(init_path, map_location='cpu'))
backbone.load_state_dict(state_dict)
print('loading init net_recon %s from %s' % (net_recon, init_path))
self.backbone = backbone
if not use_last_fc:
self.final_layers = nn.ModuleList([
conv1x1(last_dim, 80, bias=True), # id layer
conv1x1(last_dim, 64, bias=True), # exp layer
conv1x1(last_dim, 80, bias=True), # tex layer
conv1x1(last_dim, 3, bias=True), # angle layer
conv1x1(last_dim, 27, bias=True), # gamma layer
conv1x1(last_dim, 2, bias=True), # tx, ty
conv1x1(last_dim, 1, bias=True) # tz
])
for m in self.final_layers:
nn.init.constant_(m.weight, 0.)
nn.init.constant_(m.bias, 0.)
def forward(self, x):
x = self.backbone(x)
if not self.use_last_fc:
output = []
for layer in self.final_layers:
output.append(layer(x))
x = torch.flatten(torch.cat(output, dim=1), 1)
return x
class ReconNetWrapper2(nn.Module):
fc_dim = 264
def __init__(self, net_recon, use_last_fc=False, init_path=None):
super(ReconNetWrapper2, self).__init__()
self.use_last_fc = use_last_fc
if net_recon not in func_dict:
return NotImplementedError('network [%s] is not implemented',
net_recon)
func, last_dim = func_dict[net_recon]
backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)
if init_path and os.path.isfile(init_path):
state_dict = filter_state_dict(
torch.load(init_path, map_location='cpu'))
backbone.load_state_dict(state_dict)
print('loading init net_recon %s from %s' % (net_recon, init_path))
self.backbone = backbone
if not use_last_fc:
self.final_layers2 = nn.ModuleList([
conv1x1(last_dim, 80, bias=True), # id layer
conv1x1(last_dim, 51, bias=True), # exp layer
conv1x1(last_dim, 100, bias=True), # tex layer
conv1x1(last_dim, 3, bias=True), # angle layer
conv1x1(last_dim, 27, bias=True), # gamma layer
conv1x1(last_dim, 2, bias=True), # tx, ty
conv1x1(last_dim, 1, bias=True) # tz
])
for m in self.final_layers2:
nn.init.constant_(m.weight, 0.)
nn.init.constant_(m.bias, 0.)
def forward(self, x):
x = self.backbone(x)
if not self.use_last_fc:
output = []
for layer in self.final_layers2:
output.append(layer(x))
x = torch.flatten(torch.cat(output, dim=1), 1)
return x
# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py
__all__ = [
'ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2',
'wide_resnet101_2'
]
model_urls = {
'resnet18':
'https://download.pytorch.org/models/resnet18-f37072fd.pth',
'resnet34':
'https://download.pytorch.org/models/resnet34-b627a593.pth',
'resnet50':
'https://download.pytorch.org/models/resnet50-0676ba61.pth',
'resnet101':
'https://download.pytorch.org/models/resnet101-63fe2227.pth',
'resnet152':
'https://download.pytorch.org/models/resnet152-394f9c45.pth',
'resnext50_32x4d':
'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d':
'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2':
'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2':
'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
def conv3x3(in_planes: int,
out_planes: int,
stride: int = 1,
groups: int = 1,
dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)
def conv1x1(in_planes: int,
out_planes: int,
stride: int = 1,
bias: bool = False) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(
in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
class BasicBlock(nn.Module):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
'Dilation > 1 not supported in BasicBlock')
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
use_last_fc: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None) -> None:
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError('replace_stride_with_dilation should be None '
'or a 3-element tuple, got {}'.format(
replace_stride_with_dilation))
self.use_last_fc = use_last_fc
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(
block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(
block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(
block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
if self.use_last_fc:
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight,
0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight,
0) # type: ignore[arg-type]
def _make_layer(self,
block: Type[Union[BasicBlock, Bottleneck]],
planes: int,
blocks: int,
stride: int = 1,
dilate: bool = False) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
if self.use_last_fc:
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _resnet(arch: str, block: Type[Union[BasicBlock,
Bottleneck]], layers: List[int],
pretrained: bool, progress: bool, **kwargs: Any) -> ResNet:
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(
model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model
def resnet18(pretrained: bool = False,
progress: bool = True,
**kwargs: Any) -> ResNet:
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained: bool = False,
progress: bool = True,
**kwargs: Any) -> ResNet:
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained: bool = False,
progress: bool = True,
**kwargs: Any) -> ResNet:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained: bool = False,
progress: bool = True,
**kwargs: Any) -> ResNet:
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained,
progress, **kwargs)
def resnet152(pretrained: bool = False,
progress: bool = True,
**kwargs: Any) -> ResNet:
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained,
progress, **kwargs)
def resnext50_32x4d(pretrained: bool = False,
progress: bool = True,
**kwargs: Any) -> ResNet:
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained,
progress, **kwargs)
def resnext101_32x8d(pretrained: bool = False,
progress: bool = True,
**kwargs: Any) -> ResNet:
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained,
progress, **kwargs)
def wide_resnet50_2(pretrained: bool = False,
progress: bool = True,
**kwargs: Any) -> ResNet:
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained,
progress, **kwargs)
def wide_resnet101_2(pretrained: bool = False,
progress: bool = True,
**kwargs: Any) -> ResNet:
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained,
progress, **kwargs)
func_dict = {'resnet18': (resnet18, 512), 'resnet50': (resnet50, 2048)}

View File

@@ -0,0 +1,414 @@
# Part of the implementation is borrowed and modified from Deep3DFaceRecon_pytorch,
# publicly available at https://github.com/sicxu/Deep3DFaceRecon_pytorch
import warnings
from typing import List
import numpy as np
import nvdiffrast
import nvdiffrast.torch as dr
import torch
import torch.nn.functional as F
from torch import nn
from .losses import TVLoss, TVLoss_std
warnings.filterwarnings('ignore')
def ndc_projection(x=0.1, n=1.0, f=50.0):
return np.array([[n / x, 0, 0, 0], [0, n / -x, 0, 0],
[0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],
[0, 0, -1, 0]]).astype(np.float32)
def to_image(face_shape):
"""
Return:
face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
Parameters:
face_shape -- torch.tensor, size (B, N, 3)
"""
focal = 1015.
center = 112.
persc_proj = np.array([focal, 0, center, 0, focal, center, 0, 0,
1]).reshape([3, 3]).astype(np.float32).transpose()
persc_proj = torch.tensor(persc_proj).to(face_shape.device)
face_proj = face_shape @ persc_proj
face_proj = face_proj[..., :2] / face_proj[..., 2:]
return face_proj
class MeshRenderer(nn.Module):
def __init__(self, rasterize_fov, znear=0.1, zfar=10, rasterize_size=224):
super(MeshRenderer, self).__init__()
x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
self.ndc_proj = torch.tensor(ndc_projection(
x=x, n=znear,
f=zfar)).matmul(torch.diag(torch.tensor([1., -1, -1, 1])))
self.rasterize_size = rasterize_size
self.glctx = None
def forward(self, vertex, tri, feat=None):
"""
Return:
mask -- torch.tensor, size (B, 1, H, W)
depth -- torch.tensor, size (B, 1, H, W)
features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
Parameters:
vertex -- torch.tensor, size (B, N, 3)
tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
feat(optional) -- torch.tensor, size (B, C), features
"""
device = vertex.device
rsize = int(self.rasterize_size)
ndc_proj = self.ndc_proj.to(device)
verts_proj = to_image(vertex)
# trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
if vertex.shape[-1] == 3:
vertex = torch.cat(
[vertex, torch.ones([*vertex.shape[:2], 1]).to(device)],
dim=-1)
vertex[..., 1] = -vertex[..., 1]
vertex_ndc = vertex @ ndc_proj.t()
if self.glctx is None:
if nvdiffrast.__version__ == '0.2.7':
self.glctx = dr.RasterizeGLContext(device=device)
else:
self.glctx = dr.RasterizeCudaContext(device=device)
ranges = None
if isinstance(tri, List) or len(tri.shape) == 3:
vum = vertex_ndc.shape[1]
fnum = torch.tensor([f.shape[0]
for f in tri]).unsqueeze(1).to(device)
print('fnum shape:{}'.format(fnum.shape))
fstartidx = torch.cumsum(fnum, dim=0) - fnum
ranges = torch.cat([fstartidx, fnum],
axis=1).type(torch.int32).cpu()
for i in range(tri.shape[0]):
tri[i] = tri[i] + i * vum
vertex_ndc = torch.cat(vertex_ndc, dim=0)
tri = torch.cat(tri, dim=0)
# for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
tri = tri.type(torch.int32).contiguous()
rast_out, _ = dr.rasterize(
self.glctx,
vertex_ndc.contiguous(),
tri,
resolution=[rsize, rsize],
ranges=ranges)
depth, _ = dr.interpolate(
vertex.reshape([-1, 4])[..., 2].unsqueeze(1).contiguous(),
rast_out, tri)
depth = depth.permute(0, 3, 1, 2)
mask = (rast_out[..., 3] > 0).float().unsqueeze(1)
depth = mask * depth
image = None
verts_x = verts_proj[0, :, 0]
verts_y = 224 - verts_proj[0, :, 1]
verts_int = torch.ceil(verts_proj[0]).long() # (n, 2)
verts_xr_int = verts_int[:, 0].clamp(1, 224 - 1)
verts_yt_int = 224 - verts_int[:, 1].clamp(2, 224)
verts_right_float = verts_xr_int - verts_x
verts_left_float = 1 - verts_right_float
verts_top_float = verts_y - verts_yt_int
verts_bottom_float = 1 - verts_top_float
rast_lt = rast_out[0, verts_yt_int, verts_xr_int - 1, 3]
rast_lb = rast_out[0, verts_yt_int + 1, verts_xr_int - 1, 3]
rast_rt = rast_out[0, verts_yt_int, verts_xr_int, 3]
rast_rb = rast_out[0, verts_yt_int + 1, verts_xr_int, 3]
occ_feat = (rast_lt > 0) * 1.0 * (verts_left_float + verts_top_float) + \
(rast_lb > 0) * 1.0 * (verts_left_float + verts_bottom_float) + \
(rast_rt > 0) * 1.0 * (verts_right_float + verts_top_float) + \
(rast_rb > 0) * 1.0 * (verts_right_float + verts_bottom_float)
occ_feat = occ_feat[None, :, None] / 4.0
occ, _ = dr.interpolate(occ_feat, rast_out, tri)
occ = occ.permute(0, 3, 1, 2)
if feat is not None:
image, _ = dr.interpolate(feat, rast_out, tri)
image = image.permute(0, 3, 1, 2)
image = mask * image
return mask, depth, image, occ
def render_uv_texture(self, vertex, tri, uv, uv_texture):
"""
Return:
mask -- torch.tensor, size (B, 1, H, W)
depth -- torch.tensor, size (B, 1, H, W)
features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
Parameters:
vertex -- torch.tensor, size (B, N, 3)
tri -- torch.tensor, size (M, 3), triangles
uv -- torch.tensor, size (B,N, 2), uv mapping
base_tex -- torch.tensor, size (B,H,W,C)
"""
device = vertex.device
rsize = int(self.rasterize_size)
ndc_proj = self.ndc_proj.to(device)
# trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
if vertex.shape[-1] == 3:
vertex = torch.cat(
[vertex, torch.ones([*vertex.shape[:2], 1]).to(device)],
dim=-1)
vertex[..., 1] = -vertex[..., 1]
vertex_ndc = vertex @ ndc_proj.t()
if self.glctx is None:
if nvdiffrast.__version__ == '0.2.7':
self.glctx = dr.RasterizeGLContext(device=device)
else:
self.glctx = dr.RasterizeCudaContext(device=device)
ranges = None
if isinstance(tri, List) or len(tri.shape) == 3:
vum = vertex_ndc.shape[1]
fnum = torch.tensor([f.shape[0]
for f in tri]).unsqueeze(1).to(device)
print('fnum shape:{}'.format(fnum.shape))
fstartidx = torch.cumsum(fnum, dim=0) - fnum
ranges = torch.cat([fstartidx, fnum],
axis=1).type(torch.int32).cpu()
for i in range(tri.shape[0]):
tri[i] = tri[i] + i * vum
vertex_ndc = torch.cat(vertex_ndc, dim=0)
tri = torch.cat(tri, dim=0)
# for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
tri = tri.type(torch.int32).contiguous()
rast_out, _ = dr.rasterize(
self.glctx,
vertex_ndc.contiguous(),
tri,
resolution=[rsize, rsize],
ranges=ranges)
depth, _ = dr.interpolate(
vertex.reshape([-1, 4])[..., 2].unsqueeze(1).contiguous(),
rast_out, tri)
depth = depth.permute(0, 3, 1, 2)
mask = (rast_out[..., 3] > 0).float().unsqueeze(1)
depth = mask * depth
uv[..., -1] = 1.0 - uv[..., -1]
rast_out, rast_db = dr.rasterize(
self.glctx,
vertex_ndc.contiguous(),
tri,
resolution=[rsize, rsize],
ranges=ranges)
interp_out, uv_da = dr.interpolate(
uv, rast_out, tri, rast_db, diff_attrs='all')
uv_texture = uv_texture.permute(0, 2, 3, 1).contiguous()
img = dr.texture(
uv_texture, interp_out, filter_mode='linear') # , uv_da)
img = img * torch.clamp(rast_out[..., -1:], 0,
1) # Mask out background.
image = img.permute(0, 3, 1, 2)
return mask, depth, image
def pred_shape_and_texture(self,
vertex,
tri,
uv,
target_img,
base_tex=None):
"""
Return:
mask -- torch.tensor, size (B, 1, H, W)
depth -- torch.tensor, size (B, 1, H, W)
features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
Parameters:
vertex -- torch.tensor, size (B, N, 3)
tri -- torch.tensor, size (B, M, 3) or (M, 3), triangles
uv -- torch.tensor, size (B,N, 2), uv mapping
base_tex -- torch.tensor, size (B,H,W,C)
"""
uv = uv.clone()
device = vertex.device
rsize = int(self.rasterize_size)
ndc_proj = self.ndc_proj.to(device)
# trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
if vertex.shape[-1] == 3:
vertex = torch.cat(
[vertex, torch.ones([*vertex.shape[:2], 1]).to(device)],
dim=-1)
vertex[..., 1] = -vertex[..., 1]
vertex_ndc = vertex @ ndc_proj.t()
if self.glctx is None:
if nvdiffrast.__version__ == '0.2.7':
self.glctx = dr.RasterizeGLContext(device=device)
else:
self.glctx = dr.RasterizeCudaContext(device=device)
# print("create glctx on device cuda:%d" % device.index)
# print('vertex_ndc shape:{}'.format(vertex_ndc.shape)) # Size([1, 35709, 4])
# print('tri shape:{}'.format(tri.shape)) # Size([70789, 3])
ranges = None
if isinstance(tri, List) or len(tri.shape) == 3:
vum = vertex_ndc.shape[1]
fnum = torch.tensor([f.shape[0]
for f in tri]).unsqueeze(1).to(device)
# print('fnum shape:{}'.format(fnum.shape))
fstartidx = torch.cumsum(fnum, dim=0) - fnum
ranges = torch.cat([fstartidx, fnum],
axis=1).type(torch.int32).cpu()
for i in range(tri.shape[0]):
tri[i] = tri[i] + i * vum
vertex_ndc = torch.cat(vertex_ndc, dim=0)
tri = torch.cat(tri, dim=0)
# for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
tri = tri.type(torch.int32).contiguous()
rast_out, _ = dr.rasterize(
self.glctx,
vertex_ndc.contiguous(),
tri,
resolution=[rsize, rsize],
ranges=ranges)
depth, _ = dr.interpolate(
vertex.reshape([-1, 4])[..., 2].unsqueeze(1).contiguous(),
rast_out, tri)
depth = depth.permute(0, 3, 1, 2)
mask = (rast_out[..., 3] > 0).float().unsqueeze(1)
depth = mask * depth
uv[..., -1] = 1.0 - uv[..., -1]
rast_out, rast_db = dr.rasterize(
self.glctx,
vertex_ndc.contiguous(),
tri,
resolution=[rsize, rsize],
ranges=ranges)
interp_out, uv_da = dr.interpolate(
uv, rast_out, tri, rast_db, diff_attrs='all')
mask_3c = mask.permute(0, 2, 3, 1)
mask_3c = torch.cat((mask_3c, mask_3c, mask_3c), dim=-1)
maskout_img = mask_3c * target_img
mean_color = torch.sum(maskout_img, dim=(1, 2))
valid_pixel_count = torch.sum(mask)
mean_color = mean_color / valid_pixel_count
tex = torch.zeros((1, int(128), 128, 3), dtype=torch.float32)
# tex = torch.zeros((1, 128, 128, 3), dtype=torch.float32)
tex[:, :, :, 0] = mean_color[0, 0]
tex[:, :, :, 1] = mean_color[0, 1]
tex[:, :, :, 2] = mean_color[0, 2]
tex = tex.cuda()
tex_mask = torch.zeros((1, int(2048), 2048, 3), dtype=torch.float32)
# tex_mask = torch.zeros((1, 2048, 2048, 3), dtype=torch.float32)
tex_mask[:, :, :, 1] = 1.0
tex_mask = tex_mask.cuda()
tex_mask.requires_grad = True
tex_mask = tex_mask.contiguous()
criterionTV = TVLoss()
if base_tex is not None:
base_tex = base_tex.cuda()
for tex_resolution in [64, 128, 256, 512, 1024, 2048]:
tex = tex.detach()
tex = tex.permute(0, 3, 1, 2)
tex = F.interpolate(tex, (int(tex_resolution), tex_resolution))
# tex = F.interpolate(tex, (tex_resolution, tex_resolution))
tex = tex.permute(0, 2, 3, 1).contiguous()
if base_tex is not None:
_base_tex = base_tex.permute(0, 3, 1, 2)
_base_tex = F.interpolate(
_base_tex, (int(tex_resolution), tex_resolution))
# _base_tex = F.interpolate(_base_tex, (tex_resolution, tex_resolution))
_base_tex = _base_tex.permute(0, 2, 3, 1).contiguous()
tex += _base_tex
tex.requires_grad = True
optim = torch.optim.Adam([tex], lr=1e-2)
texture_opt_iters = 200
if tex_resolution == 2048:
optim_mask = torch.optim.Adam([tex_mask], lr=1e-2)
for i in range(int(texture_opt_iters)):
if tex_resolution == 2048:
optim_mask.zero_grad()
rendered = dr.texture(
tex_mask, interp_out, filter_mode='linear') # , uv_da)
rendered = rendered * torch.clamp(
rast_out[..., -1:], 0, 1) # Mask out background.
tex_loss = torch.mean((target_img - rendered)**2)
tex_loss.backward()
optim_mask.step()
optim.zero_grad()
img = dr.texture(
tex, interp_out, filter_mode='linear') # , uv_da)
img = img * torch.clamp(rast_out[..., -1:], 0,
1) # Mask out background.
recon_loss = torch.mean((target_img - img)**2)
if tex_resolution < 2048:
tv_loss = criterionTV(tex.permute(0, 3, 1, 2))
total_loss = recon_loss + tv_loss * 0.01
else:
total_loss = recon_loss
total_loss.backward()
optim.step()
tex_map = tex[0].detach().cpu().numpy()[..., ::-1] * 255.0
image = img.permute(0, 3, 1, 2)
tex_mask = tex_mask[0].detach().cpu().numpy() * 255.0
tex_mask = np.where(tex_mask[..., 1] > 250, 1.0, 0.0) * np.where(
tex_mask[..., 0] < 10, 1.0, 0) * np.where(tex_mask[..., 2] < 10,
1.0, 0)
tex_mask = 1.0 - tex_mask
return mask, depth, image, tex_map, tex_mask

View File

@@ -0,0 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
bfm_folder = ''
bfm_model = 'head_model_for_maas.mat'
camera_d = 10.0
center = 112.0
focal = 1015.0
isTrain = False
net_recon = 'resnet50'
phase = 'test'
use_ddp = False
use_last_fc = False
z_far = 15.0
z_near = 5.0
lr = 0.001
w_color = 1.92
w_reg = 3.0e-4
w_gamma = 10.0
w_lm = 1.6e-3
w_id = 1.0
w_exp = 0.8
w_tex = 1.7e-2

View File

@@ -0,0 +1,145 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import cv2
import numpy as np
def get_fade_out_mask(length, start_value, end_value, fade_start_ratio,
fade_end_ratio):
fade_start_ind = int(length * fade_start_ratio)
fade_end_ind = int(length * fade_end_ratio)
left_part = np.array([start_value] * fade_start_ind)
fade_part = np.linspace(start_value, end_value,
fade_end_ind - fade_start_ind)
len_right = length - len(left_part) - len(fade_part)
right_part = np.array([end_value] * len_right)
fade_out_mask = np.concatenate([left_part, fade_part, right_part], axis=0)
return fade_out_mask
class TexProcesser():
def __init__(self, model_root):
self.tex_size = 4096
self.bald_tex_bg = cv2.imread(
os.path.join(model_root,
'assets/texture/template_bald_tex_2.jpg')).astype(
np.float32)
self.hair_tex_bg = cv2.imread(
os.path.join(model_root,
'assets/texture/template_withHair_tex.jpg')).astype(
np.float32)
self.hair_mask = cv2.imread(
os.path.join(model_root,
'assets/texture/hair_mask_male.png'))[..., 0].astype(
np.float32) / 255.0
self.hair_mask = cv2.resize(self.hair_mask, (4096, 4096 + 1024))
front_mask = cv2.imread(
os.path.join(model_root,
'assets/texture/face_mask_singleview.jpg')).astype(
np.float32) / 255
front_mask = cv2.resize(front_mask, (1024, 1024))
front_mask = cv2.resize(front_mask, (0, 0), fx=0.1, fy=0.1)
front_mask = cv2.erode(front_mask,
np.ones(shape=(7, 7), dtype=np.float32))
front_mask = cv2.GaussianBlur(front_mask, (13, 13), 0)
self.front_mask = cv2.resize(front_mask,
(self.tex_size, self.tex_size))
self.binary_front_mask = self.front_mask.copy()
self.binary_front_mask[(self.front_mask < 0.3)
+ (self.front_mask > 0.7)] = 0
self.binary_front_mask[self.binary_front_mask != 0] = 1.0
self.binary_front_mask_ = self.binary_front_mask.copy()
self.binary_front_mask_[:int(4096 * 375 / 950)] = 0
self.binary_front_mask_[int(4096 * 600 / 950):] = 0
self.binary_front_mask = np.zeros((4096 + 1024, 4096, 3),
dtype=np.float32)
self.binary_front_mask[:4096, :] = self.binary_front_mask_
self.front_mask_ = self.front_mask.copy()
self.front_mask = np.zeros((4096 + 1024, 4096, 3), dtype=np.float32)
self.front_mask[:4096, :] = self.front_mask_
self.fg_mask = cv2.imread(
os.path.join(model_root,
'assets/texture/fg_mask.png'))[..., 0].astype(
np.float32) / 255.0
self.fg_mask = cv2.resize(self.fg_mask, (256, 256))
self.fg_mask = cv2.dilate(self.fg_mask,
np.ones(shape=(13, 13), dtype=np.float32))
self.fg_mask = cv2.blur(self.fg_mask, (27, 27), 0)
self.fg_mask = cv2.resize(self.fg_mask, (4096, 4096 + 1024))
self.fg_mask = self.fg_mask[..., None]
self.cheek_mask = cv2.imread(
os.path.join(model_root,
'assets/texture/cheek_area_mask.png'))[..., 0].astype(
np.float32) / 255.0
self.cheek_mask = cv2.resize(self.cheek_mask, (4096, 4096 + 1024))
self.cheek_mask = self.cheek_mask[..., None]
self.bald_tex_bg = self.bald_tex_bg[:4096]
self.hair_tex_bg = self.hair_tex_bg[:4096]
self.fg_mask = self.fg_mask[:4096]
self.hair_mask = self.hair_mask[:4096]
self.front_mask = self.front_mask[:4096]
self.binary_front_mask = self.binary_front_mask[:4096]
self.front_mask_ = self.front_mask_[:4096]
self.cheek_mask_left = self.cheek_mask[:4096]
self.cheek_mask_right = self.cheek_mask[:4096].copy()[:, ::-1]
def post_process_texture(self, tex_map, hair_tex=True):
tex_map = cv2.resize(tex_map, (self.tex_size, self.tex_size))
# if hair_tex is true and there is a dark side, use the mirror texture
if hair_tex:
left_cheek_light_mean = np.mean(
tex_map[self.cheek_mask_left[..., 0] == 1.0])
right_cheek_light_mean = np.mean(
tex_map[self.cheek_mask_right[..., 0] == 1.0])
tex_map_flip = tex_map[:, ::-1, :]
w = tex_map.shape[1]
half_w = w // 2
if left_cheek_light_mean > right_cheek_light_mean * 1.5:
tex_map[:, half_w:, :] = tex_map_flip[:, half_w:, :]
elif right_cheek_light_mean > left_cheek_light_mean * 2:
tex_map[:, :half_w, :] = tex_map_flip[:, :half_w, :]
# change the color of template texture
bg_mean_rgb = np.mean(
self.bald_tex_bg[self.binary_front_mask[..., 0] == 1.0],
axis=0)[None, None]
pred_tex_mean_rgb = np.mean(
tex_map[self.binary_front_mask[..., 0] == 1.0], axis=0)[None,
None] * 1.1
_bald_tex_bg = self.bald_tex_bg.copy()
_bald_tex_bg = self.bald_tex_bg + (pred_tex_mean_rgb - bg_mean_rgb)
if hair_tex:
# inpaint hair
tex_gray = cv2.cvtColor(
tex_map.astype(np.uint8),
cv2.COLOR_BGR2GRAY).astype(np.float32)
hair_mask = (self.hair_mask == 1.0) * (tex_gray < 120)
hair_bgr = np.mean(tex_map[hair_mask, :], axis=0) * 0.5
if hair_bgr is None:
hair_bgr = 20.0
_bald_tex_bg[self.hair_mask == 1.0] = hair_bgr
# fuse
tex_map = _bald_tex_bg * (1.
- self.fg_mask) + tex_map * self.fg_mask
else:
# fuse
tex_map = _bald_tex_bg * (
1. - self.front_mask) + tex_map * self.front_mask
return tex_map

View File

@@ -0,0 +1,28 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .generate_skeleton import gen_skeleton_bvh
from .utils import (read_obj, write_obj, render, rotate_x, rotate_y,
translate, projection)
else:
_import_structure = {
'generate_skeleton': ['gen_skeleton_bvh'],
'utils': [
'read_obj', 'write_obj', 'render', 'rotate_x', 'rotate_y',
'translate', 'projection'
],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,184 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
from .transforms import aa2quat, batch_rodrigues, mat2aa, quat2euler
def write_bvh(parent,
offset,
rotation,
position,
names,
frametime,
order,
path,
endsite=None):
file = open(path, 'w')
frame = rotation.shape[0]
joint_num = rotation.shape[1]
order = order.upper()
file_string = 'HIERARCHY\n'
seq = []
def write_static(idx, prefix):
nonlocal parent, offset, rotation, names
nonlocal order, endsite, file_string, seq
seq.append(idx)
if idx == 0:
name_label = 'ROOT ' + names[idx]
channel_label = 'CHANNELS 6 Xposition Yposition Zposition \
{}rotation {}rotation {}rotation'.format(*order)
else:
name_label = 'JOINT ' + names[idx]
channel_label = 'CHANNELS 3 {}rotation {}rotation \
{}rotation'.format(*order)
offset_label = 'OFFSET %.6f %.6f %.6f' % (
offset[idx][0], offset[idx][1], offset[idx][2])
file_string += prefix + name_label + '\n'
file_string += prefix + '{\n'
file_string += prefix + '\t' + offset_label + '\n'
file_string += prefix + '\t' + channel_label + '\n'
has_child = False
for y in range(idx + 1, rotation.shape[1]):
if parent[y] == idx:
has_child = True
write_static(y, prefix + '\t')
if not has_child:
file_string += prefix + '\t' + 'End Site\n'
file_string += prefix + '\t' + '{\n'
file_string += prefix + '\t\t' + 'OFFSET 0 0 0\n'
file_string += prefix + '\t' + '}\n'
file_string += prefix + '}\n'
write_static(0, '')
file_string += 'MOTION\n' + 'Frames: {}\n'.format(
frame) + 'Frame Time: %.8f\n' % frametime
for i in range(frame):
file_string += '%.6f %.6f %.6f ' % (position[i][0], position[i][1],
position[i][2])
for j in range(joint_num):
idx = seq[j]
file_string += '%.6f %.6f %.6f ' % (
rotation[i][idx][0], rotation[i][idx][1], rotation[i][idx][2])
file_string += '\n'
file.write(file_string)
return file_string
class WriterWrapper:
def __init__(self, parents):
self.parents = parents
def axis2euler(self, rot):
rot = rot.reshape(rot.shape[0], -1, 3) # 45, 24, 3
quat = aa2quat(rot)
euler = quat2euler(quat, order='xyz')
rot = euler
return rot
def mapper_rot_mixamo(self, rot, n_bone):
rot = rot.reshape(rot.shape[0], -1, 3)
smpl_mapper = [
0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 17, 21, 15, 18, 22, 19,
23, 20, 24
]
if n_bone > 24:
hand_mapper = list(range(25, 65))
smpl_mapper += hand_mapper
new_rot = torch.zeros((rot.shape[0], n_bone, 3)) # n, 24, 3
new_rot[:, :len(smpl_mapper), :] = rot[:, smpl_mapper, :]
return new_rot
def transform_rot_with_restpose(self, rot, rest_pose, node_list, n_bone):
rest_pose = batch_rodrigues(rest_pose.reshape(-1, 3)).reshape(
1, n_bone, 3, 3) # N*3-> N*3*3
frame_num = rot.shape[0]
rot = rot.reshape(rot.shape[0], -1, 3)
new_rot = rot.clone()
for k in range(frame_num):
action_rot = batch_rodrigues(rot[k].reshape(-1, 3)).reshape(
1, n_bone, 3, 3)
for i in node_list:
rot1 = rest_pose[0, i, :, :]
rot2 = action_rot[0, i, :, :]
nrot = torch.matmul(rot2, torch.inverse(rot1))
nvec = mat2aa(nrot)
new_rot[k, i, :] = nvec
new_rot = self.axis2euler(new_rot) # =# 45,24,3
return new_rot
def transform_rot_with_stdApose(self, rot, rest_pose):
print('transform_rot_with_stdApose')
rot = rot.reshape(rot.shape[0], -1, 3)
rest_pose = self.axis2euler(rest_pose)
print(rot.shape)
print(rest_pose.shape)
smpl_left_arm_idx = 18
smpl_right_arm_idx = 19
std_arm_rot = torch.tensor([[21.7184, -4.8148, 16.3985],
[-20.1108, 10.7190, -8.9279]])
x = rest_pose[:, smpl_left_arm_idx:smpl_right_arm_idx + 1, :]
delta = (x - std_arm_rot)
rot[:, smpl_left_arm_idx:smpl_right_arm_idx + 1, :] -= delta
return rot
def write(self,
filename,
offset,
rot=None,
action_loc=None,
rest_pose=None,
correct_arm=0): # offset: [24,3], rot:[45,72]
if not isinstance(offset, torch.Tensor):
offset = torch.tensor(offset)
n_bone = offset.shape[0] # 24
pos = offset[0].unsqueeze(0) # 1,3
if rot is None:
rot = np.zeros((1, n_bone, 3))
else: # rot: 45, 72
if rest_pose is None:
rot = self.mapper_rot_mixamo(rot, n_bone)
else:
if correct_arm == 1:
rot = self.mapper_rot_mixamo(rot, n_bone)
print(rot.shape)
node_list_chage = [16, 17]
n_bone = rot.shape[1]
print(rot[0, 19, :])
else:
node_list_chage = [1, 2, 3, 6, 9, 12, 13, 14, 15, 16, 17]
rot = self.transform_rot_with_restpose(
rot, rest_pose, node_list_chage, n_bone)
rest = torch.zeros((1, n_bone * 3))
rest = self.axis2euler(rest)
frames_add = 1
rest = rest.repeat(frames_add, 1, 1)
rot = torch.cat((rest, rot), 0)
pos = pos.repeat(rot.shape[0], 1)
action_len = action_loc.shape[0]
pos[-action_len:, :] = action_loc[..., :]
names = ['%02d' % i for i in range(n_bone)]
write_bvh(self.parents, offset, rot, pos, names, 0.0333, 'xyz',
filename)

View File

@@ -0,0 +1,167 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import pickle
import numpy as np
import torch
from .bvh_writer import WriterWrapper
from .utils import matrix_to_axis_angle, rotation_6d_to_matrix
def laod_smpl_params(pose_fname):
with open(pose_fname, 'rb') as f:
data = pickle.load(f)
pose = torch.from_numpy(data['pose'])
beta = torch.from_numpy(data['betas'])
trans = torch.from_numpy(data['trans'])
if 'joints' in data:
joints = torch.from_numpy(data['joints'])
joints = joints.reshape(1, -1, 3)
else:
joints = None
trans = trans.reshape(1, 3)
beta = beta.reshape(1, -1)[:, :10]
pose = pose.reshape(-1, 24 * 3)
return pose, beta, trans, joints
def set_pose_param(pose, start, end):
pose[:, start * 3:(end + 1) * 3] = 0
return pose
def load_test_anim(filename, device, mode='move'):
anim = np.load(filename)
anim = torch.tensor(anim, device=device, dtype=torch.float)
poses = anim[:, :-3]
loc = anim[:, -3:]
if os.path.basename(filename)[:5] == 'comb_':
loc = loc / 100
repeat = 0
idx = -1
for i in range(poses.shape[0]):
if i == 0:
continue
if repeat >= 5:
idx = i
break
if poses[i].equal(poses[i - 1]):
repeat += 1
else:
repeat = 0
poses = poses[:idx - 5, :]
loc = loc[:idx - 5, :]
if mode == 'inplace':
loc[1:, :] = loc[0, :]
return poses, loc
def load_syn_motion(filename, device, mode='move'):
data = np.load(filename, allow_pickle=True).item()
anim = data['thetas']
n_joint, c, t = anim.shape
anim = torch.tensor(anim, device=device, dtype=torch.float)
anim = anim.permute(2, 0, 1) # 180, 24, 6
poses = anim.reshape(-1, 6)
poses = rotation_6d_to_matrix(poses)
poses = matrix_to_axis_angle(poses)
poses = poses.reshape(-1, 24, 3)
loc = data['root_translation']
loc = torch.tensor(loc, device=device, dtype=torch.float)
loc = loc.permute(1, 0)
if mode == 'inplace':
loc = torch.zeros((t, 3))
print('load %s' % filename)
return poses, loc
def load_action(action_name,
model_dir,
action_dir,
mode='move',
device=torch.device('cpu')):
action_path = os.path.join(action_dir, action_name + '.npy')
if not os.path.exists(action_path):
print('can not find action %s, use default action instead' %
(action_name))
action_path = os.path.join(model_dir, '3D-assets', 'SwingDancing.npy')
print('load action %s' % action_path)
test_pose, test_loc = load_test_anim(
action_path, device, mode=mode) # pose:[45,72], loc:[45,1,3]
return test_pose, test_loc
def load_action_list(action,
model_dir,
action_dir,
mode='move',
device=torch.device('cpu')):
action_list = action.split(',')
test_pose, test_loc = load_action(
action_list[0], model_dir, action_dir, mode=mode, device=device)
final_loc = test_loc[-1, :]
idx = 0
if len(action_list) > 1:
for action in action_list:
if idx == 0:
idx += 1
continue
print('load action %s' % action)
pose, loc = load_action(
action, model_dir, action_dir, mode=mode, device=device)
delta_loc = final_loc - loc[0, :]
loc += delta_loc
final_loc = loc[-1, :]
test_pose = torch.cat([test_pose, pose], 0)
test_loc = torch.cat([test_loc, loc], 0)
idx += 1
return test_pose, test_loc
def gen_skeleton_bvh(model_dir, action_dir, case_dir, action, mode='move'):
outpath_a = os.path.join(case_dir, 'skeleton_a.bvh')
device = torch.device('cpu')
assets_dir = os.path.join(model_dir, '3D-assets')
pkl_path = os.path.join(assets_dir, 'smpl.pkl')
poses, shapes, trans, joints = laod_smpl_params(pkl_path)
if action.endswith('.npy'):
skeleton_path = os.path.join(assets_dir, 'skeleton_nohand.npy')
else:
skeleton_path = os.path.join(assets_dir, 'skeleton.npy')
data = np.load(skeleton_path, allow_pickle=True).item()
skeleton = data['skeleton']
parent = data['parent']
skeleton = skeleton.squeeze(0)
bvh_writer = WriterWrapper(parent)
if action.endswith('.npy'):
action_path = action
print('load action %s' % action_path)
test_pose, test_loc = load_syn_motion(action_path, device, mode=mode)
bvh_writer.write(
outpath_a,
skeleton,
test_pose,
action_loc=test_loc,
rest_pose=poses)
else:
print('load action %s' % action)
test_pose, test_loc = load_action_list(
action, model_dir, action_dir, mode='move', device=device)
std_y = torch.tensor(0.99)
test_loc = test_loc + (skeleton[0, 1] - std_y)
bvh_writer.write(outpath_a, skeleton, test_pose, action_loc=test_loc)
print('save %s' % outpath_a)
return 0

View File

@@ -0,0 +1,316 @@
# ------------------------------------------------------------------------
# Modified from https://github.com/facebookresearch/pytorch3d
# All Rights Reserved.
# ------------------------------------------------------------------------
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
def batch_mm(matrix, matrix_batch):
"""
https://github.com/pytorch/pytorch/issues/14489#issuecomment-607730242
:param matrix: Sparse or dense matrix, size (m, n).
:param matrix_batch: Batched dense matrices, size (b, n, k).
:return: The batched matrix-matrix product,
size (m, n) x (b, n, k) = (b, m, k).
"""
batch_size = matrix_batch.shape[0]
# Stack the vector batch into columns. (b, n, k) -> (n, b, k) -> (n, b*k)
vectors = matrix_batch.transpose(0, 1).reshape(matrix.shape[1], -1)
# A matrix-matrix product is a batched matrix-vector
# product of the columns.
# And then reverse the reshaping.
# (m, n) x (n, b*k) = (m, b*k) -> (m, b, k) -> (b, m, k)
return matrix.mm(vectors).reshape(matrix.shape[0], batch_size,
-1).transpose(1, 0)
def aa2quat(rots, form='wxyz', unified_orient=True):
"""
Convert angle-axis representation to wxyz quaternion
and to the half plan (w >= 0)
@param rots: angle-axis rotations, (*, 3)
@param form: quaternion format, either 'wxyz' or 'xyzw'
@param unified_orient: Use unified orientation for quaternion
(quaternion is dual cover of SO3)
:return:
"""
angles = rots.norm(dim=-1, keepdim=True)
norm = angles.clone()
norm[norm < 1e-8] = 1
axis = rots / norm
quats = torch.empty(
rots.shape[:-1] + (4, ), device=rots.device, dtype=rots.dtype)
angles = angles * 0.5
if form == 'wxyz':
quats[..., 0] = torch.cos(angles.squeeze(-1))
quats[..., 1:] = torch.sin(angles) * axis
elif form == 'xyzw':
quats[..., :3] = torch.sin(angles) * axis
quats[..., 3] = torch.cos(angles.squeeze(-1))
if unified_orient:
idx = quats[..., 0] < 0
quats[idx, :] *= -1
return quats
def quat2aa(quats):
"""
Convert wxyz quaternions to angle-axis representation
:param quats:
:return:
"""
_cos = quats[..., 0]
xyz = quats[..., 1:]
_sin = xyz.norm(dim=-1)
norm = _sin.clone()
norm[norm < 1e-7] = 1
axis = xyz / norm.unsqueeze(-1)
angle = torch.atan2(_sin, _cos) * 2
return axis * angle.unsqueeze(-1)
def quat2mat(quats: torch.Tensor):
"""
Convert (w, x, y, z) quaternions to 3x3 rotation matrix
:param quats: quaternions of shape (..., 4)
:return: rotation matrices of shape (..., 3, 3)
"""
qw = quats[..., 0]
qx = quats[..., 1]
qy = quats[..., 2]
qz = quats[..., 3]
x2 = qx + qx
y2 = qy + qy
z2 = qz + qz
xx = qx * x2
yy = qy * y2
wx = qw * x2
xy = qx * y2
yz = qy * z2
wy = qw * y2
xz = qx * z2
zz = qz * z2
wz = qw * z2
m = torch.empty(
quats.shape[:-1] + (3, 3), device=quats.device, dtype=quats.dtype)
m[..., 0, 0] = 1.0 - (yy + zz)
m[..., 0, 1] = xy - wz
m[..., 0, 2] = xz + wy
m[..., 1, 0] = xy + wz
m[..., 1, 1] = 1.0 - (xx + zz)
m[..., 1, 2] = yz - wx
m[..., 2, 0] = xz - wy
m[..., 2, 1] = yz + wx
m[..., 2, 2] = 1.0 - (xx + yy)
return m
def quat2euler(q, order='xyz', degrees=True):
"""
Convert (w, x, y, z) quaternions to xyz euler angles.
This is used for bvh output.
"""
q0 = q[..., 0]
q1 = q[..., 1]
q2 = q[..., 2]
q3 = q[..., 3]
es = torch.empty(q0.shape + (3, ), device=q.device, dtype=q.dtype)
if order == 'xyz':
es[..., 2] = torch.atan2(2 * (q0 * q3 - q1 * q2),
q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3)
es[..., 1] = torch.asin((2 * (q1 * q3 + q0 * q2)).clip(-1, 1))
es[..., 0] = torch.atan2(2 * (q0 * q1 - q2 * q3),
q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3)
else:
raise NotImplementedError('Cannot convert to ordering %s' % order)
if degrees:
es = es * 180 / np.pi
return es
def aa2mat(rots):
"""
Convert angle-axis representation to rotation matrix
:param rots: angle-axis representation
:return:
"""
quat = aa2quat(rots)
mat = quat2mat(quat)
return mat
def inv_affine(mat):
"""
Calculate the inverse of any affine transformation
"""
affine = torch.zeros((mat.shape[:2] + (1, 4)))
affine[..., 3] = 1
vert_mat = torch.cat((mat, affine), dim=2)
vert_mat_inv = torch.inverse(vert_mat)
return vert_mat_inv[..., :3, :]
def inv_rigid_affine(mat):
"""
Calculate the inverse of a rigid affine transformation
"""
res = mat.clone()
res[..., :3] = mat[..., :3].transpose(-2, -1)
res[...,
3] = -torch.matmul(res[..., :3], mat[..., 3].unsqueeze(-1)).squeeze(-1)
return res
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f'Invalid rotation matrix shape {matrix.shape}.')
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9, )), dim=-1)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
))
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0]**2, m21 - m12, m02 - m20, m10 - m01],
dim=-1),
torch.stack([m21 - m12, q_abs[..., 1]**2, m10 + m01, m02 + m20],
dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2]**2, m12 + m21],
dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3]**2],
dim=-1),
],
dim=-2,
)
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
return quat_candidates[F.one_hot(q_abs.argmax(
dim=-1), num_classes=4) > 0.5, :].reshape(batch_dim + (4, ))
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to axis/angle.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
half_angles = torch.atan2(norms, quaternions[..., :1])
angles = 2 * half_angles
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles])
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48)
return quaternions[..., 1:] / sin_half_angles_over_angles
def mat2aa(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to axis/angle.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
def batch_rodrigues(rot_vecs: Tensor, epsilon: float = 1e-8) -> Tensor:
''' Calculates the rotation matrices for a batch of rotation vectors
Parameters
----------
rot_vecs: torch.tensor Nx3
array of N axis-angle vectors
Returns
-------
R: torch.tensor Nx3x3
The rotation matrices for the given axis-angle parameters
'''
assert len(rot_vecs.shape) == 2, (
f'Expects an array of size Bx3, but received {rot_vecs.shape}')
batch_size = rot_vecs.shape[0]
device = rot_vecs.device
dtype = rot_vecs.dtype
angle = torch.norm(rot_vecs + epsilon, dim=1, keepdim=True, p=2)
rot_dir = rot_vecs / angle
cos = torch.unsqueeze(torch.cos(angle), dim=1)
sin = torch.unsqueeze(torch.sin(angle), dim=1)
# Bx1 arrays
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
.view((batch_size, 3, 3))
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
return rot_mat

View File

@@ -0,0 +1,375 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import cv2
import numpy as np
import nvdiffrast.torch as dr
import torch
import torch.nn.functional as F
def read_obj(obj_path, print_shape=False):
with open(obj_path, 'r') as f:
bfm_lines = f.readlines()
vertices = []
faces = []
uvs = []
vns = []
faces_uv = []
faces_normal = []
max_face_length = 0
for line in bfm_lines:
if line[:2] == 'v ':
vertex = [
float(a) for a in line.strip().split(' ')[1:] if len(a) > 0
]
vertices.append(vertex)
if line[:2] == 'f ':
items = line.strip().split(' ')[1:]
face = [int(a.split('/')[0]) for a in items if len(a) > 0]
max_face_length = max(max_face_length, len(face))
faces.append(face)
if '/' in items[0] and len(items[0].split('/')[1]) > 0:
face_uv = [int(a.split('/')[1]) for a in items if len(a) > 0]
faces_uv.append(face_uv)
if '/' in items[0] and len(items[0].split('/')) >= 3 and len(
items[0].split('/')[2]) > 0:
face_normal = [
int(a.split('/')[2]) for a in items if len(a) > 0
]
faces_normal.append(face_normal)
if line[:3] == 'vt ':
items = line.strip().split(' ')[1:]
uv = [float(a) for a in items if len(a) > 0]
uvs.append(uv)
if line[:3] == 'vn ':
items = line.strip().split(' ')[1:]
vn = [float(a) for a in items if len(a) > 0]
vns.append(vn)
vertices = np.array(vertices).astype(np.float32)
if max_face_length <= 3:
faces = np.array(faces).astype(np.int32)
else:
print('not a triangle face mesh!')
if vertices.shape[1] == 3:
mesh = {
'vertices': vertices,
'faces': faces,
}
else:
mesh = {
'vertices': vertices[:, :3],
'colors': vertices[:, 3:],
'faces': faces,
}
if len(uvs) > 0:
uvs = np.array(uvs).astype(np.float32)
mesh['uvs'] = uvs
if len(vns) > 0:
vns = np.array(vns).astype(np.float32)
mesh['normals'] = vns
if len(faces_uv) > 0:
if max_face_length <= 3:
faces_uv = np.array(faces_uv).astype(np.int32)
mesh['faces_uv'] = faces_uv
if len(faces_normal) > 0:
if max_face_length <= 3:
faces_normal = np.array(faces_normal).astype(np.int32)
mesh['faces_normal'] = faces_normal
if print_shape:
print('num of vertices', len(vertices))
print('num of faces', len(faces))
return mesh
def write_obj(save_path, mesh):
save_dir = os.path.dirname(save_path)
save_name = os.path.splitext(os.path.basename(save_path))[0]
if 'texture_map' in mesh:
cv2.imwrite(
os.path.join(save_dir, save_name + '.png'), mesh['texture_map'])
with open(os.path.join(save_dir, save_name + '.mtl'), 'w') as wf:
wf.write('newmtl material_0\n')
wf.write('Ka 1.000000 0.000000 0.000000\n')
wf.write('Kd 1.000000 1.000000 1.000000\n')
wf.write('Ks 0.000000 0.000000 0.000000\n')
wf.write('Tr 0.000000\n')
wf.write('illum 0\n')
wf.write('Ns 0.000000\n')
wf.write('map_Kd {}\n'.format(save_name + '.png'))
with open(save_path, 'w') as wf:
if 'texture_map' in mesh:
wf.write('# Create by ModelScope\n')
wf.write('mtllib ./{}.mtl\n'.format(save_name))
if 'colors' in mesh:
for i, v in enumerate(mesh['vertices']):
wf.write('v {} {} {} {} {} {}\n'.format(
v[0], v[1], v[2], mesh['colors'][i][0],
mesh['colors'][i][1], mesh['colors'][i][2]))
else:
for v in mesh['vertices']:
wf.write('v {} {} {}\n'.format(v[0], v[1], v[2]))
if 'uvs' in mesh:
for uv in mesh['uvs']:
wf.write('vt {} {}\n'.format(uv[0], uv[1]))
if 'normals' in mesh:
for vn in mesh['normals']:
wf.write('vn {} {} {}\n'.format(vn[0], vn[1], vn[2]))
if 'faces' in mesh:
for ind, face in enumerate(mesh['faces']):
if 'faces_uv' in mesh or 'faces_normal' in mesh:
if 'faces_uv' in mesh:
face_uv = mesh['faces_uv'][ind]
else:
face_uv = face
if 'faces_normal' in mesh:
face_normal = mesh['faces_normal'][ind]
else:
face_normal = face
row = 'f ' + ' '.join([
'{}/{}/{}'.format(face[i], face_uv[i], face_normal[i])
for i in range(len(face))
]) + '\n'
else:
row = 'f ' + ' '.join(
['{}'.format(face[i])
for i in range(len(face))]) + '\n'
wf.write(row)
def projection(x=0.1, n=1.0, f=50.0):
return np.array([[n / x, 0, 0, 0], [0, n / x, 0, 0],
[0, 0, -(f + n) / (f - n), -(2 * f * n) / (f - n)],
[0, 0, -1, 0]]).astype(np.float32)
def translate(x, y, z):
return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z],
[0, 0, 0, 1]]).astype(np.float32)
def rotate_x(a):
s, c = np.sin(a), np.cos(a)
return np.array([[1, 0, 0, 0], [0, c, s, 0], [0, -s, c, 0],
[0, 0, 0, 1]]).astype(np.float32)
def rotate_y(a):
s, c = np.sin(a), np.cos(a)
return np.array([[c, 0, s, 0], [0, 1, 0, 0], [-s, 0, c, 0],
[0, 0, 0, 1]]).astype(np.float32)
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.sum(x * y, -1, keepdim=True)
def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
return 2 * dot(x, n) * n - x
def length(x: torch.Tensor, eps: float = 1e-20) -> torch.Tensor:
return torch.sqrt(torch.clamp(
dot(x, x),
min=eps)) # Clamp to avoid nan gradients because grad(sqrt(0)) = NaN
def safe_normalize(x: torch.Tensor, eps: float = 1e-20) -> torch.Tensor:
return x / length(x, eps)
def transform_pos(mtx, pos):
t_mtx = torch.from_numpy(mtx).cuda() if isinstance(mtx,
np.ndarray) else mtx
posw = torch.cat([pos, torch.ones([pos.shape[0], 1]).cuda()], axis=1)
return torch.matmul(posw, t_mtx.t())[None, ...]
def render(glctx, mtx, pos, pos_idx, uv, uv_idx, tex, resolution, enable_mip,
max_mip_level):
pos_clip = transform_pos(mtx, pos)
rast_out, rast_out_db = dr.rasterize(
glctx, pos_clip, pos_idx, resolution=[resolution, resolution])
if enable_mip:
texc, texd = dr.interpolate(
uv[None, ...],
rast_out,
uv_idx,
rast_db=rast_out_db,
diff_attrs='all')
color = dr.texture(
tex[None, ...],
texc,
texd,
filter_mode='linear-mipmap-linear',
max_mip_level=max_mip_level)
else:
texc, _ = dr.interpolate(uv[None, ...], rast_out, uv_idx)
color = dr.texture(tex[None, ...], texc, filter_mode='linear')
pos_idx = pos_idx.type(torch.long)
v0 = pos[pos_idx[:, 0], :]
v1 = pos[pos_idx[:, 1], :]
v2 = pos[pos_idx[:, 2], :]
face_normals = safe_normalize(torch.cross(v1 - v0, v2 - v0))
face_normal_indices = (torch.arange(
0, face_normals.shape[0], dtype=torch.int64,
device='cuda')[:, None]).repeat(1, 3)
gb_geometric_normal, _ = dr.interpolate(face_normals[None, ...], rast_out,
face_normal_indices.int())
normal = (gb_geometric_normal + 1) * 0.5
mask = torch.clamp(rast_out[..., -1:], 0, 1)
color = color * mask + (1 - mask) * torch.ones_like(color)
normal = normal * mask + (1 - mask) * torch.ones_like(normal)
return color, mask, normal
# The following code is based on https://github.com/Mathux/ACTOR.git
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# Check PYTORCH3D_LICENCE before use
def _copysign(a, b):
"""
Return a tensor where each element has the absolute value taken from the,
corresponding element of a, with sign taken from the corresponding
element of b. This is like the standard copysign floating-point operation,
but is not careful about negative 0 and NaN.
Args:
a: source tensor.
b: tensor whose signs will be used, of the same shape as a.
Returns:
Tensor of the same shape as a with the signs of b.
"""
signs_differ = (a < 0) != (b < 0)
return torch.where(signs_differ, -a, a)
def _sqrt_positive_part(x):
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def matrix_to_quaternion(matrix):
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f'Invalid rotation matrix shape f{matrix.shape}.')
m00 = matrix[..., 0, 0]
m11 = matrix[..., 1, 1]
m22 = matrix[..., 2, 2]
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
return torch.stack((o0, o1, o2, o3), -1)
def quaternion_to_axis_angle(quaternions):
"""
Convert rotations given as quaternions to axis/angle.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
half_angles = torch.atan2(norms, quaternions[..., :1])
angles = 2 * half_angles
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles])
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48)
return quaternions[..., 1:] / sin_half_angles_over_angles
def matrix_to_axis_angle(matrix):
"""
Convert rotations given as rotation matrices to axis/angle.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
"""
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
using Gram--Schmidt orthogonalisation per Section B of [1].
Args:
d6: 6D rotation representation, of size (*, 6)
Returns:
batch of rotation matrices of size (*, 3, 3)
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
On the Continuity of Rotation Representations in Neural Networks.
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
Retrieved from http://arxiv.org/abs/1812.07035
"""
a1, a2 = d6[..., :3], d6[..., 3:]
b1 = F.normalize(a1, dim=-1)
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
b2 = F.normalize(b2, dim=-1)
b3 = torch.cross(b1, b2, dim=-1)
return torch.stack((b1, b2, b3), dim=-2)

View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .image_control_3d_portrait import ImageControl3dPortrait
else:
_import_structure = {
'image_control_3d_portrait': ['ImageControl3dPortrait']
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,468 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
from collections import OrderedDict
from typing import Any, Dict
import cv2
import json
import numpy as np
import PIL.Image as Image
import torch
import torchvision.transforms as transforms
from scipy.io import loadmat
from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.face_detection.peppa_pig_face.facer import FaceAna
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.device import create_device
from modelscope.utils.logger import get_logger
from .network.camera_utils import FOV_to_intrinsics, LookAtPoseSampler
from .network.shape_utils import convert_sdf_samples_to_ply
from .network.triplane import TriPlaneGenerator
from .network.triplane_encoder import TriplaneEncoder
logger = get_logger()
__all__ = ['ImageControl3dPortrait']
@MODELS.register_module(
Tasks.image_control_3d_portrait,
module_name=Models.image_control_3d_portrait)
class ImageControl3dPortrait(TorchModel):
def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the image face fusion model from the `model_dir` path.
Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
logger.info('model params:{}'.format(kwargs))
self.neural_rendering_resolution = kwargs[
'neural_rendering_resolution']
self.cam_radius = kwargs['cam_radius']
self.fov_deg = kwargs['fov_deg']
self.truncation_psi = kwargs['truncation_psi']
self.truncation_cutoff = kwargs['truncation_cutoff']
self.z_dim = kwargs['z_dim']
self.image_size = kwargs['image_size']
self.shape_res = kwargs['shape_res']
self.pitch_range = kwargs['pitch_range']
self.yaw_range = kwargs['yaw_range']
self.max_batch = kwargs['max_batch']
self.num_frames = kwargs['num_frames']
self.box_warp = kwargs['box_warp']
self.save_shape = kwargs['save_shape']
self.save_images = kwargs['save_images']
device = kwargs['device']
self.device = create_device(device)
self.facer = FaceAna(model_dir)
similarity_mat_path = os.path.join(model_dir, 'BFM',
'similarity_Lm3D_all.mat')
self.lm3d_std = self.load_lm3d(similarity_mat_path)
init_model_json = os.path.join(model_dir, 'configs',
'init_encoder.json')
with open(init_model_json, 'r') as fr:
init_kwargs_encoder = json.load(fr)
encoder_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
self.model = TriplaneEncoder(**init_kwargs_encoder)
ckpt_encoder = torch.load(encoder_path, map_location='cpu')
model_state = self.convert_state_dict(ckpt_encoder['state_dict'])
self.model.load_state_dict(model_state)
self.model = self.model.to(self.device)
self.model.eval()
init_args_G = ()
init_netG_json = os.path.join(model_dir, 'configs', 'init_G.json')
with open(init_netG_json, 'r') as fr:
init_kwargs_G = json.load(fr)
self.netG = TriPlaneGenerator(*init_args_G, **init_kwargs_G)
netG_path = os.path.join(model_dir, 'ffhqrebalanced512-128.pth')
ckpt_G = torch.load(netG_path)
self.netG.load_state_dict(ckpt_G['G_ema'], strict=False)
self.netG.neural_rendering_resolution = self.neural_rendering_resolution
self.netG = self.netG.to(self.device)
self.netG.eval()
self.intrinsics = FOV_to_intrinsics(self.fov_deg, device=self.device)
col, row = np.meshgrid(
np.arange(self.image_size), np.arange(self.image_size))
np_coord = np.stack((col, row), axis=2) / self.image_size # [0,1]
self.coord = torch.from_numpy(np_coord.astype(
np.float32)).unsqueeze(0).permute(0, 3, 1, 2).to(self.device)
self.image_transform = transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
logger.info('init done')
def convert_state_dict(self, state_dict):
if not next(iter(state_dict)).startswith('module.'):
return state_dict
new_state_dict = OrderedDict()
split_index = 0
for cur_key, cur_value in state_dict.items():
if cur_key.startswith('module.model'):
split_index = 13
elif cur_key.startswith('module'):
split_index = 7
break
for k, v in state_dict.items():
name = k[split_index:]
new_state_dict[name] = v
return new_state_dict
def detect_face(self, img):
src_h, src_w, _ = img.shape
boxes, landmarks, _ = self.facer.run(img)
if boxes.shape[0] == 0:
return None
elif boxes.shape[0] > 1:
max_area = 0
max_index = 0
for i in range(boxes.shape[0]):
bbox_width = boxes[i][2] - boxes[i][0]
bbox_height = boxes[i][3] - boxes[i][1]
area = int(bbox_width) * int(bbox_height)
if area > max_area:
max_index = i
max_area = area
return landmarks[max_index]
else:
return landmarks[0]
def get_f5p(self, landmarks, np_img):
eye_left = self.find_pupil(landmarks[36:41], np_img)
eye_right = self.find_pupil(landmarks[42:47], np_img)
if eye_left is None or eye_right is None:
logger.warning(
'cannot find 5 points with find_pupil, used mean instead.!')
eye_left = landmarks[36:41].mean(axis=0)
eye_right = landmarks[42:47].mean(axis=0)
nose = landmarks[30]
mouth_left = landmarks[48]
mouth_right = landmarks[54]
f5p = [[eye_left[0], eye_left[1]], [eye_right[0], eye_right[1]],
[nose[0], nose[1]], [mouth_left[0], mouth_left[1]],
[mouth_right[0], mouth_right[1]]]
return np.array(f5p)
def find_pupil(self, landmarks, np_img):
h, w, _ = np_img.shape
xmax = int(landmarks[:, 0].max())
xmin = int(landmarks[:, 0].min())
ymax = int(landmarks[:, 1].max())
ymin = int(landmarks[:, 1].min())
if ymin >= ymax or xmin >= xmax or ymin < 0 or xmin < 0 or ymax > h or xmax > w:
return None
eye_img_bgr = np_img[ymin:ymax, xmin:xmax, :]
eye_img = cv2.cvtColor(eye_img_bgr, cv2.COLOR_BGR2GRAY)
eye_img = cv2.equalizeHist(eye_img)
n_marks = landmarks - np.array([xmin, ymin]).reshape([1, 2])
eye_mask = cv2.fillConvexPoly(
np.zeros_like(eye_img), n_marks.astype(np.int32), 1)
ret, thresh = cv2.threshold(eye_img, 100, 255,
cv2.THRESH_BINARY | cv2.THRESH_OTSU)
thresh = (1 - thresh / 255.) * eye_mask
cnt = 0
xm = []
ym = []
for i in range(thresh.shape[0]):
for j in range(thresh.shape[1]):
if thresh[i, j] > 0.5:
xm.append(j)
ym.append(i)
cnt += 1
if cnt != 0:
xm.sort()
ym.sort()
xm = xm[cnt // 2]
ym = ym[cnt // 2]
else:
xm = thresh.shape[1] / 2
ym = thresh.shape[0] / 2
return xm + xmin, ym + ymin
def load_lm3d(self, similarity_mat_path):
Lm3D = loadmat(similarity_mat_path)
Lm3D = Lm3D['lm']
lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
lm_data1 = Lm3D[lm_idx[0], :]
lm_data2 = np.mean(Lm3D[lm_idx[[1, 2]], :], 0)
lm_data3 = np.mean(Lm3D[lm_idx[[3, 4]], :], 0)
lm_data4 = Lm3D[lm_idx[5], :]
lm_data5 = Lm3D[lm_idx[6], :]
Lm3D = np.stack([lm_data1, lm_data2, lm_data3, lm_data4, lm_data5],
axis=0)
Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
return Lm3D
def POS(self, xp, x):
npts = xp.shape[1]
A = np.zeros([2 * npts, 8])
A[0:2 * npts - 1:2, 0:3] = x.transpose()
A[0:2 * npts - 1:2, 3] = 1
A[1:2 * npts:2, 4:7] = x.transpose()
A[1:2 * npts:2, 7] = 1
b = np.reshape(xp.transpose(), [2 * npts, 1])
k, _, _, _ = np.linalg.lstsq(A, b)
R1 = k[0:3]
R2 = k[4:7]
sTx = k[3]
sTy = k[7]
s = (np.linalg.norm(R1) + np.linalg.norm(R2)) / 2
t = np.stack([sTx, sTy], axis=0)
return t, s
def resize_n_crop_img(self, img, lm, t, s, target_size=224., mask=None):
w0, h0 = img.size
w = (w0 * s).astype(np.int32)
h = (h0 * s).astype(np.int32)
left = (w / 2 - target_size / 2 + float(
(t[0] - w0 / 2) * s)).astype(np.int32)
right = left + target_size
up = (h / 2 - target_size / 2 + float(
(h0 / 2 - t[1]) * s)).astype(np.int32)
below = up + target_size
img = img.resize((w, h), resample=Image.BICUBIC)
img = img.crop((left, up, right, below))
if mask is not None:
mask = mask.resize((w, h), resample=Image.BICUBIC)
mask = mask.crop((left, up, right, below))
lm = np.stack([lm[:, 0] - t[0] + w0 / 2, lm[:, 1] - t[1] + h0 / 2],
axis=1) * s
lm = lm - np.reshape(
np.array([(w / 2 - target_size / 2),
(h / 2 - target_size / 2)]), [1, 2])
return img, lm, mask
def align_img(self,
img,
lm,
lm3D,
mask=None,
target_size=224.,
rescale_factor=102.):
w0, h0 = img.size
lm5p = lm
t, s = self.POS(lm5p.transpose(), lm3D.transpose())
s = rescale_factor / s
img_new, lm_new, mask_new = self.resize_n_crop_img(
img, lm, t, s, target_size=target_size, mask=mask)
trans_params = np.array([w0, h0, s, t[0], t[1]], dtype=object)
return trans_params, img_new, lm_new, mask_new
def crop_image(self, img, lm):
_, H = img.size
lm[:, -1] = H - 1 - lm[:, -1]
target_size = 1024.
rescale_factor = 300
center_crop_size = 700
output_size = 512
_, im_high, _, _, = self.align_img(
img,
lm,
self.lm3d_std,
target_size=target_size,
rescale_factor=rescale_factor)
left = int(im_high.size[0] / 2 - center_crop_size / 2)
upper = int(im_high.size[1] / 2 - center_crop_size / 2)
right = left + center_crop_size
lower = upper + center_crop_size
im_cropped = im_high.crop((left, upper, right, lower))
im_cropped = im_cropped.resize((output_size, output_size),
resample=Image.LANCZOS)
logger.info('crop image done!')
return im_cropped
def create_samples(self, N=256, voxel_origin=[0, 0, 0], cube_length=2.0):
voxel_origin = np.array(voxel_origin) - cube_length / 2
voxel_size = cube_length / (N - 1)
overall_index = torch.arange(0, N**3, 1, out=torch.LongTensor())
samples = torch.zeros(N**3, 3)
samples[:, 2] = overall_index % N
samples[:, 1] = (overall_index.float() / N) % N
samples[:, 0] = ((overall_index.float() / N) / N) % N
samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]
return samples.unsqueeze(0), voxel_origin, voxel_size
def numpy_array_to_video(self, numpy_list, video_out_path):
assert len(numpy_list) > 0
video_height = numpy_list[0].shape[0]
video_width = numpy_list[0].shape[1]
out_video_size = (video_width, video_height)
output_video_fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
video_write_capture = cv2.VideoWriter(video_out_path,
output_video_fourcc, 30,
out_video_size)
for frame in numpy_list:
frame_bgr = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
video_write_capture.write(frame_bgr)
video_write_capture.release()
def inference(self, image_path, save_dir):
basename = os.path.basename(image_path).split('.')[0]
img = Image.open(image_path).convert('RGB')
img_array = np.array(img)
img_bgr = img_array[:, :, ::-1]
landmark = self.detect_face(img_array)
if landmark is None:
logger.warning('No face detected in the image!')
f5p = self.get_f5p(landmark, img_bgr)
logger.info('f5p is:{}'.format(f5p))
img_cropped = self.crop_image(img, f5p)
img_cropped.save(os.path.join(save_dir, 'crop.jpg'))
in_image = self.image_transform(img_cropped).unsqueeze(0).to(
self.device)
input = torch.cat((in_image, self.coord), 1)
save_video_path = os.path.join(save_dir, f'{basename}.mp4')
pred_imgs = []
for frame_idx in range(self.num_frames):
cam_pivot = torch.tensor([0, 0, 0.2], device=self.device)
cam2world_pose = LookAtPoseSampler.sample(
3.14 / 2 + self.yaw_range
* np.sin(2 * 3.14 * frame_idx / self.num_frames),
3.14 / 2 - 0.05 + self.pitch_range
* np.cos(2 * 3.14 * frame_idx / self.num_frames),
cam_pivot,
radius=self.cam_radius,
device=self.device)
camera_params = torch.cat([
cam2world_pose.reshape(-1, 16),
self.intrinsics.reshape(-1, 9)
], 1)
conditioning_cam2world_pose = LookAtPoseSampler.sample(
np.pi / 2,
np.pi / 2,
cam_pivot,
radius=self.cam_radius,
device=self.device)
conditioning_params = torch.cat([
conditioning_cam2world_pose.reshape(-1, 16),
self.intrinsics.reshape(-1, 9)
], 1)
z = torch.from_numpy(np.random.randn(1,
self.z_dim)).to(self.device)
with torch.no_grad():
ws = self.netG.mapping(
z,
conditioning_params,
truncation_psi=self.truncation_psi,
truncation_cutoff=self.truncation_cutoff)
planes, pred_depth, pred_feature, pred_rgb, pred_sr, _, _, _, _ = self.model(
ws, input, camera_params, None)
pred_img = (pred_sr.permute(0, 2, 3, 1) * 127.5 + 128).clamp(
0, 255).to(torch.uint8)
pred_img = pred_img.squeeze().cpu().numpy()
if self.save_images:
cv2.imwrite(
os.path.join(save_dir, '{}.jpg'.format(frame_idx)),
pred_img[:, :, ::-1])
pred_imgs.append(pred_img)
self.numpy_array_to_video(pred_imgs, save_video_path)
if self.save_shape:
max_batch = 1000000
samples, voxel_origin, voxel_size = self.create_samples(
N=self.shape_res,
voxel_origin=[0, 0, 0],
cube_length=self.box_warp)
samples = samples.to(z.device)
sigmas = torch.zeros((samples.shape[0], samples.shape[1], 1),
device=z.device)
transformed_ray_directions_expanded = torch.zeros(
(samples.shape[0], max_batch, 3), device=z.device)
transformed_ray_directions_expanded[..., -1] = -1
head = 0
with torch.no_grad():
while head < samples.shape[1]:
torch.manual_seed(0)
sigma = self.model.sample(
samples[:, head:head + max_batch],
transformed_ray_directions_expanded[:, :samples.
shape[1] - head],
planes)['sigma']
sigmas[:, head:head + max_batch] = sigma
head += max_batch
sigmas = sigmas.reshape((self.shape_res, self.shape_res,
self.shape_res)).cpu().numpy()
sigmas = np.flip(sigmas, 0)
pad = int(30 * self.shape_res / 256)
pad_value = -1000
sigmas[:pad] = pad_value
sigmas[-pad:] = pad_value
sigmas[:, :pad] = pad_value
sigmas[:, -pad:] = pad_value
sigmas[:, :, :pad] = pad_value
sigmas[:, :, -pad:] = pad_value
convert_sdf_samples_to_ply(
np.transpose(sigmas, (2, 1, 0)), [0, 0, 0],
1,
os.path.join(save_dir, f'{basename}.ply'),
level=10)
logger.info('model inference done')

View File

@@ -0,0 +1,195 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""
Helper functions for constructing camera parameter matrices. Primarily used in visualization and inference scripts.
"""
import math
import torch
import torch.nn as nn
from .volumetric_rendering import math_utils
class GaussianCameraPoseSampler:
"""
Samples pitch and yaw from a Gaussian distribution and returns a camera pose.
Camera is specified as looking at the origin.
If horizontal and vertical stddev (specified in radians) are zero, gives a
deterministic camera pose with yaw=horizontal_mean, pitch=vertical_mean.
The coordinate system is specified with y-up, z-forward, x-left.
Horizontal mean is the azimuthal angle (rotation around y axis) in radians,
vertical mean is the polar angle (angle from the y axis) in radians.
A point along the z-axis has azimuthal_angle=0, polar_angle=pi/2.
Example:
For a camera pose looking at the origin with the camera at position [0, 0, 1]:
cam2world = GaussianCameraPoseSampler.sample(math.pi/2, math.pi/2, radius=1)
"""
@staticmethod
def sample(horizontal_mean,
vertical_mean,
horizontal_stddev=0,
vertical_stddev=0,
radius=1,
batch_size=1,
device='cpu'):
h = torch.randn((batch_size, 1),
device=device) * horizontal_stddev + horizontal_mean
v = torch.randn(
(batch_size, 1), device=device) * vertical_stddev + vertical_mean
v = torch.clamp(v, 1e-5, math.pi - 1e-5)
theta = h
v = v / math.pi
phi = torch.arccos(1 - 2 * v)
camera_origins = torch.zeros((batch_size, 3), device=device)
camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi
- theta)
camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi
- theta)
camera_origins[:, 1:2] = radius * torch.cos(phi)
forward_vectors = math_utils.normalize_vecs(-camera_origins)
return create_cam2world_matrix(forward_vectors, camera_origins)
class LookAtPoseSampler:
"""
Same as GaussianCameraPoseSampler, except the
camera is specified as looking at 'lookat_position', a 3-vector.
Example:
For a camera pose looking at the origin with the camera at position [0, 0, 1]:
cam2world = LookAtPoseSampler.sample(math.pi/2, math.pi/2, torch.tensor([0, 0, 0]), radius=1)
"""
@staticmethod
def sample(horizontal_mean,
vertical_mean,
lookat_position,
horizontal_stddev=0,
vertical_stddev=0,
radius=1,
batch_size=1,
device='cpu'):
h = torch.randn((batch_size, 1),
device=device) * horizontal_stddev + horizontal_mean
v = torch.randn(
(batch_size, 1), device=device) * vertical_stddev + vertical_mean
v = torch.clamp(v, 1e-5, math.pi - 1e-5)
theta = h
v = v / math.pi
phi = torch.arccos(1 - 2 * v)
camera_origins = torch.zeros((batch_size, 3), device=device)
camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi
- theta)
camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi
- theta)
camera_origins[:, 1:2] = radius * torch.cos(phi)
# forward_vectors = math_utils.normalize_vecs(-camera_origins)
forward_vectors = math_utils.normalize_vecs(lookat_position
- camera_origins)
return create_cam2world_matrix(forward_vectors, camera_origins)
class UniformCameraPoseSampler:
"""
Same as GaussianCameraPoseSampler, except the
pose is sampled from a uniform distribution with range +-[horizontal/vertical]_stddev.
Example:
For a batch of random camera poses looking at the origin with yaw sampled from [-pi/2, +pi/2] radians:
cam2worlds = UniformCameraPoseSampler.sample
(math.pi/2, math.pi/2, horizontal_stddev=math.pi/2, radius=1, batch_size=16)
"""
@staticmethod
def sample(horizontal_mean,
vertical_mean,
horizontal_stddev=0,
vertical_stddev=0,
radius=1,
batch_size=1,
device='cpu'):
h = (torch.rand((batch_size, 1), device=device) * 2
- 1) * horizontal_stddev + horizontal_mean
v = (torch.rand((batch_size, 1), device=device) * 2
- 1) * vertical_stddev + vertical_mean
v = torch.clamp(v, 1e-5, math.pi - 1e-5)
theta = h
v = v / math.pi
phi = torch.arccos(1 - 2 * v)
camera_origins = torch.zeros((batch_size, 3), device=device)
camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi
- theta)
camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi
- theta)
camera_origins[:, 1:2] = radius * torch.cos(phi)
forward_vectors = math_utils.normalize_vecs(-camera_origins)
return create_cam2world_matrix(forward_vectors, camera_origins)
def create_cam2world_matrix(forward_vector, origin):
"""
Takes in the direction the camera is pointing and the camera origin and returns a cam2world matrix.
Works on batches of forward_vectors, origins. Assumes y-axis is up and that there is no camera roll.
"""
forward_vector = math_utils.normalize_vecs(forward_vector)
up_vector = torch.tensor([0, 1, 0],
dtype=torch.float,
device=origin.device).expand_as(forward_vector)
right_vector = -math_utils.normalize_vecs(
torch.cross(up_vector, forward_vector, dim=-1))
up_vector = math_utils.normalize_vecs(
torch.cross(forward_vector, right_vector, dim=-1))
rotation_matrix = torch.eye(
4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0],
1, 1)
rotation_matrix[:, :3, :3] = torch.stack(
(right_vector, up_vector, forward_vector), axis=-1)
translation_matrix = torch.eye(
4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0],
1, 1)
translation_matrix[:, :3, 3] = origin
cam2world = (translation_matrix @ rotation_matrix)[:, :, :]
assert (cam2world.shape[1:] == (4, 4))
return cam2world
def FOV_to_intrinsics(fov_degrees, device='cpu'):
"""
Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees.
Note the intrinsics are returned as normalized by image size, rather than in pixel units.
Assumes principal point is at image center.
"""
focal_length = float(1 / (math.tan(fov_degrees * 3.14159 / 360) * 1.414))
intrinsics = torch.tensor(
[[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]],
device=device)
return intrinsics

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,65 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""
Utils for extracting 3D shapes using marching cubes. Based on code from DeepSDF (Park et al.)
Takes as input an .mrc file and extracts a mesh.
Ex.
python shape_utils.py my_shape.mrc
Ex.
python shape_utils.py myshapes_directory --level=12
"""
import numpy as np
import plyfile
import skimage.measure
def convert_sdf_samples_to_ply(numpy_3d_sdf_tensor,
voxel_grid_origin,
voxel_size,
ply_filename_out,
offset=None,
scale=None,
level=0.0):
verts, faces, normals, values = skimage.measure.marching_cubes(
numpy_3d_sdf_tensor, level=level, spacing=[voxel_size] * 3)
mesh_points = np.zeros_like(verts)
mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]
if scale is not None:
mesh_points = mesh_points / scale
if offset is not None:
mesh_points = mesh_points - offset
num_verts = verts.shape[0]
num_faces = faces.shape[0]
verts_tuple = np.zeros((num_verts, ),
dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
for i in range(0, num_verts):
verts_tuple[i] = tuple(mesh_points[i, :])
faces_building = []
for i in range(0, num_faces):
faces_building.append(((faces[i, :].tolist(), )))
faces_tuple = np.array(
faces_building, dtype=[('vertex_indices', 'i4', (3, ))])
el_verts = plyfile.PlyElement.describe(verts_tuple, 'vertex')
el_faces = plyfile.PlyElement.describe(faces_tuple, 'face')
ply_data = plyfile.PlyData([el_verts, el_faces])
ply_data.write(ply_filename_out)

View File

@@ -0,0 +1,493 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""Superresolution network architectures from the paper
"Efficient Geometry-aware 3D Generative Adversarial Networks"."""
import numpy as np
import torch
from modelscope.ops.image_control_3d_portrait.torch_utils import (misc,
persistence)
from modelscope.ops.image_control_3d_portrait.torch_utils.ops import upfirdn2d
from .networks_stylegan2 import (Conv2dLayer, SynthesisBlock, SynthesisLayer,
ToRGBLayer)
# for 512x512 generation
@persistence.persistent_class
class SuperresolutionHybrid8X(torch.nn.Module):
def __init__(
self,
channels,
img_resolution,
sr_num_fp16_res,
sr_antialias,
num_fp16_res=4,
conv_clamp=None,
channel_base=None,
channel_max=None, # IGNORE
**block_kwargs):
super().__init__()
assert img_resolution == 512
use_fp16 = sr_num_fp16_res > 0
self.input_resolution = 128
self.sr_antialias = sr_antialias
self.block0 = SynthesisBlock(
channels,
128,
w_dim=512,
resolution=256,
img_channels=3,
is_last=False,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.block1 = SynthesisBlock(
128,
64,
w_dim=512,
resolution=512,
img_channels=3,
is_last=True,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.register_buffer('resample_filter',
upfirdn2d.setup_filter([1, 3, 3, 1]))
def forward(self, rgb, x, ws, **block_kwargs):
ws = ws[:, -1:, :].repeat(1, 3, 1)
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(
x,
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
rgb = torch.nn.functional.interpolate(
rgb,
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
return rgb
# for 256x256 generation
@persistence.persistent_class
class SuperresolutionHybrid4X(torch.nn.Module):
def __init__(
self,
channels,
img_resolution,
sr_num_fp16_res,
sr_antialias,
num_fp16_res=4,
conv_clamp=None,
channel_base=None,
channel_max=None, # IGNORE
**block_kwargs):
super().__init__()
assert img_resolution == 256
use_fp16 = sr_num_fp16_res > 0
self.sr_antialias = sr_antialias
self.input_resolution = 128
self.block0 = SynthesisBlockNoUp(
channels,
128,
w_dim=512,
resolution=128,
img_channels=3,
is_last=False,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.block1 = SynthesisBlock(
128,
64,
w_dim=512,
resolution=256,
img_channels=3,
is_last=True,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.register_buffer('resample_filter',
upfirdn2d.setup_filter([1, 3, 3, 1]))
def forward(self, rgb, x, ws, **block_kwargs):
ws = ws[:, -1:, :].repeat(1, 3, 1)
if x.shape[-1] < self.input_resolution:
x = torch.nn.functional.interpolate(
x,
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
rgb = torch.nn.functional.interpolate(
rgb,
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
return rgb
# for 128 x 128 generation
@persistence.persistent_class
class SuperresolutionHybrid2X(torch.nn.Module):
def __init__(
self,
channels,
img_resolution,
sr_num_fp16_res,
sr_antialias,
num_fp16_res=4,
conv_clamp=None,
channel_base=None,
channel_max=None, # IGNORE
**block_kwargs):
super().__init__()
assert img_resolution == 128
use_fp16 = sr_num_fp16_res > 0
self.input_resolution = 64
self.sr_antialias = sr_antialias
self.block0 = SynthesisBlockNoUp(
channels,
128,
w_dim=512,
resolution=64,
img_channels=3,
is_last=False,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.block1 = SynthesisBlock(
128,
64,
w_dim=512,
resolution=128,
img_channels=3,
is_last=True,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.register_buffer('resample_filter',
upfirdn2d.setup_filter([1, 3, 3, 1]))
def forward(self, rgb, x, ws, **block_kwargs):
ws = ws[:, -1:, :].repeat(1, 3, 1)
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(
x,
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
rgb = torch.nn.functional.interpolate(
rgb,
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
return rgb
# TODO: Delete (here for backwards compatibility with old 256x256 models)
@persistence.persistent_class
class SuperresolutionHybridDeepfp32(torch.nn.Module):
def __init__(
self,
channels,
img_resolution,
sr_num_fp16_res,
num_fp16_res=4,
conv_clamp=None,
channel_base=None,
channel_max=None, # IGNORE
**block_kwargs):
super().__init__()
assert img_resolution == 256
use_fp16 = sr_num_fp16_res > 0
self.input_resolution = 128
self.block0 = SynthesisBlockNoUp(
channels,
128,
w_dim=512,
resolution=128,
img_channels=3,
is_last=False,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.block1 = SynthesisBlock(
128,
64,
w_dim=512,
resolution=256,
img_channels=3,
is_last=True,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.register_buffer('resample_filter',
upfirdn2d.setup_filter([1, 3, 3, 1]))
def forward(self, rgb, x, ws, **block_kwargs):
ws = ws[:, -1:, :].repeat(1, 3, 1)
if x.shape[-1] < self.input_resolution:
x = torch.nn.functional.interpolate(
x,
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False)
rgb = torch.nn.functional.interpolate(
rgb,
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False)
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
return rgb
@persistence.persistent_class
class SynthesisBlockNoUp(torch.nn.Module):
def __init__(
self,
in_channels, # Number of input channels, 0 = first block.
out_channels, # Number of output channels.
w_dim, # Intermediate latent (W) dimensionality.
resolution, # Resolution of this block.
img_channels, # Number of output color channels.
is_last, # Is this the last block?
architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
resample_filter=[
1, 3, 3, 1
], # Low-pass filter to apply when resampling activations.
conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping.
use_fp16=False, # Use FP16 for this block?
fp16_channels_last=False, # Use channels-last memory format with FP16?
fused_modconv_default=True, # Default value of fused_modconv.
**layer_kwargs, # Arguments for SynthesisLayer.
):
assert architecture in ['orig', 'skip', 'resnet']
super().__init__()
self.in_channels = in_channels
self.w_dim = w_dim
self.resolution = resolution
self.img_channels = img_channels
self.is_last = is_last
self.architecture = architecture
self.use_fp16 = use_fp16
self.channels_last = (use_fp16 and fp16_channels_last)
self.fused_modconv_default = fused_modconv_default
self.register_buffer('resample_filter',
upfirdn2d.setup_filter(resample_filter))
self.num_conv = 0
self.num_torgb = 0
if in_channels == 0:
self.const = torch.nn.Parameter(
torch.randn([out_channels, resolution, resolution]))
if in_channels != 0:
self.conv0 = SynthesisLayer(
in_channels,
out_channels,
w_dim=w_dim,
resolution=resolution,
conv_clamp=conv_clamp,
channels_last=self.channels_last,
**layer_kwargs)
self.num_conv += 1
self.conv1 = SynthesisLayer(
out_channels,
out_channels,
w_dim=w_dim,
resolution=resolution,
conv_clamp=conv_clamp,
channels_last=self.channels_last,
**layer_kwargs)
self.num_conv += 1
if is_last or architecture == 'skip':
self.torgb = ToRGBLayer(
out_channels,
img_channels,
w_dim=w_dim,
conv_clamp=conv_clamp,
channels_last=self.channels_last)
self.num_torgb += 1
if in_channels != 0 and architecture == 'resnet':
self.skip = Conv2dLayer(
in_channels,
out_channels,
kernel_size=1,
bias=False,
up=2,
resample_filter=resample_filter,
channels_last=self.channels_last)
def forward(self,
x,
img,
ws,
force_fp32=False,
fused_modconv=None,
update_emas=False,
**layer_kwargs):
_ = update_emas # unused
misc.assert_shape(ws,
[None, self.num_conv + self.num_torgb, self.w_dim])
w_iter = iter(ws.unbind(dim=1))
if ws.device.type != 'cuda':
force_fp32 = True
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
if fused_modconv is None:
fused_modconv = self.fused_modconv_default
if fused_modconv == 'inference_only':
fused_modconv = (not self.training)
# Input.
if self.in_channels == 0:
x = self.const.to(dtype=dtype, memory_format=memory_format)
x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
else:
misc.assert_shape(
x, [None, self.in_channels, self.resolution, self.resolution])
x = x.to(dtype=dtype, memory_format=memory_format)
# Main layers.
if self.in_channels == 0:
x = self.conv1(
x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
elif self.architecture == 'resnet':
y = self.skip(x, gain=np.sqrt(0.5))
x = self.conv0(
x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
x = self.conv1(
x,
next(w_iter),
fused_modconv=fused_modconv,
gain=np.sqrt(0.5),
**layer_kwargs)
x = y.add_(x)
else:
x = self.conv0(
x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
x = self.conv1(
x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
# ToRGB.
# if img is not None:
# misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
# img = upfirdn2d.upsample2d(img, self.resample_filter)
if self.is_last or self.architecture == 'skip':
y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
y = y.to(
dtype=torch.float32, memory_format=torch.contiguous_format)
img = img.add_(y) if img is not None else y
assert x.dtype == dtype
assert img is None or img.dtype == torch.float32
return x, img
def extra_repr(self):
return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
# for 512x512 generation
@persistence.persistent_class
class SuperresolutionHybrid8XDC(torch.nn.Module):
def __init__(
self,
channels,
img_resolution,
sr_num_fp16_res,
sr_antialias,
num_fp16_res=4,
conv_clamp=None,
channel_base=None,
channel_max=None, # IGNORE
**block_kwargs):
super().__init__()
assert img_resolution == 512
use_fp16 = sr_num_fp16_res > 0
self.input_resolution = 128
self.sr_antialias = sr_antialias
self.block0 = SynthesisBlock(
channels,
256,
w_dim=512,
resolution=256,
img_channels=3,
is_last=False,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.block1 = SynthesisBlock(
256,
128,
w_dim=512,
resolution=512,
img_channels=3,
is_last=True,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
def forward(self, rgb, x, ws, **block_kwargs):
ws = ws[:, -1:, :].repeat(1, 3, 1)
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(
x,
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
rgb = torch.nn.functional.interpolate(
rgb,
size=(self.input_resolution, self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
return rgb

View File

@@ -0,0 +1,242 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
from modelscope.ops.image_control_3d_portrait.torch_utils import persistence
from .networks_stylegan2 import FullyConnectedLayer
from .networks_stylegan2 import Generator as StyleGAN2Backbone
from .superresolution import SuperresolutionHybrid8XDC
from .volumetric_rendering.ray_sampler import RaySampler
from .volumetric_rendering.renderer import ImportanceRenderer
@persistence.persistent_class
class TriPlaneGenerator(torch.nn.Module):
def __init__(
self,
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output resolution.
img_channels, # Number of output color channels.
sr_num_fp16_res=0,
mapping_kwargs={}, # Arguments for MappingNetwork.
rendering_kwargs={},
sr_kwargs={},
**synthesis_kwargs, # Arguments for SynthesisNetwork.
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.img_resolution = img_resolution
self.img_channels = img_channels
self.renderer = ImportanceRenderer()
self.ray_sampler = RaySampler()
self.backbone = StyleGAN2Backbone(
z_dim,
c_dim,
w_dim,
img_resolution=256,
img_channels=32 * 3,
mapping_kwargs=mapping_kwargs,
**synthesis_kwargs)
self.superresolution = SuperresolutionHybrid8XDC(
channels=32,
img_resolution=img_resolution,
sr_num_fp16_res=sr_num_fp16_res,
sr_antialias=rendering_kwargs['sr_antialias'],
**sr_kwargs)
self.decoder = OSGDecoder(
32, {
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
'decoder_output_dim': 32
})
self.neural_rendering_resolution = 64
self.rendering_kwargs = rendering_kwargs
self._last_planes = None
def mapping(self,
z,
c,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False):
if self.rendering_kwargs['c_gen_conditioning_zero']:
c = torch.zeros_like(c)
return self.backbone.mapping(
z,
c * self.rendering_kwargs.get('c_scale', 0),
truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
update_emas=update_emas)
def synthesis(self,
ws,
c,
neural_rendering_resolution=None,
update_emas=False,
cache_backbone=False,
use_cached_backbone=False,
**synthesis_kwargs):
cam2world_matrix = c[:, :16].view(-1, 4, 4)
intrinsics = c[:, 16:25].view(-1, 3, 3)
if neural_rendering_resolution is None:
neural_rendering_resolution = self.neural_rendering_resolution
else:
self.neural_rendering_resolution = neural_rendering_resolution
# Create a batch of rays for volume rendering
ray_origins, ray_directions = self.ray_sampler(
cam2world_matrix, intrinsics, neural_rendering_resolution)
# Create triplanes by running StyleGAN backbone
N, M, _ = ray_origins.shape
if use_cached_backbone and self._last_planes is not None:
planes = self._last_planes
else:
planes = self.backbone.synthesis(
ws, update_emas=update_emas, **synthesis_kwargs)
if cache_backbone:
self._last_planes = planes
# Reshape output into three 32-channel planes
planes = planes.view(
len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
# Perform volume rendering
feature_samples, depth_samples, weights_samples = self.renderer(
planes, self.decoder, ray_origins, ray_directions,
self.rendering_kwargs) # channels last
# Reshape into 'raw' neural-rendered image
H = W = self.neural_rendering_resolution
feature_image = feature_samples.permute(0, 2, 1).reshape(
N, feature_samples.shape[-1], H, W).contiguous()
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
# Run superresolution to get final image
rgb_image = feature_image[:, :3]
sr_image = self.superresolution(
rgb_image,
feature_image,
ws,
noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
**{
k: synthesis_kwargs[k]
for k in synthesis_kwargs.keys() if k != 'noise_mode'
})
return {
'image': sr_image,
'image_raw': rgb_image,
'image_depth': depth_image
}
def sample(self,
coordinates,
directions,
z,
c,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False,
**synthesis_kwargs):
# Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes.
ws = self.mapping(
z,
c,
truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
update_emas=update_emas)
planes = self.backbone.synthesis(
ws, update_emas=update_emas, **synthesis_kwargs)
planes = planes.view(
len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
return self.renderer.run_model(planes, self.decoder, coordinates,
directions, self.rendering_kwargs)
def sample_mixed(self,
coordinates,
directions,
ws,
truncation_psi=1,
truncation_cutoff=None,
update_emas=False,
**synthesis_kwargs):
# Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
planes = self.backbone.synthesis(
ws, update_emas=update_emas, **synthesis_kwargs)
planes = planes.view(
len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
return self.renderer.run_model(planes, self.decoder, coordinates,
directions, self.rendering_kwargs)
def forward(self,
z,
c,
truncation_psi=1,
truncation_cutoff=None,
neural_rendering_resolution=None,
update_emas=False,
cache_backbone=False,
use_cached_backbone=False,
**synthesis_kwargs):
# Render a batch of generated images.
ws = self.mapping(
z,
c,
truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
update_emas=update_emas)
return self.synthesis(
ws,
c,
update_emas=update_emas,
neural_rendering_resolution=neural_rendering_resolution,
cache_backbone=cache_backbone,
use_cached_backbone=use_cached_backbone,
**synthesis_kwargs)
class OSGDecoder(torch.nn.Module):
def __init__(self, n_features, options):
super().__init__()
self.hidden_dim = 64
self.net = torch.nn.Sequential(
FullyConnectedLayer(
n_features,
self.hidden_dim,
lr_multiplier=options['decoder_lr_mul']), torch.nn.Softplus(),
FullyConnectedLayer(
self.hidden_dim,
1 + options['decoder_output_dim'],
lr_multiplier=options['decoder_lr_mul']))
def forward(self, sampled_features, ray_directions):
# Aggregate features
sampled_features = sampled_features.mean(1)
x = sampled_features
N, M, C = x.shape
x = x.view(N * M, C)
x = self.net(x)
x = x.view(N, M, -1)
rgb = torch.sigmoid(x[..., 1:]) * (
1 + 2 * 0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
sigma = x[..., 0:1]
return {'rgb': rgb, 'sigma': sigma}

View File

@@ -0,0 +1,697 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
from functools import partial
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from .networks_stylegan2 import FullyConnectedLayer
from .superresolution import SuperresolutionHybrid8XDC
from .volumetric_rendering.ray_sampler import RaySampler
from .volumetric_rendering.renderer import ImportanceRenderer
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv = DWConv(hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.,
proj_drop=0.,
sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f'dim {dim} should be divided by num_heads {num_heads}.'
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(
dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads,
C // self.num_heads).permute(0, 2, 1, 3)
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads,
C // self.num_heads).permute(
2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads,
C // self.num_heads).permute(
2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
sr_ratio=1):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
sr_ratio=sr_ratio)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
return x
class OverlapPatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self,
img_size=224,
patch_size=7,
stride=4,
in_chans=3,
embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.H, self.W = img_size[0] // patch_size[0], img_size[
1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class Encoder_low(nn.Module):
def __init__(self,
img_size=64,
depth=5,
in_chans=256,
embed_dims=1024,
num_head=4,
mlp_ratio=2,
sr_ratio=1,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6)):
super().__init__()
self.depth = depth
self.deeplabnet = smp.DeepLabV3(
encoder_name='resnet34',
encoder_depth=5,
encoder_weights=None,
decoder_channels=256,
in_channels=5,
classes=1)
self.deeplabnet.encoder.conv1 = nn.Conv2d(
5,
64,
kernel_size=(7, 7),
stride=(2, 2),
padding=(3, 3),
bias=False)
self.deeplabnet.segmentation_head = nn.Sequential()
self.deeplabnet.encoder.bn1 = nn.Sequential()
self.deeplabnet.encoder.layer1[0].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer1[0].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer1[1].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer1[1].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer1[2].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer1[2].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer2[0].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer2[0].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer2[0].downsample[1] = nn.Sequential()
self.deeplabnet.encoder.layer2[1].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer2[1].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer2[2].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer2[2].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer2[3].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer2[3].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer3[0].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer3[0].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer3[0].downsample[1] = nn.Sequential()
self.deeplabnet.encoder.layer3[1].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer3[1].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer3[2].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer3[2].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer3[3].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer3[3].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer3[4].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer3[4].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer3[5].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer3[5].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer4[0].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer4[0].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer4[0].downsample[1] = nn.Sequential()
self.deeplabnet.encoder.layer4[1].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer4[1].bn2 = nn.Sequential()
self.deeplabnet.encoder.layer4[2].bn1 = nn.Sequential()
self.deeplabnet.encoder.layer4[2].bn2 = nn.Sequential()
self.deeplabnet.decoder[0].convs[0][1] = nn.Sequential()
self.deeplabnet.decoder[0].convs[1][1] = nn.Sequential()
self.deeplabnet.decoder[0].convs[2][1] = nn.Sequential()
self.deeplabnet.decoder[0].convs[3][1] = nn.Sequential()
self.deeplabnet.decoder[0].convs[4][2] = nn.Sequential()
self.deeplabnet.decoder[0].project[1] = nn.Sequential()
self.deeplabnet.decoder[2] = nn.Sequential()
self.patch_embed = OverlapPatchEmbed(
img_size=img_size,
patch_size=3,
stride=2,
in_chans=in_chans,
embed_dim=embed_dims)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
cur = 0
self.vit_block = nn.ModuleList([
Block(
dim=embed_dims,
num_heads=num_head,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[cur + i],
norm_layer=norm_layer,
sr_ratio=sr_ratio) for i in range(depth)
])
self.norm1 = norm_layer(embed_dims)
self.ps = nn.PixelShuffle(2)
self.upsample1 = nn.UpsamplingBilinear2d(scale_factor=2)
self.conv1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.ReLU()
self.upsample2 = nn.UpsamplingBilinear2d(scale_factor=2)
self.conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.ReLU()
self.conv3 = nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, input):
B = input.shape[0]
f_low = self.deeplabnet(input)
x, H, W = self.patch_embed(f_low)
for i, blk in enumerate(self.vit_block):
x = blk(x, H, W)
x = self.norm1(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
x = self.ps(x)
x = self.relu1(self.conv1(self.upsample1(x)))
x = self.relu2(self.conv2(self.upsample2(x)))
x = self.conv3(x)
return x
class Encoder_high(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(5, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.LeakyReLU(0.01)
self.conv2 = nn.Conv2d(64, 96, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.LeakyReLU(0.01)
self.conv3 = nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1)
self.relu3 = nn.LeakyReLU(0.01)
self.conv4 = nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1)
self.relu4 = nn.LeakyReLU(0.01)
self.conv5 = nn.Conv2d(96, 96, kernel_size=3, stride=1, padding=1)
self.relu5 = nn.LeakyReLU(0.01)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.relu1(self.conv1(x))
x = self.relu2(self.conv2(x))
x = self.relu3(self.conv3(x))
x = self.relu4(self.conv4(x))
x = self.relu5(self.conv5(x))
return x
class MixFeature(nn.Module):
def __init__(self,
img_size=256,
depth=1,
in_chans=128,
embed_dims=1024,
num_head=2,
mlp_ratio=2,
sr_ratio=2,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6)):
super().__init__()
self.conv1 = nn.Conv2d(192, 256, kernel_size=3, stride=1, padding=1)
self.relu1 = nn.LeakyReLU(0.01)
self.conv2 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
self.relu2 = nn.LeakyReLU(0.01)
self.patch_embed = OverlapPatchEmbed(
img_size=img_size,
patch_size=3,
stride=2,
in_chans=in_chans,
embed_dim=embed_dims)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
cur = 0
self.vit_block = nn.ModuleList([
Block(
dim=embed_dims,
num_heads=num_head,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[cur + i],
norm_layer=norm_layer,
sr_ratio=sr_ratio) for i in range(depth)
])
self.norm1 = norm_layer(embed_dims)
self.ps = nn.PixelShuffle(2)
self.conv3 = nn.Conv2d(352, 256, kernel_size=3, stride=1, padding=1)
self.relu3 = nn.LeakyReLU(0.01)
self.conv4 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
self.relu4 = nn.LeakyReLU(0.01)
self.conv5 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.relu5 = nn.LeakyReLU(0.01)
self.conv6 = nn.Conv2d(128, 96, kernel_size=3, stride=1, padding=1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x_low, x_high):
x = torch.cat((x_low, x_high), 1)
B = x.shape[0]
x = self.relu1(self.conv1(x))
x = self.relu2(self.conv2(x))
x, H, W = self.patch_embed(x)
for i, blk in enumerate(self.vit_block):
x = blk(x, H, W)
x = self.norm1(x)
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
x = self.ps(x)
x = torch.cat((x, x_low), 1)
x = self.relu3(self.conv3(x))
x = self.relu4(self.conv4(x))
x = self.relu5(self.conv5(x))
x = self.conv6(x)
return x
class OSGDecoder(torch.nn.Module):
def __init__(self, n_features, options):
super().__init__()
self.hidden_dim = 64
self.net = torch.nn.Sequential(
FullyConnectedLayer(
n_features,
self.hidden_dim,
lr_multiplier=options['decoder_lr_mul']), torch.nn.Softplus(),
FullyConnectedLayer(
self.hidden_dim,
1 + options['decoder_output_dim'],
lr_multiplier=options['decoder_lr_mul']))
def forward(self, sampled_features, ray_directions):
sampled_features = sampled_features.mean(1)
x = sampled_features
N, M, C = x.shape
x = x.view(N * M, C)
x = self.net(x)
x = x.view(N, M, -1)
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
sigma = x[..., 0:1]
return {'rgb': rgb, 'sigma': sigma}
class TriplaneEncoder(nn.Module):
def __init__(self,
img_resolution,
sr_num_fp16_res=0,
rendering_kwargs={},
sr_kwargs={}):
super().__init__()
self.encoder_low = Encoder_low(
img_size=64,
depth=5,
in_chans=256,
embed_dims=1024,
num_head=4,
mlp_ratio=2,
sr_ratio=1,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6))
self.encoder_high = Encoder_high()
self.mix = MixFeature(
img_size=256,
depth=1,
in_chans=128,
embed_dims=1024,
num_head=2,
mlp_ratio=2,
sr_ratio=2,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6))
self.renderer = ImportanceRenderer()
self.ray_sampler = RaySampler()
self.superresolution = SuperresolutionHybrid8XDC(
channels=32,
img_resolution=img_resolution,
sr_num_fp16_res=sr_num_fp16_res,
sr_antialias=rendering_kwargs['sr_antialias'],
**sr_kwargs)
self.decoder = OSGDecoder(
32, {
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
'decoder_output_dim': 32
})
self.neural_rendering_resolution = 128
self.rendering_kwargs = rendering_kwargs
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def gen_interfeats(self, ws, planes, camera_params):
planes = planes.view(
len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
cam2world_matrix = camera_params[:, :16].view(-1, 4, 4)
intrinsics = camera_params[:, 16:25].view(-1, 3, 3)
H = W = self.neural_rendering_resolution
ray_origins, ray_directions = self.ray_sampler(
cam2world_matrix, intrinsics, self.neural_rendering_resolution)
N, M, _ = ray_origins.shape
feature_samples, depth_samples, weights_samples = self.renderer(
planes, self.decoder, ray_origins, ray_directions,
self.rendering_kwargs)
feature_image = feature_samples.permute(0, 2, 1).reshape(
N, feature_samples.shape[-1], H, W).contiguous()
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
rgb_image = feature_image[:, :3]
sr_image = self.superresolution(
rgb_image, feature_image, ws, noise_mode='const')
return depth_image, feature_image, rgb_image, sr_image
def sample(self, coordinates, directions, planes):
planes = planes.view(
len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
return self.renderer.run_model(planes, self.decoder, coordinates,
directions, self.rendering_kwargs)
def forward(self, ws, x, camera_ref, camera_mv):
f = self.encoder_low(x)
f_high = self.encoder_high(x)
planes = self.mix(f, f_high)
depth_ref, feature_ref, rgb_ref, sr_ref = self.gen_interfeats(
ws, planes, camera_ref)
if camera_mv is not None:
depth_mv, feature_mv, rgb_mv, sr_mv = self.gen_interfeats(
ws, planes, camera_mv)
else:
depth_mv = feature_mv = rgb_mv = sr_mv = None
return planes, depth_ref, feature_ref, rgb_ref, sr_ref, depth_mv, feature_mv, rgb_mv, sr_mv
def get_parameter_number(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}

View File

@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
# empty

View File

@@ -0,0 +1,137 @@
# MIT License
# Copyright (c) 2022 Petr Kellnhofer
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import torch
def transform_vectors(matrix: torch.Tensor,
vectors4: torch.Tensor) -> torch.Tensor:
"""
Left-multiplies MxM @ NxM. Returns NxM.
"""
res = torch.matmul(vectors4, matrix.T)
return res
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
"""
Normalize vector lengths.
"""
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
def torch_dot(x: torch.Tensor, y: torch.Tensor):
"""
Dot product of two tensors.
"""
return (x * y).sum(-1)
def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor,
box_side_length):
"""
Author: Petr Kellnhofer
Intersects rays with the [-1, 1] NDC volume.
Returns min and max distance of entry.
Returns -1 for no intersection.
https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
"""
o_shape = rays_o.shape
rays_o = rays_o.detach().reshape(-1, 3)
rays_d = rays_d.detach().reshape(-1, 3)
temp_min_1 = -1 * (box_side_length / 2)
temp_min_2 = -1 * (box_side_length / 2)
temp_min_3 = -1 * (box_side_length / 2)
bb_min = [temp_min_1, temp_min_2, temp_min_3]
temp_max_1 = 1 * (box_side_length / 2)
temp_max_2 = 1 * (box_side_length / 2)
temp_max_3 = 1 * (box_side_length / 2)
bb_max = [temp_max_1, temp_max_2, temp_max_3]
bounds = torch.tensor([bb_min, bb_max],
dtype=rays_o.dtype,
device=rays_o.device)
is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
# Precompute inverse for stability.
invdir = 1 / rays_d
sign = (invdir < 0).long()
# Intersect with YZ plane.
tmin = (bounds.index_select(0, sign[..., 0])[..., 0]
- rays_o[..., 0]) * invdir[..., 0]
tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0]
- rays_o[..., 0]) * invdir[..., 0]
# Intersect with XZ plane.
tymin = (bounds.index_select(0, sign[..., 1])[..., 1]
- rays_o[..., 1]) * invdir[..., 1]
tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1]
- rays_o[..., 1]) * invdir[..., 1]
# Resolve parallel rays.
is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
# Use the shortest intersection.
tmin = torch.max(tmin, tymin)
tmax = torch.min(tmax, tymax)
# Intersect with XY plane.
tzmin = (bounds.index_select(0, sign[..., 2])[..., 2]
- rays_o[..., 2]) * invdir[..., 2]
tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2]
- rays_o[..., 2]) * invdir[..., 2]
# Resolve parallel rays.
is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
# Use the shortest intersection.
tmin = torch.max(tmin, tzmin)
tmax = torch.min(tmax, tzmax)
# Mark invalid.
tmin[torch.logical_not(is_valid)] = -1
tmax[torch.logical_not(is_valid)] = -2
return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
"""
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
"""
# create a tensor of 'num' steps from 0 to 1
steps = torch.arange(
num, dtype=torch.float32, device=start.device) / (
num - 1)
# reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
# - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
# "cannot statically infer the expected size of a list in this contex", hence the code below
for i in range(start.ndim):
steps = steps.unsqueeze(-1)
# the output starts at 'start' and increments until 'stop' in each dimension
out = start[None] + steps * (stop - start)[None]
return out

View File

@@ -0,0 +1,67 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""
The ray marcher takes the raw output of the implicit representation and
uses the volume rendering equation to produce composited colors and depths.
Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class MipRayMarcher2(nn.Module):
def __init__(self):
super().__init__()
def run_forward(self, colors, densities, depths, rendering_options):
deltas = depths[:, :, 1:] - depths[:, :, :-1]
colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
if rendering_options['clamp_mode'] == 'softplus':
densities_mid = F.softplus(
densities_mid
- 1) # activation bias of -1 makes things initialize better
else:
assert False, 'MipRayMarcher only supports `clamp_mode`=`softplus`!'
density_delta = densities_mid * deltas
alpha = 1 - torch.exp(-density_delta)
alpha_shifted = torch.cat(
[torch.ones_like(alpha[:, :, :1]), 1 - alpha + 1e-10], -2)
weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
composite_rgb = torch.sum(weights * colors_mid, -2)
weight_total = weights.sum(2)
composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
# clip the composite to min/max range of depths
composite_depth = torch.nan_to_num(composite_depth, float('inf'))
composite_depth = torch.clamp(composite_depth, torch.min(depths),
torch.max(depths))
if rendering_options.get('white_back', False):
composite_rgb = composite_rgb + 1 - weight_total
composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
return composite_rgb, composite_depth, weights
def forward(self, colors, densities, depths, rendering_options):
composite_rgb, composite_depth, weights = self.run_forward(
colors, densities, depths, rendering_options)
return composite_rgb, composite_depth, weights

View File

@@ -0,0 +1,80 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""
The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
"""
import torch
class RaySampler(torch.nn.Module):
def __init__(self):
super().__init__()
self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = \
None, None, None, None, None
def forward(self, cam2world_matrix, intrinsics, resolution):
"""
Create batches of rays and return origins and directions.
cam2world_matrix: (N, 4, 4)
intrinsics: (N, 3, 3)
resolution: int
ray_origins: (N, M, 3)
ray_dirs: (N, M, 2)
"""
N, M = cam2world_matrix.shape[0], resolution**2
cam_locs_world = cam2world_matrix[:, :3, 3]
fx = intrinsics[:, 0, 0]
fy = intrinsics[:, 1, 1]
cx = intrinsics[:, 0, 2]
cy = intrinsics[:, 1, 2]
sk = intrinsics[:, 0, 1]
uv = torch.stack(
torch.meshgrid(
torch.arange(
resolution,
dtype=torch.float32,
device=cam2world_matrix.device),
torch.arange(
resolution,
dtype=torch.float32,
device=cam2world_matrix.device))) * (1. / resolution) + (
0.5 / resolution)
uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
x_cam = uv[:, :, 0].view(N, -1)
y_cam = uv[:, :, 1].view(N, -1)
z_cam = torch.ones((N, M), device=cam2world_matrix.device)
x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)
* sk.unsqueeze(-1) / fy.unsqueeze(-1) - sk.unsqueeze(-1)
* y_cam / fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
cam_rel_points = torch.stack(
(x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
world_rel_points = torch.bmm(cam2world_matrix,
cam_rel_points.permute(0, 2, 1)).permute(
0, 2, 1)[:, :, :3]
ray_dirs = world_rel_points - cam_locs_world[:, None, :]
ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
ray_origins = cam_locs_world.unsqueeze(1).repeat(
1, ray_dirs.shape[1], 1)
return ray_origins, ray_dirs

View File

@@ -0,0 +1,341 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""
The renderer is a module that takes in rays, decides where to sample along each
ray, and computes pixel colors using the volume rendering equation.
"""
import math
import torch
import torch.nn as nn
from . import math_utils
from .ray_marcher import MipRayMarcher2
def generate_planes():
"""
Defines planes by the three vectors that form the "axes" of the
plane. Should work with arbitrary number of planes and planes of
arbitrary orientation.
"""
return torch.tensor(
[[[1, 0, 0], [0, 1, 0], [0, 0, 1]], [[1, 0, 0], [0, 0, 1], [0, 1, 0]],
[[0, 0, 1], [1, 0, 0], [0, 1, 0]]],
dtype=torch.float32)
def project_onto_planes(planes, coordinates):
"""
Does a projection of a 3D point onto a batch of 2D planes,
returning 2D plane coordinates.
Takes plane axes of shape n_planes, 3, 3
# Takes coordinates of shape N, M, 3
# returns projections of shape N*n_planes, M, 2
"""
N, M, C = coordinates.shape
n_planes, _, _ = planes.shape
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1,
-1).reshape(
N * n_planes, M, 3)
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(
N, -1, -1, -1).reshape(N * n_planes, 3, 3).to(coordinates.device)
projections = torch.bmm(coordinates, inv_planes)
return projections[..., :2]
def sample_from_planes(plane_axes,
plane_features,
coordinates,
mode='bilinear',
padding_mode='zeros',
box_warp=None):
assert padding_mode == 'zeros'
N, n_planes, C, H, W = plane_features.shape
_, M, _ = coordinates.shape
plane_features = plane_features.view(N * n_planes, C, H, W)
coordinates = (2 / box_warp) * coordinates # TODO: add specific box bounds
projected_coordinates = project_onto_planes(plane_axes,
coordinates).unsqueeze(1)
output_features = torch.nn.functional.grid_sample(
plane_features,
projected_coordinates.float(),
mode=mode,
padding_mode=padding_mode,
align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
return output_features
def sample_from_3dgrid(grid, coordinates):
"""
Expects coordinates in shape (batch_size, num_points_per_batch, 3)
Expects grid in shape (1, channels, H, W, D)
(Also works if grid has batch size)
Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
"""
batch_size, n_coords, n_dims = coordinates.shape
sampled_features = torch.nn.functional.grid_sample(
grid.expand(batch_size, -1, -1, -1, -1),
coordinates.reshape(batch_size, 1, 1, -1, n_dims),
mode='bilinear',
padding_mode='zeros',
align_corners=False)
N, C, H, W, D = sampled_features.shape
sampled_features = sampled_features.permute(0, 4, 3, 2,
1).reshape(N, H * W * D, C)
return sampled_features
class ImportanceRenderer(torch.nn.Module):
def __init__(self):
super().__init__()
self.ray_marcher = MipRayMarcher2()
self.plane_axes = generate_planes()
def forward(self, planes, decoder, ray_origins, ray_directions,
rendering_options):
self.plane_axes = self.plane_axes.to(ray_origins.device)
if rendering_options['ray_start'] == rendering_options[
'ray_end'] == 'auto':
ray_start, ray_end = math_utils.get_ray_limits_box(
ray_origins,
ray_directions,
box_side_length=rendering_options['box_warp'])
is_ray_valid = ray_end > ray_start
if torch.any(is_ray_valid).item():
ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
depths_coarse = self.sample_stratified(
ray_origins, ray_start, ray_end,
rendering_options['depth_resolution'],
rendering_options['disparity_space_sampling'])
else:
# Create stratified depth samples
depths_coarse = self.sample_stratified(
ray_origins, rendering_options['ray_start'],
rendering_options['ray_end'],
rendering_options['depth_resolution'],
rendering_options['disparity_space_sampling'])
batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape
# Coarse Pass
sample_coordinates = (
ray_origins.unsqueeze(-2)
+ depths_coarse * ray_directions.unsqueeze(-2)).reshape(
batch_size, -1, 3)
sample_directions = ray_directions.unsqueeze(-2).expand(
-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
out = self.run_model(planes, decoder, sample_coordinates,
sample_directions, rendering_options)
colors_coarse = out['rgb']
densities_coarse = out['sigma']
colors_coarse = colors_coarse.reshape(batch_size, num_rays,
samples_per_ray,
colors_coarse.shape[-1])
densities_coarse = densities_coarse.reshape(batch_size, num_rays,
samples_per_ray, 1)
# Fine Pass
N_importance = rendering_options['depth_resolution_importance']
if N_importance > 0:
_, _, weights = self.ray_marcher(colors_coarse, densities_coarse,
depths_coarse, rendering_options)
depths_fine = self.sample_importance(depths_coarse, weights,
N_importance)
sample_directions = ray_directions.unsqueeze(-2).expand(
-1, -1, N_importance, -1).reshape(batch_size, -1, 3)
sample_coordinates = (
ray_origins.unsqueeze(-2)
+ depths_fine * ray_directions.unsqueeze(-2)).reshape(
batch_size, -1, 3)
out = self.run_model(planes, decoder, sample_coordinates,
sample_directions, rendering_options)
colors_fine = out['rgb']
densities_fine = out['sigma']
colors_fine = colors_fine.reshape(batch_size, num_rays,
N_importance,
colors_fine.shape[-1])
densities_fine = densities_fine.reshape(batch_size, num_rays,
N_importance, 1)
all_depths, all_colors, all_densities = self.unify_samples(
depths_coarse, colors_coarse, densities_coarse, depths_fine,
colors_fine, densities_fine)
# Aggregate
rgb_final, depth_final, weights = self.ray_marcher(
all_colors, all_densities, all_depths, rendering_options)
else:
rgb_final, depth_final, weights = self.ray_marcher(
colors_coarse, densities_coarse, depths_coarse,
rendering_options)
return rgb_final, depth_final, weights.sum(2)
def run_model(self, planes, decoder, sample_coordinates, sample_directions,
options):
sampled_features = sample_from_planes(
self.plane_axes,
planes,
sample_coordinates,
padding_mode='zeros',
box_warp=options['box_warp'])
out = decoder(sampled_features, sample_directions)
if options.get('density_noise', 0) > 0:
out['sigma'] += torch.randn_like(
out['sigma']) * options['density_noise']
return out
def sort_samples(self, all_depths, all_colors, all_densities):
_, indices = torch.sort(all_depths, dim=-2)
all_depths = torch.gather(all_depths, -2, indices)
all_colors = torch.gather(
all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
all_densities = torch.gather(all_densities, -2,
indices.expand(-1, -1, -1, 1))
return all_depths, all_colors, all_densities
def unify_samples(self, depths1, colors1, densities1, depths2, colors2,
densities2):
all_depths = torch.cat([depths1, depths2], dim=-2)
all_colors = torch.cat([colors1, colors2], dim=-2)
all_densities = torch.cat([densities1, densities2], dim=-2)
_, indices = torch.sort(all_depths, dim=-2)
all_depths = torch.gather(all_depths, -2, indices)
all_colors = torch.gather(
all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
all_densities = torch.gather(all_densities, -2,
indices.expand(-1, -1, -1, 1))
return all_depths, all_colors, all_densities
def sample_stratified(self,
ray_origins,
ray_start,
ray_end,
depth_resolution,
disparity_space_sampling=False):
"""
Return depths of approximately uniformly spaced samples along rays.
"""
N, M, _ = ray_origins.shape
if disparity_space_sampling:
depths_coarse = torch.linspace(
0, 1, depth_resolution,
device=ray_origins.device).reshape(1, 1, depth_resolution,
1).repeat(N, M, 1, 1)
depth_delta = 1 / (depth_resolution - 1)
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
depths_coarse = 1. / (1. / ray_start * (1. - depths_coarse)
+ 1. / ray_end * depths_coarse)
else:
if type(ray_start) == torch.Tensor:
depths_coarse = math_utils.linspace(ray_start, ray_end,
depth_resolution).permute(
1, 2, 0, 3)
depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
depths_coarse += torch.rand_like(depths_coarse) * depth_delta[
..., None]
else:
depths_coarse = torch.linspace(
ray_start,
ray_end,
depth_resolution,
device=ray_origins.device).reshape(1, 1, depth_resolution,
1).repeat(N, M, 1, 1)
depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
return depths_coarse
def sample_importance(self, z_vals, weights, N_importance):
"""
Return depths of importance sampled points along rays. See NeRF importance sampling for more.
"""
with torch.no_grad():
batch_size, num_rays, samples_per_ray, _ = z_vals.shape
z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
weights = weights.reshape(
batch_size * num_rays,
-1) # -1 to account for loss of 1 sample in MipRayMarcher
# smooth weights
weights = torch.nn.functional.max_pool1d(
weights.unsqueeze(1).float(), 2, 1, padding=1)
weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
weights = weights + 0.01
z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:])
importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
N_importance).detach().reshape(
batch_size, num_rays,
N_importance, 1)
return importance_z_vals
def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
"""
Sample @N_importance samples from @bins with distribution defined by @weights.
Inputs:
bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
weights: (N_rays, N_samples_)
N_importance: the number of samples to draw from the distribution
det: deterministic or not
eps: a small number to prevent division by zero
Outputs:
samples: the sampled samples
"""
N_rays, N_samples_ = weights.shape
weights = weights + eps # prevent division by zero (don't do inplace op!)
pdf = weights / torch.sum(
weights, -1, keepdim=True) # (N_rays, N_samples_)
cdf = torch.cumsum(
pdf, -1) # (N_rays, N_samples), cumulative distribution function
cdf = torch.cat([torch.zeros_like(cdf[:, :1]), cdf],
-1) # (N_rays, N_samples_+1)
# padded to 0~1 inclusive
if det:
u = torch.linspace(0, 1, N_importance, device=bins.device)
u = u.expand(N_rays, N_importance)
else:
u = torch.rand(N_rays, N_importance, device=bins.device)
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.clamp_min(inds - 1, 0)
above = torch.clamp_max(inds, N_samples_)
inds_sampled = torch.stack([below, above],
-1).view(N_rays, 2 * N_importance)
cdf_g = torch.gather(cdf, 1,
inds_sampled).view(N_rays, N_importance, 2)
bins_g = torch.gather(bins, 1,
inds_sampled).view(N_rays, N_importance, 2)
denom = cdf_g[..., 1] - cdf_g[..., 0]
denom[denom < eps] = 1
samples = bins_g[..., 0] + (u - cdf_g[..., 0]) / denom * (
bins_g[..., 1] - bins_g[..., 0])
return samples

View File

@@ -0,0 +1,20 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .image_view_transform_infer import ImageViewTransform
else:
_import_structure = {'image_view_transform_infer': ['ImageViewTransform']}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,219 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import math
import os
import sys
import time
from contextlib import nullcontext
from functools import partial
import cv2
import diffusers # 0.12.1
import fire
import numpy as np
import rich
import torch
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image
from rich import print
from torch import autocast
from torchvision import transforms
from modelscope.fileio import load
from modelscope.metainfo import Models
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .ldm.ddim import DDIMSampler
from .util import instantiate_from_config, load_and_preprocess
logger = get_logger()
def load_model_from_config(model, config, ckpt, device, verbose=False):
print(f'Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location='cpu')
if 'global_step' in pl_sd:
print(f'Global Step: {pl_sd["global_step"]}')
sd = pl_sd['state_dict']
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.to(device)
model.eval()
return model
@MODELS.register_module(
Tasks.image_view_transform, module_name=Models.image_view_transform)
class ImageViewTransform(TorchModel):
"""initialize the image view translation model from the `model_dir` path.
Args:
model_dir (str): the model path.
"""
def __init__(self, model_dir, device='cpu', *args, **kwargs):
super().__init__(model_dir=model_dir, device=device, *args, **kwargs)
self.device = torch.device(
device if torch.cuda.is_available() else 'cpu')
config = os.path.join(model_dir,
'sd-objaverse-finetune-c_concat-256.yaml')
ckpt = os.path.join(model_dir, 'zero123-xl.ckpt')
config = OmegaConf.load(config)
self.model = None
self.model = load_model_from_config(
self.model, config, ckpt, device=self.device)
def forward(self, model_path, x, y):
pred_results = _infer(self.model, model_path, x, y, self.device)
return pred_results
def infer(genmodel, model_path, image_path, target_view_path, device):
output_ims = genmodel(model_path, image_path, target_view_path)
return output_ims
@torch.no_grad()
def sample_model(input_im, model, sampler, precision, h, w, ddim_steps,
n_samples, scale, ddim_eta, x, y, z):
precision_scope = autocast if precision == 'autocast' else nullcontext
with precision_scope('cuda'):
with model.ema_scope():
c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1)
T = torch.tensor([
math.radians(x),
math.sin(math.radians(y)),
math.cos(math.radians(y)), z
])
T = T[None, None, :].repeat(n_samples, 1, 1).to(c.device)
c = torch.cat([c, T], dim=-1)
c = model.cc_projection(c)
cond = {}
cond['c_crossattn'] = [c]
cond['c_concat'] = [
model.encode_first_stage(
(input_im.to(c.device))).mode().detach().repeat(
n_samples, 1, 1, 1)
]
if scale != 1.0:
uc = {}
uc['c_concat'] = [
torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)
]
uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)]
else:
uc = None
shape = [4, h // 8, w // 8]
samples_ddim, _ = sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
eta=ddim_eta,
x_T=None)
# samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
x_samples_ddim = model.decode_first_stage(samples_ddim)
return torch.clamp(
(x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
def preprocess_image(models, input_im, preprocess, carvekit_path):
'''
:param input_im (PIL Image).
:return input_im (H, W, 3) array in [0, 1].
'''
print('old input_im:', input_im.size)
if preprocess:
# model_carvekit = create_carvekit_interface()
model_carvekit = torch.load(carvekit_path)
input_im = load_and_preprocess(model_carvekit, input_im)
input_im = (input_im / 255.0).astype(np.float32)
# (H, W, 3) array in [0, 1].
else:
input_im = input_im.resize([256, 256], Image.Resampling.LANCZOS)
input_im = np.asarray(input_im, dtype=np.float32) / 255.0
alpha = input_im[:, :, 3:4]
white_im = np.ones_like(input_im)
input_im = alpha * input_im + (1.0 - alpha) * white_im
input_im = input_im[:, :, 0:3]
# (H, W, 3) array in [0, 1].
return input_im
def main_run(models,
device,
return_what,
x=0.0,
y=0.0,
z=0.0,
raw_im=None,
carvekit_path=None,
preprocess=True,
scale=3.0,
n_samples=4,
ddim_steps=50,
ddim_eta=1.0,
precision='fp32',
h=256,
w=256):
'''
:param raw_im (PIL Image).
'''
raw_im.thumbnail([1536, 1536], Image.Resampling.LANCZOS)
input_im = preprocess_image(models, raw_im, preprocess, carvekit_path)
if 'gen' in return_what:
input_im = transforms.ToTensor()(input_im).unsqueeze(0).to(device)
input_im = input_im * 2 - 1
input_im = transforms.functional.resize(input_im, [h, w])
sampler = DDIMSampler(models)
# used_x = -x # NOTE: Polar makes more sense in Basile's opinion this way!
used_x = x # NOTE: Set this way for consistency.
x_samples_ddim = sample_model(input_im, models, sampler, precision, h,
w, ddim_steps, n_samples, scale,
ddim_eta, used_x, y, z)
output_ims = []
for x_sample in x_samples_ddim:
image = x_sample.detach().cpu().squeeze().numpy()
image = np.transpose(image, (1, 2, 0)) * 255
image = np.uint8(image)
bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
output_ims.append(bgr)
return output_ims
def _infer(genmodel, model_path, image_path, target_view_path, device):
if isinstance(image_path, str):
raw_image = load(image_path)
print(type(raw_image))
else:
raw_image = image_path
if isinstance(target_view_path, str):
views = load(target_view_path)
else:
views = target_view_path
# views = views.astype(np.float32)
carvekit_path = os.path.join(model_path, 'carvekit.pth')
output_ims = main_run(genmodel, device, 'angles_gen', views[0], views[1],
views[2], raw_image, carvekit_path, views[3],
views[4], views[5], views[6], views[7])
return output_ims

View File

@@ -0,0 +1,294 @@
import math
from inspect import isfunction
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum, nn
from .util_diffusion import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(
qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(
out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self,
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else
None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(),
self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
disable_self_attn=False):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn) for d in range(depth)
])
self.proj_out = zero_module(
nn.Conv2d(
inner_dim, in_channels, kernel_size=1, stride=1, padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
x = self.proj_out(x)
return x + x_in

View File

@@ -0,0 +1,555 @@
from contextlib import contextmanager
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from ..util import instantiate_from_config
from .distributions import DiagonalGaussianDistribution
from .model import Decoder, Encoder
class VQModel(pl.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(
n_embed,
embed_dim,
beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer('colorize',
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(
f'{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.'
)
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f'{context}: Switched to EMA weights')
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f'{context}: Restored training weights')
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
f'Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys'
)
if len(missing) > 0:
print(f'Missing Keys: {missing}')
print(f'Unexpected Keys: {unexpected}')
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_, _, ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1,
2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(
np.arange(lower_size, upper_size + 16, 16))
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode='bicubic')
x = x.detach()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train',
predicted_indices=ind)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train')
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
self._validation_step(batch, batch_idx, suffix='_ema')
return log_dict
def _validation_step(self, batch, batch_idx, suffix=''):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(
qloss,
x,
xrec,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val' + suffix,
predicted_indices=ind)
discloss, log_dict_disc = self.loss(
qloss,
x,
xrec,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val' + suffix,
predicted_indices=ind)
rec_loss = log_dict_ae[f'val{suffix}/rec_loss']
self.log(
f'val{suffix}/rec_loss',
rec_loss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True)
self.log(
f'val{suffix}/aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=False,
on_epoch=True,
sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f'val{suffix}/rec_loss']
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor * self.learning_rate
print('lr_d', lr_d)
print('lr_g', lr_g)
opt_ae = torch.optim.Adam(
list(self.encoder.parameters()) + list(self.decoder.parameters())
+ list(self.quantize.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr_g,
betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.9))
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print('Setting up LambdaLR scheduler...')
scheduler = [
{
'scheduler':
LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
{
'scheduler':
LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
]
return [opt_ae, opt_disc], scheduler
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log['inputs'] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log['inputs'] = x
log['reconstructions'] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3:
xrec_ema = self.to_rgb(xrec_ema)
log['reconstructions_ema'] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer('colorize',
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
class AutoencoderKL(pl.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer('colorize',
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f'Restored from {path}')
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1,
2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train')
self.log(
'aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train')
self.log(
'discloss',
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val')
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val')
self.log('val/rec_loss', log_dict_ae['val/rec_loss'])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(
list(self.encoder.parameters()) + list(self.decoder.parameters())
+ list(self.quant_conv.parameters())
+ list(self.post_quant_conv.parameters()),
lr=lr,
betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log['reconstructions'] = xrec
log['inputs'] = x
return log
def to_rgb(self, x):
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer('colorize',
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

View File

@@ -0,0 +1,433 @@
"""SAMPLING ONLY."""
from functools import partial
import numpy as np
import torch
from einops import rearrange
from tqdm import tqdm
from .sampling_util import (norm_thresholding, renorm_thresholding,
spatial_norm_thresholding)
from .util_diffusion import (extract_into_tensor,
make_ddim_sampling_parameters,
make_ddim_timesteps, noise_like)
class DDIMSampler(object):
def __init__(self, model, schedule='linear', **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def to(self, device):
"""Same as to in torch module
Don't really underestand why this isn't a module in the first place"""
for k, v in self.__dict__.items():
if isinstance(v, torch.Tensor):
new_v = getattr(self, k).to(device)
setattr(self, k, new_v)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device('cuda'):
attr = attr.to(torch.device('cuda'))
setattr(self, name, attr)
def make_schedule(self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.,
verbose=True):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev',
to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas))
alpha_1 = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
alpha_2 = (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
alpha_1 * alpha_2)
self.register_buffer('ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
dynamic_threshold=None,
**kwargs):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
dynamic_threshold=None,
t_start=-1):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
timesteps = timesteps[:t_start]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
0]
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback:
img = callback(i, img, pred_x0)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat(
[unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat(
[unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
**corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
if use_original_steps:
alphas_prev = self.model.alphas_cumprod_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod
else:
alphas_prev = self.ddim_alphas_prev
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1),
sqrt_one_minus_alphas[index],
device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self,
x0,
c,
t_enc,
use_original_steps=False,
return_intermediates=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[
0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0], ),
i,
device=self.model.device,
dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(
torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (
noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
alp_1 = (1 / alphas_next[i] - 1).sqrt()
alp_2 = (1 / alphas[i] - 1).sqrt()
weighted_noise_pred = alphas_next[i].sqrt() * (
alp_1 - alp_2) * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (num_steps // return_intermediates
) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
* noise)
@torch.no_grad()
def decode(self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False):
timesteps = np.arange(self.ddpm_num_timesteps
) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0], ),
step,
device=x_latent.device,
dtype=torch.long)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
return x_dec

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,92 @@
import numpy as np
import torch
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(
self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, 'at least one argument must be a Tensor'
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
comp_1 = -1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
comp_2 = ((mean1 - mean2)**2) * torch.exp(-logvar2)
return 0.5 * (comp_1 + comp_2)

View File

@@ -0,0 +1,84 @@
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
'num_updates',
torch.tensor(0, dtype=torch.int)
if use_num_upates else torch.tensor(-1, dtype=torch.int))
for name, p in model.named_parameters():
if p.requires_grad:
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
num_1 = (1 + self.num_updates)
num_2 = (10 + self.num_updates)
decay = min(self.decay, num_1 / num_2)
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(
m_param[key])
param_1 = (shadow_params[sname] - m_param[key])
shadow_params[sname].sub_(one_minus_decay * param_1)
else:
assert key not in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(
shadow_params[self.m_name2s_name[key]].data)
else:
assert key not in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)

View File

@@ -0,0 +1,131 @@
# https://github.com/eladrich/pixel2style2pixel
from collections import namedtuple
import torch
from torch.nn import (AdaptiveAvgPool2d, BatchNorm2d, Conv2d, MaxPool2d,
Module, PReLU, ReLU, Sequential, Sigmoid)
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
pass
def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)
] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
def get_blocks(num_layers):
if num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3)
]
else:
raise ValueError(
'Invalid number of layers: {}. Must be one of [50, 100, 152]'.
format(num_layers))
return blocks
class SEModule(Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(
channels,
channels // reduction,
kernel_size=1,
padding=0,
bias=False)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(
channels // reduction,
channels,
kernel_size=1,
padding=0,
bias=False)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class bottleneck_IR(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class bottleneck_IR_SE(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR_SE, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth), Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth), SEModule(depth, 16))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut

View File

@@ -0,0 +1,27 @@
# https://github.com/eladrich/pixel2style2pixel
import torch
from torch import nn
from .model_irse import Backbone
class IDFeatures(nn.Module):
def __init__(self, model_path):
super(IDFeatures, self).__init__()
print('Loading ResNet ArcFace')
self.facenet = Backbone(
input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
self.facenet.load_state_dict(
torch.load(model_path, map_location='cpu'))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
def forward(self, x, crop=False):
# Not sure of the image range here
if crop:
x = torch.nn.functional.interpolate(x, (256, 256), mode='area')
x = x[:, :, 35:223, 32:220]
x = self.face_pool(x)
x_feats = self.facenet(x)
return x_feats

View File

@@ -0,0 +1,961 @@
# pytorch_diffusion + derived encoder decoder
import math
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from ..util import instantiate_from_config
from .attention import LinearAttention
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(
x, scale_factor=2.0, mode='nearest')
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
def make_attn(in_channels, attn_type='vanilla'):
assert attn_type in ['vanilla', 'linear',
'none'], f'attn_type {attn_type} unknown'
print(
f"making attention of type '{attn_type}' with {in_channels} in_channels"
)
if attn_type == 'vanilla':
return AttnBlock(in_channels)
elif attn_type == 'none':
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)
class Model(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type='vanilla'):
super().__init__()
if use_linear_attn:
attn_type = 'linear'
self.ch = ch
self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
])
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
skip_in = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, x, t=None, context=None):
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()],
dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.weight
class Encoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type='vanilla',
**ignore_kwargs):
super().__init__()
if use_linear_attn:
attn_type = 'linear'
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type='vanilla',
**ignorekwargs):
super().__init__()
if use_linear_attn:
attn_type = 'linear'
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList([
nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(
in_channels=in_channels,
out_channels=2 * in_channels,
temb_channels=0,
dropout=0.0),
ResnetBlock(
in_channels=2 * in_channels,
out_channels=4 * in_channels,
temb_channels=0,
dropout=0.0),
ResnetBlock(
in_channels=4 * in_channels,
out_channels=2 * in_channels,
temb_channels=0,
dropout=0.0),
nn.Conv2d(2 * in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True)
])
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
for i, layer in enumerate(self.model):
if i in [1, 2, 3]:
x = layer(x, None)
else:
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
x = self.conv_out(h)
return x
class UpsampleDecoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
ch,
num_res_blocks,
resolution,
ch_mult=(2, 2),
dropout=0.0):
super().__init__()
# upsampling
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution // 2**(self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
res_block = []
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
self.upsample_blocks.append(Upsample(block_in, True))
curr_res = curr_res * 2
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
# upsampling
h = x
for k, i_level in enumerate(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.res_blocks[i_level][i_block](h, None)
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class LatentRescaler(nn.Module):
def __init__(self,
factor,
in_channels,
mid_channels,
out_channels,
depth=2):
super().__init__()
# residual block, interpolate, residual block
self.factor = factor
self.conv_in = nn.Conv2d(
in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
self.res_block1 = nn.ModuleList([
ResnetBlock(
in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)
])
self.attn = AttnBlock(mid_channels)
self.res_block2 = nn.ModuleList([
ResnetBlock(
in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)
])
self.conv_out = nn.Conv2d(
mid_channels,
out_channels,
kernel_size=1,
)
def forward(self, x):
x = self.conv_in(x)
for block in self.res_block1:
x = block(x, None)
x = torch.nn.functional.interpolate(
x,
size=(int(round(x.shape[2] * self.factor)),
int(round(x.shape[3] * self.factor))))
x = self.attn(x)
for block in self.res_block2:
x = block(x, None)
x = self.conv_out(x)
return x
class MergedRescaleEncoder(nn.Module):
def __init__(self,
in_channels,
ch,
resolution,
out_ch,
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
ch_mult=(1, 2, 4, 8),
rescale_factor=1.0,
rescale_module_depth=1):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(
in_channels=in_channels,
num_res_blocks=num_res_blocks,
ch=ch,
ch_mult=ch_mult,
z_channels=intermediate_chn,
double_z=False,
resolution=resolution,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
out_ch=None)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=intermediate_chn,
mid_channels=intermediate_chn,
out_channels=out_ch,
depth=rescale_module_depth)
def forward(self, x):
x = self.encoder(x)
x = self.rescaler(x)
return x
class MergedRescaleDecoder(nn.Module):
def __init__(self,
z_channels,
out_ch,
resolution,
num_res_blocks,
attn_resolutions,
ch,
ch_mult=(1, 2, 4, 8),
dropout=0.0,
resamp_with_conv=True,
rescale_factor=1.0,
rescale_module_depth=1):
super().__init__()
tmp_chn = z_channels * ch_mult[-1]
self.decoder = Decoder(
out_ch=out_ch,
z_channels=tmp_chn,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
in_channels=None,
num_res_blocks=num_res_blocks,
ch_mult=ch_mult,
resolution=resolution,
ch=ch)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=z_channels,
mid_channels=tmp_chn,
out_channels=tmp_chn,
depth=rescale_module_depth)
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Upsampler(nn.Module):
def __init__(self,
in_size,
out_size,
in_channels,
out_channels,
ch_mult=2):
super().__init__()
assert out_size >= in_size
num_blocks = int(np.log2(out_size // in_size)) + 1
factor_up = 1. + (out_size % in_size)
print(
f'Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}'
)
self.rescaler = LatentRescaler(
factor=factor_up,
in_channels=in_channels,
mid_channels=2 * in_channels,
out_channels=in_channels)
self.decoder = Decoder(
out_ch=out_channels,
resolution=out_size,
z_channels=in_channels,
num_res_blocks=2,
attn_resolutions=[],
in_channels=None,
ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)])
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Resize(nn.Module):
def __init__(self, in_channels=None, learned=False, mode='bilinear'):
super().__init__()
self.with_conv = learned
self.mode = mode
if self.with_conv:
print(
f'Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode'
)
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x, scale_factor=1.0):
if scale_factor == 1.0:
return x
else:
x = torch.nn.functional.interpolate(
x,
mode=self.mode,
align_corners=False,
scale_factor=scale_factor)
return x
class FirstStagePostProcessor(nn.Module):
def __init__(self,
ch_mult: list,
in_channels,
pretrained_model: nn.Module = None,
reshape=False,
n_channels=None,
dropout=0.,
pretrained_config=None):
super().__init__()
if pretrained_config is None:
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
self.pretrained_model = pretrained_model
else:
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
self.instantiate_pretrained(pretrained_config)
self.do_reshape = reshape
if n_channels is None:
n_channels = self.pretrained_model.encoder.ch
self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
self.proj = nn.Conv2d(
in_channels, n_channels, kernel_size=3, stride=1, padding=1)
blocks = []
downs = []
ch_in = n_channels
for m in ch_mult:
blocks.append(
ResnetBlock(
in_channels=ch_in,
out_channels=m * n_channels,
dropout=dropout))
ch_in = m * n_channels
downs.append(Downsample(ch_in, with_conv=False))
self.model = nn.ModuleList(blocks)
self.downsampler = nn.ModuleList(downs)
def instantiate_pretrained(self, config):
model = instantiate_from_config(config)
self.pretrained_model = model.eval()
# self.pretrained_model.train = False
for param in self.pretrained_model.parameters():
param.requires_grad = False
@torch.no_grad()
def encode_with_pretrained(self, x):
c = self.pretrained_model.encode(x)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
return c
def forward(self, x):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
z = nonlinearity(z)
for submodel, downmodel in zip(self.model, self.downsampler):
z = submodel(z, temb=None)
z = downmodel(z)
if self.do_reshape:
z = rearrange(z, 'b c h w -> b (h w) c')
return z

View File

@@ -0,0 +1,92 @@
from torch.nn import (BatchNorm1d, BatchNorm2d, Conv2d, Dropout, Linear,
Module, PReLU, Sequential)
from .helpers import (Flatten, bottleneck_IR, bottleneck_IR_SE, get_blocks,
l2_norm)
class Backbone(Module):
def __init__(self,
input_size,
num_layers,
mode='ir',
drop_ratio=0.4,
affine=True):
super(Backbone, self).__init__()
assert input_size in [112, 224], 'input_size should be 112 or 224'
assert num_layers in [50, 100,
152], 'num_layers should be 50, 100 or 152'
assert mode in ['ir', 'ir_se'], 'mode should be ir or ir_se'
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(
Conv2d(3, 64, (3, 3), 1, 1, bias=False), BatchNorm2d(64),
PReLU(64))
if input_size == 112:
self.output_layer = Sequential(
BatchNorm2d(512), Dropout(drop_ratio), Flatten(),
Linear(512 * 7 * 7, 512), BatchNorm1d(512, affine=affine))
else:
self.output_layer = Sequential(
BatchNorm2d(512), Dropout(drop_ratio), Flatten(),
Linear(512 * 14 * 14, 512), BatchNorm1d(512, affine=affine))
modules = []
for block in blocks:
for bottleneck in block:
modules.append(
unit_module(bottleneck.in_channel, bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return l2_norm(x)
def IR_50(input_size):
"""Constructs a ir-50 model."""
model = Backbone(
input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_101(input_size):
"""Constructs a ir-101 model."""
model = Backbone(
input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_152(input_size):
"""Constructs a ir-152 model."""
model = Backbone(
input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_SE_50(input_size):
"""Constructs a ir_se-50 model."""
model = Backbone(
input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
return model
def IR_SE_101(input_size):
"""Constructs a ir_se-101 model."""
model = Backbone(
input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
return model
def IR_SE_152(input_size):
"""Constructs a ir_se-152 model."""
model = Backbone(
input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
return model

View File

@@ -0,0 +1,668 @@
import random
from functools import partial
import clip
import kornia
import kornia.augmentation as K
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from transformers import (CLIPTextModel, CLIPTokenizer, CLIPVisionModel,
T5EncoderModel, T5Tokenizer)
from ..util import default, instantiate_from_config
from .id_loss import IDFeatures
from .util_diffusion import extract_into_tensor, make_beta_schedule, noise_like
from .x_transformer import Encoder, TransformerWrapper
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class FaceClipEncoder(AbstractEncoder):
def __init__(self, augment=True, retreival_key=None):
super().__init__()
self.encoder = FrozenCLIPImageEmbedder()
self.augment = augment
self.retreival_key = retreival_key
def forward(self, img):
encodings = []
with torch.no_grad():
x_offset = 125
if self.retreival_key:
# Assumes retrieved image are packed into the second half of channels
face = img[:, 3:, 190:440, x_offset:(512 - x_offset)]
other = img[:, :3, ...].clone()
else:
face = img[:, :, 190:440, x_offset:(512 - x_offset)]
other = img.clone()
if self.augment:
face = K.RandomHorizontalFlip()(face)
other[:, :, 190:440, x_offset:(512 - x_offset)] *= 0
encodings = [
self.encoder.encode(face),
self.encoder.encode(other),
]
return torch.cat(encodings, dim=1)
def encode(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros(
(1, 2, 768),
device=self.encoder.model.visual.conv1.weight.device)
return self(img)
class FaceIdClipEncoder(AbstractEncoder):
def __init__(self):
super().__init__()
self.encoder = FrozenCLIPImageEmbedder()
for p in self.encoder.parameters():
p.requires_grad = False
self.id = FrozenFaceEncoder(
'/home/jpinkney/code/stable-diffusion/model_ir_se50.pth',
augment=True)
def forward(self, img):
encodings = []
with torch.no_grad():
face = kornia.geometry.resize(
img, (256, 256), interpolation='bilinear', align_corners=True)
other = img.clone()
other[:, :, 184:452, 122:396] *= 0
encodings = [
self.id.encode(face),
self.encoder.encode(other),
]
return torch.cat(encodings, dim=1)
def encode(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros(
(1, 2, 768),
device=self.encoder.model.visual.conv1.weight.device)
return self(img)
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
def forward(self, batch, key=None):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
c = self.embedding(c)
return c
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self,
n_embed,
n_layer,
vocab_size,
max_seq_len=77,
device='cuda'):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, x):
return self(x)
class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device='cuda', vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt')
tokens = batch_encoding['input_ids'].to(self.device)
return tokens
@torch.no_grad()
def encode(self, text):
tokens = self(text)
if not self.vq_interface:
return tokens
return None, None, [None, None, tokens]
def decode(self, text):
return text
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self,
n_embed,
n_layer,
vocab_size=30522,
max_seq_len=77,
device='cuda',
use_tokenizer=True,
embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(
vq_interface=False, max_length=max_seq_len)
self.device = device
self.transformer = TransformerWrapper(
num_tokens=vocab_size,
max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
def forward(self, text):
if self.use_tknz_fn:
tokens = self.tknz_fn(text)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, text):
# output of length 77
return self(text)
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self,
version='google/t5-v1_1-large',
device='cuda',
max_length=77
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt')
tokens = batch_encoding['input_ids'].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenFaceEncoder(AbstractEncoder):
def __init__(self, model_path, augment=False):
super().__init__()
self.loss_fn = IDFeatures(model_path)
# face encoder is frozen
for p in self.loss_fn.parameters():
p.requires_grad = False
# Mapper is trainable
self.mapper = torch.nn.Linear(512, 768)
p = 0.25
if augment:
self.augment = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomEqualize(p=p),
# K.RandomPlanckianJitter(p=p),
# K.RandomPlasmaBrightness(p=p),
# K.RandomPlasmaContrast(p=p),
# K.ColorJiggle(0.02, 0.2, 0.2, p=p),
)
else:
self.augment = False
def forward(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros((1, 1, 768), device=self.mapper.weight.device)
if self.augment is not None:
# Transforms require 0-1
img = self.augment((img + 1) / 2)
img = 2 * img - 1
feat = self.loss_fn(img, crop=True)
feat = self.mapper(feat.unsqueeze(1))
return feat
def encode(self, img):
return self(img)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
def __init__(self,
version='openai/clip-vit-large-patch14',
device='cuda',
max_length=77): # clip-vit-base-patch32
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt')
tokens = batch_encoding['input_ids'].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class ClipImageProjector(AbstractEncoder):
"""
Uses the CLIP image encoder.
"""
def __init__(self,
version='openai/clip-vit-large-patch14',
max_length=77): # clip-vit-base-patch32
super().__init__()
self.model = CLIPVisionModel.from_pretrained(version)
self.model.train()
self.max_length = max_length # TODO: typical value?
self.antialias = True
self.mapper = torch.nn.Linear(1024, 768)
self.register_buffer(
'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False)
null_cond = self.get_null_cond(version, max_length)
self.register_buffer('null_cond', null_cond)
@torch.no_grad()
def get_null_cond(self, version, max_length):
device = self.mean.device
embedder = FrozenCLIPEmbedder(
version=version, device=device, max_length=max_length)
null_cond = embedder([''])
return null_cond
def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(
x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
if isinstance(x, list):
return self.null_cond
# x is assumed to be in range [-1,1]
x = self.preprocess(x)
outputs = self.model(pixel_values=x)
last_hidden_state = outputs.last_hidden_state
last_hidden_state = self.mapper(last_hidden_state)
return F.pad(
last_hidden_state,
[0, 0, 0, self.max_length - last_hidden_state.shape[1], 0, 0])
def encode(self, im):
return self(im)
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
def __init__(self,
version='openai/clip-vit-large-patch14',
device='cuda',
max_length=77): # clip-vit-base-patch32
super().__init__()
self.embedder = FrozenCLIPEmbedder(
version=version, device=device, max_length=max_length)
self.projection = torch.nn.Linear(768, 768)
def forward(self, text):
z = self.embedder(text)
return self.projection(z)
def encode(self, text):
return self(text)
class FrozenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model='ViT-L/14',
jit=False,
device='cpu',
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer(
'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False)
def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(
x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
if isinstance(x, list):
# [""] denotes condition dropout for ucg
device = self.model.visual.conv1.weight.device
return torch.zeros(1, 768, device=device)
return self.model.encode_image(self.preprocess(x)).float()
def encode(self, im):
return self(im).unsqueeze(1)
class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model='ViT-L/14',
jit=False,
device='cpu',
antialias=True,
max_crops=5,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer(
'mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer(
'std',
torch.Tensor([0.26862954, 0.26130258, 0.27577711]),
persistent=False)
self.max_crops = max_crops
def preprocess(self, x):
# Expects inputs in the range -1, 1
randcrop = transforms.RandomResizedCrop(
224, scale=(0.085, 1.0), ratio=(1, 1))
max_crops = self.max_crops
patches = []
crops = [randcrop(x) for _ in range(max_crops)]
patches.extend(crops)
x = torch.cat(patches, dim=0)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
if isinstance(x, list):
# [""] denotes condition dropout for ucg
device = self.model.visual.conv1.weight.device
return torch.zeros(1, self.max_crops, 768, device=device)
batch_tokens = []
for im in x:
patches = self.preprocess(im.unsqueeze(0))
tokens = self.model.encode_image(patches).float()
for t in tokens:
if random.random() < 0.1:
t *= 0
batch_tokens.append(tokens.unsqueeze(0))
return torch.cat(batch_tokens, dim=0)
def encode(self, im):
return self(im)
class SpatialRescaler(nn.Module):
def __init__(self,
n_stages=1,
method='bilinear',
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in [
'nearest', 'linear', 'bilinear', 'trilinear', 'bicubic', 'area'
]
self.multiplier = multiplier
self.interpolator = partial(
torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None
if self.remap_output:
print(
f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.'
)
self.channel_mapper = nn.Conv2d(
in_channels, out_channels, 1, bias=bias)
def forward(self, x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output:
x = self.channel_mapper(x)
return x
def encode(self, x):
return self(x)
class LowScaleEncoder(nn.Module):
def __init__(self,
model_config,
linear_start,
linear_end,
timesteps=1000,
max_noise_level=250,
output_size=64,
scale_factor=1.0):
super().__init__()
self.max_noise_level = max_noise_level
self.model = instantiate_from_config(model_config)
self.augmentation_schedule = self.register_schedule(
timesteps=timesteps,
linear_start=linear_start,
linear_end=linear_end)
self.out_size = output_size
self.scale_factor = scale_factor
def register_schedule(self,
beta_schedule='linear',
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3):
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[
0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev',
to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
* x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t,
x_start.shape) * noise)
def forward(self, x):
z = self.model.encode(x).sample()
z = z * self.scale_factor
noise_level = torch.randint(
0, self.max_noise_level, (x.shape[0], ), device=x.device).long()
z = self.q_sample(z, noise_level)
if self.out_size is not None:
z = torch.nn.functional.interpolate(
z, size=self.out_size,
mode='nearest') # TODO: experiment with mode
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level
def decode(self, z):
z = z / self.scale_factor
return self.model.decode(z)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,349 @@
"""SAMPLING ONLY."""
from functools import partial
import numpy as np
import torch
from tqdm import tqdm
from .sampling_util import norm_thresholding
from .util_diffusion import (make_ddim_sampling_parameters,
make_ddim_timesteps, noise_like)
class PLMSSampler(object):
def __init__(self, model, schedule='linear', **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device('cuda'):
attr = attr.to(torch.device('cuda'))
setattr(self, name, attr)
def make_schedule(self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.,
verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev',
to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas))
alp_1 = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
alp_2 = (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
alp_1 * alp_2)
self.register_buffer('ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
dynamic_threshold=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(
0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
0]
print(f'Running PLMS Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long)
ts_next = torch.full((b, ),
time_range[min(i + 1,
len(time_range) - 1)],
device=device,
dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
dynamic_threshold=dynamic_threshold)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([
unconditional_conditioning[k][i], c[k][i]
]) for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat(
[unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in,
c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
**corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
if use_original_steps:
alphas_prev = self.model.alphas_cumprod_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod
else:
alphas_prev = self.ddim_alphas_prev
sqrt_one_minus_alphas = self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1),
alphas_prev[index],
device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1),
sqrt_one_minus_alphas[index],
device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2]
- 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

@@ -0,0 +1,51 @@
import numpy as np
import torch
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f'input has {x.ndim} dims but target_dims is {target_dims}, which is less'
)
return x[(..., ) + (None, ) * dims_to_append]
def renorm_thresholding(x0, value):
# renorm
pred_max = x0.max()
pred_min = x0.min()
pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1
pred_x0 = 2 * pred_x0 - 1. # -1 ... 1
s = torch.quantile(
rearrange(pred_x0, 'b ... -> b (...)').abs(), value, dim=-1)
s.clamp_(min=1.0)
s = s.view(-1, *((1, ) * (pred_x0.ndim - 1)))
# clip by threshold
# pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max
# temporary hack: numpy on cpu
pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(),
s.cpu().numpy()) / s.cpu().numpy()
pred_x0 = torch.tensor(pred_x0).to(self.model.device)
# re.renorm
pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1
pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range
return pred_x0
def norm_thresholding(x0, value):
s = append_dims(
x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)

View File

@@ -0,0 +1,308 @@
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import math
import os
import numpy as np
import torch
import torch.nn as nn
from einops import repeat
from ..util import instantiate_from_config
def make_beta_schedule(schedule,
n_timestep,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3):
if schedule == 'linear':
betas = (
torch.linspace(
linear_start**0.5,
linear_end**0.5,
n_timestep,
dtype=torch.float64)**2)
elif schedule == 'cosine':
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
+ cosine_s)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == 'sqrt_linear':
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == 'sqrt':
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method,
num_ddim_timesteps,
num_ddpm_timesteps,
verbose=True):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
num_ddim_timesteps))**2).astype(int)
else:
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
def make_ddim_sampling_parameters(alphacums,
ddim_timesteps,
eta,
verbose=True):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]]
+ alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
alpha_1 = (1 - alphas_prev) / (1 - alphas)
alpha_2 = (1 - alphas / alphas_prev)
sigmas = eta * np.sqrt(alpha_1 * alpha_2)
if verbose:
print(
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
)
print(
f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
)
return sigmas, alphas, alphas_prev
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [
x.detach().requires_grad_(True) for x in ctx.input_tensors
]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f'unsupported dimensions: {dims}')
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f'unsupported dimensions: {dims}')
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(
c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
def noise_like(shape, device, repeat=False):
def repeat_noise():
return torch.randn((1, *shape[1:]),
device=device).repeat(shape[0],
*((1, ) * (len(shape) - 1)))
def noise():
return torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()

View File

@@ -0,0 +1,680 @@
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
from collections import namedtuple
from functools import partial
from inspect import isfunction
import torch
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import einsum, nn
# constants
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates',
['pre_softmax_attn', 'post_softmax_attn'])
LayerIntermediates = namedtuple('Intermediates',
['hiddens', 'attn_intermediates'])
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.emb = nn.Embedding(max_seq_len, dim)
self.init_()
def init_(self):
nn.init.normal_(self.emb.weight, std=0.02)
def forward(self, x):
n = torch.arange(x.shape[1], device=x.device)
return self.emb(n)[None, :, :]
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000**(torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(
x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]
# helpers
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def always(val):
def inner(*args, **kwargs):
return val
return inner
def not_equals(val):
def inner(x):
return x != val
return inner
def equals(val):
def inner(x):
return x == val
return inner
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val, )
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(
partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(
map(lambda x: (x[0][len(prefix):], x[1]),
tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# classes
class Scale(nn.Module):
def __init__(self, value, fn):
super().__init__()
self.value = value
self.fn = fn
def forward(self, x, **kwargs):
x, *rest = self.fn(x, **kwargs)
return (x * self.value, *rest)
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.zeros(1))
def forward(self, x, **kwargs):
x, *rest = self.fn(x, **kwargs)
return (x * self.g, *rest)
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class Residual(nn.Module):
def forward(self, x, residual):
return x + residual
class GRUGating(nn.Module):
def __init__(self, dim):
super().__init__()
self.gru = nn.GRUCell(dim, dim)
def forward(self, x, residual):
gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d'))
return gated_output.reshape_as(x)
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
# attention.
class Attention(nn.Module):
def __init__(self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
mask=None,
talking_heads=False,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.,
on_attn=False):
super().__init__()
if use_entmax15:
raise NotImplementedError(
'Check out entmax activation instead of softmax activation!')
self.scale = dim_head**-0.5
self.heads = heads
self.causal = causal
self.mask = mask
inner_dim = dim_head * heads
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_k = nn.Linear(dim, inner_dim, bias=False)
self.to_v = nn.Linear(dim, inner_dim, bias=False)
self.dropout = nn.Dropout(dropout)
# talking heads
self.talking_heads = talking_heads
if talking_heads:
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
# explicit topk sparse attention
self.sparse_topk = sparse_topk
# entmax
self.attn_fn = F.softmax
# add memory key / values
self.num_mem_kv = num_mem_kv
if num_mem_kv > 0:
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
# attention on attention
self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(
inner_dim, dim
* 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
def forward(self,
x,
context=None,
mask=None,
context_mask=None,
rel_pos=None,
sinusoidal_emb=None,
prev_attn=None,
mem=None):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
kv_input = default(context, x)
q_input = x
k_input = kv_input
v_input = kv_input
if exists(mem):
k_input = torch.cat((mem, k_input), dim=-2)
v_input = torch.cat((mem, v_input), dim=-2)
if exists(sinusoidal_emb):
# in shortformer, the query would start at a position offset depending on the past cached memory
offset = k_input.shape[-2] - q_input.shape[-2]
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
k_input = k_input + sinusoidal_emb(k_input)
q = self.to_q(q_input)
k = self.to_k(k_input)
v = self.to_v(v_input)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
(q, k, v))
input_mask = None
if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones(
(b, n), device=device).bool())
k_mask = q_mask if not exists(context) else context_mask
k_mask = default(
k_mask, lambda: torch.ones(
(b, k.shape[-2]), device=device).bool())
q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j')
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b),
(self.mem_k, self.mem_v))
k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
input_mask = F.pad(
input_mask, (self.num_mem_kv, 0), value=True)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = max_neg_value(dots)
if exists(prev_attn):
dots = dots + prev_attn
pre_softmax_attn = dots
if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots,
self.pre_softmax_proj).contiguous()
if exists(rel_pos):
dots = rel_pos(dots)
if exists(input_mask):
dots.masked_fill_(~input_mask, mask_value)
del input_mask
if self.causal:
i, j = dots.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(
r, 'j -> () () () j')
mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
top, _ = dots.topk(self.sparse_topk, dim=-1)
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
mask = dots < vk
dots.masked_fill_(mask, mask_value)
del mask
attn = self.attn_fn(dots, dim=-1)
post_softmax_attn = attn
attn = self.dropout(attn)
if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn,
self.post_softmax_proj).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn)
return self.to_out(out), intermediates
class AttentionLayers(nn.Module):
def __init__(self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
**kwargs):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
self.dim = dim
self.depth = depth
self.layers = nn.ModuleList([])
self.has_pos_emb = position_infused_attn
self.pia_pos_emb = FixedPositionalEmbedding(
dim) if position_infused_attn else None
self.rotary_pos_emb = always(None)
assert rel_pos_num_buckets <= rel_pos_max_distance, 'error'
self.rel_pos = None
self.pre_norm = pre_norm
self.residual_attn = residual_attn
self.cross_residual_attn = cross_residual_attn
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
norm_class = RMSNorm if use_rmsnorm else norm_class
norm_fn = partial(norm_class, dim)
norm_fn = nn.Identity if use_rezero else norm_fn
branch_fn = Rezero if use_rezero else None
if cross_attend and not only_cross:
default_block = ('a', 'c', 'f')
elif cross_attend and only_cross:
default_block = ('c', 'f')
else:
default_block = ('a', 'f')
if macaron:
default_block = ('f', ) + default_block
if exists(custom_layers):
layer_types = custom_layers
elif exists(par_ratio):
par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block))
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(
default_block
) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f', ) * (
par_width - len(default_block))
par_head = par_block * par_attn
layer_types = par_head + ('f', ) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a', ) * sandwich_coef + default_block * (
depth - sandwich_coef) + ('f', ) * sandwich_coef
else:
layer_types = default_block * depth
self.layer_types = layer_types
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
for layer_type in self.layer_types:
if layer_type == 'a':
layer = Attention(
dim, heads=heads, causal=causal, **attn_kwargs)
elif layer_type == 'c':
layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f':
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
else:
raise Exception(f'invalid layer type {layer_type}')
if isinstance(layer, Attention) and exists(branch_fn):
layer = branch_fn(layer)
if gate_residual:
residual_fn = GRUGating(dim)
else:
residual_fn = Residual()
self.layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
def forward(self,
x,
context=None,
mask=None,
context_mask=None,
mems=None,
return_hiddens=False):
hiddens = []
intermediates = []
prev_attn = None
prev_cross_attn = None
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(
zip(self.layer_types, self.layers)):
is_last = ind == (len(self.layers) - 1)
if layer_type == 'a':
hiddens.append(x)
layer_mem = mems.pop(0)
residual = x
if self.pre_norm:
x = norm(x)
if layer_type == 'a':
out, inter = block(
x,
mask=mask,
sinusoidal_emb=self.pia_pos_emb,
rel_pos=self.rel_pos,
prev_attn=prev_attn,
mem=layer_mem)
elif layer_type == 'c':
out, inter = block(
x,
context=context,
mask=mask,
context_mask=context_mask,
prev_attn=prev_cross_attn)
elif layer_type == 'f':
out = block(x)
x = residual_fn(out, residual)
if layer_type in ('a', 'c'):
intermediates.append(inter)
if layer_type == 'a' and self.residual_attn:
prev_attn = inter.pre_softmax_attn
elif layer_type == 'c' and self.cross_residual_attn:
prev_cross_attn = inter.pre_softmax_attn
if not self.pre_norm and not is_last:
x = norm(x)
if return_hiddens:
intermediates = LayerIntermediates(
hiddens=hiddens, attn_intermediates=intermediates)
return x, intermediates
return x
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder'
super().__init__(causal=False, **kwargs)
class TransformerWrapper(nn.Module):
def __init__(self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.,
emb_dropout=0.,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True):
super().__init__()
assert isinstance(
attn_layers, AttentionLayers
), 'attention layers must be one of Encoder or Decoder'
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.num_tokens = num_tokens
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim,
dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = nn.Linear(
dim, num_tokens
) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(
torch.randn(num_memory_tokens, dim))
# let funnel encoder know number of memory tokens, if specified
if hasattr(attn_layers, 'num_memory_tokens'):
attn_layers.num_memory_tokens = num_memory_tokens
def init_(self):
nn.init.normal_(self.token_emb.weight, std=0.02)
def forward(self,
x,
return_embeddings=False,
mask=None,
return_mems=False,
return_attn=False,
mems=None,
**kwargs):
b, num_mem = *x.shape[0], self.num_memory_tokens
x = self.token_emb(x)
x += self.pos_emb(x)
x = self.emb_dropout(x)
x = self.project_emb(x)
if num_mem > 0:
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
x = torch.cat((mem, x), dim=1)
# auto-handle masking after appending memory tokens
if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True)
x, intermediates = self.attn_layers(
x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
out = self.to_logits(x) if not return_embeddings else x
if return_mems:
hiddens = intermediates.hiddens
new_mems = list(
map(lambda pair: torch.cat(pair, dim=-2), zip(
mems, hiddens))) if exists(mems) else hiddens
new_mems = list(
map(lambda t: t[..., -self.max_mem_len:, :].detach(),
new_mems))
return out, new_mems
if return_attn:
attn_maps = list(
map(lambda t: t.post_softmax_attn,
intermediates.attn_intermediates))
return out, attn_maps
return out

View File

@@ -0,0 +1,297 @@
import importlib
import os
import time
from inspect import isfunction
import cv2
import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
import torchvision
from PIL import Image, ImageDraw, ImageFont
from torch import optim
def pil_rectangle_crop(im):
width, height = im.size # Get dimensions
if width <= height:
left = 0
right = width
top = (height - width) / 2
bottom = (height + width) / 2
else:
top = 0
bottom = height
left = (width - height) / 2
bottom = (width + height) / 2
# Crop the center of the image
im = im.crop((left, top, right, bottom))
return im
def add_margin(pil_img, color, size=256):
width, height = pil_img.size
result = Image.new(pil_img.mode, (size, size), color)
result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
return result
# def create_carvekit_interface():
# # Check doc strings for more information
# interface = HiInterface(
# object_type='object', # Can be "object" or "hairs-like".
# batch_size_seg=5,
# batch_size_matting=1,
# device='cuda' if torch.cuda.is_available() else 'cpu',
# seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
# matting_mask_size=2048,
# trimap_prob_threshold=231,
# trimap_dilation=30,
# trimap_erosion_iters=5,
# fp16=False)
# return interface
def load_and_preprocess(interface, input_im):
'''
:param input_im (PIL Image).
:return image (H, W, 3) array in [0, 1].
'''
# See https://github.com/Ir1d/image-background-remove-tool
image = input_im.convert('RGB')
image_without_background = interface([image])[0]
image_without_background = np.array(image_without_background)
est_seg = image_without_background > 127
image = np.array(image)
foreground = est_seg[:, :, -1].astype(np.bool_)
image[~foreground] = [255., 255., 255.]
x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8))
image = image[y:y + h, x:x + w, :]
image = PIL.Image.fromarray(np.array(image))
# resize image such that long edge is 512
image.thumbnail([200, 200], Image.Resampling.LANCZOS)
image = add_margin(image, (255, 255, 255), size=256)
image = np.array(image)
return image
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new('RGB', wh, color='white')
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
nc = int(40 * (wh[0] / 256))
lines = '\n'.join(xc[bi][start:start + nc]
for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill='black', font=font)
except UnicodeEncodeError:
print('Cant encode string for logging. Skipping.')
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f'{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.'
)
return total_params
def instantiate_from_config(config):
if 'target' not in config:
if config == '__is_first_stage__':
return None
elif config == '__is_unconditional__':
return None
raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config['target'])(**config.get('params', dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit('.', 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
class AdamWwithEMAandWings(optim.Optimizer):
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
def __init__(self,
params,
lr=1.e-3,
betas=(0.9, 0.999),
eps=1.e-8,
weight_decay=1.e-2,
amsgrad=False,
ema_decay=0.9999,
ema_power=1.,
param_names=()):
if not 0.0 <= lr:
raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= eps:
raise ValueError('Invalid epsilon value: {}'.format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError('Invalid beta parameter at index 0: {}'.format(
betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError('Invalid beta parameter at index 1: {}'.format(
betas[1]))
if not 0.0 <= weight_decay:
raise ValueError(
'Invalid weight_decay value: {}'.format(weight_decay))
if not 0.0 <= ema_decay <= 1.0:
raise ValueError('Invalid ema_decay value: {}'.format(ema_decay))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
ema_decay=ema_decay,
ema_power=ema_power,
param_names=param_names)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
ema_params_with_grad = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
ema_decay = group['ema_decay']
ema_power = group['ema_power']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError(
'AdamW does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
# Exponential moving average of parameter values
state['param_exp_avg'] = p.detach().float().clone()
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
ema_params_with_grad.append(state['param_exp_avg'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
optim._functional.adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=False)
cur_ema_decay = min(ema_decay, 1 - state['step']**-ema_power)
for param, ema_param in zip(params_with_grad,
ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(
param.float(), alpha=1 - cur_ema_decay)
return loss

View File

@@ -36,6 +36,7 @@ class OCRDetection(TorchModel):
self.return_polygon = cfgs.model.inference_kwargs.return_polygon
self.backbone = cfgs.model.backbone
self.detector = None
self.onnx_export = False
if self.backbone == 'resnet50':
self.detector = VLPTModel()
elif self.backbone == 'resnet18':
@@ -62,11 +63,20 @@ class OCRDetection(TorchModel):
org_shape (`List`): image original shape,
value is [height, width].
"""
pred = self.detector(input['img'])
if type(input) is dict:
pred = self.detector(input['img'])
else:
# for onnx convert
input = {'img': input, 'org_shape': [800, 800]}
pred = self.detector(input['img'])
return {'results': pred, 'org_shape': input['org_shape']}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
pred = inputs['results'][0]
if self.onnx_export:
return pred
height, width = inputs['org_shape']
segmentation = pred > self.thresh
if self.return_polygon:

View File

@@ -164,15 +164,17 @@ def polygons_from_bitmap(pred, _bitmap, dest_width, dest_height):
return boxes, scores
def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height):
def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height, is_numpy=False):
"""
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
"""
assert _bitmap.size(0) == 1
bitmap = _bitmap.cpu().numpy()[0]
pred = pred.cpu().detach().numpy()[0]
if is_numpy:
bitmap = _bitmap[0]
pred = pred[0]
else:
bitmap = _bitmap.cpu().numpy()[0]
pred = pred.cpu().detach().numpy()[0]
height, width = bitmap.shape
boxes = []
scores = []

View File

@@ -109,8 +109,8 @@ class OCRRecognition(TorchModel):
with open(dict_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
cnt = 1
# ConvNextViT model start from index=2
if self.do_chunking:
# ConvNextViT and LightweightEdge model start from index=2
if cfgs.model.recognizer == 'ConvNextViT' or cfgs.model.recognizer == 'LightweightEdge':
cnt += 1
for line in lines:
line = line.strip('\n')

View File

@@ -2,6 +2,7 @@
from collections import OrderedDict
import torch
import torch.nn as nn
from .nas_block import plnas_linear_mix_se
@@ -16,27 +17,20 @@ class LightweightEdge(nn.Module):
def __init__(self):
super(LightweightEdge, self).__init__()
self.FeatureExtraction = plnas_linear_mix_se(3, 123)
self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d(
(None, 1)) # Transform final (imgH/16-1) -> 1
self.dropout = nn.Dropout(0.3)
self.Prediction = nn.Sequential(
OrderedDict([
('fc1', nn.Linear(123, 120)),
('bn', nn.BatchNorm1d(120)),
('fc2', nn.Linear(120, 7642)),
]))
self.our_nas_model = plnas_linear_mix_se(1, 128)
self.embed_dim = 128
self.head = nn.Linear(self.embed_dim, 7644)
def forward(self, input):
visual_feature = self.FeatureExtraction(input)
visual_feature = self.AdaptiveAvgPool(
visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
visual_feature = visual_feature.squeeze(3)
visual_feature = self.dropout(visual_feature)
prediction = self.Prediction.fc1(visual_feature.contiguous())
b, t, c = prediction.shape
prediction = self.Prediction.bn(prediction.view(b * t,
c)).view(b, t, c)
prediction = self.Prediction.fc2(prediction)
# RGB2GRAY
input = input[:, 0:
1, :, :] * 0.2989 + input[:, 1:
2, :, :] * 0.5870 + input[:, 2:
3, :, :] * 0.1140
x = self.our_nas_model(input)
x = torch.squeeze(x, 2)
x = torch.transpose(x, 1, 2)
b, s, e = x.size()
x = x.reshape(b * s, e)
prediction = self.head(x).view(b, s, -1)
return prediction

View File

@@ -126,7 +126,7 @@ def plnas_linear_mix_se(input_channel, output_channel):
stride_stages = [(2, 2), (2, 1), (2, 1), (2, 1)]
n_cell_stages = [5, 5, 5, 5]
width_stages = [32, 64, 96, 123]
width_stages = [32, 64, 96, 128]
conv_op_ids = [
2, 23, 24, 26, 2, 2, 11, 27, 27, 27, 27, 2, 0, 2, 16, 10, 27, 2, 2, 2,
22, 10, 27, 3

View File

@@ -26,7 +26,9 @@ class YOLOXONNX(object):
options.intra_op_num_threads = 1
options.inter_op_num_threads = 1
self.ort_session = ort.InferenceSession(
self.onnx_path, sess_options=options)
self.onnx_path,
sess_options=options,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
self.with_p6 = False
self.multi_detect = multi_detect

View File

@@ -16,6 +16,7 @@ from modelscope.models.cv.s2net_panorama_depth_estimation.networks.util_helper i
compute_hp_info, render_depth_map)
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.device import create_device
from modelscope.utils.logger import get_logger
logger = get_logger()
@@ -35,8 +36,7 @@ class PanoramaDepthEstimation(TorchModel):
"""
super().__init__(model_dir, **kwargs)
if 'device' in kwargs:
self.device = torch.device('cuda' if 'gpu' in
kwargs['device'] else 'cpu')
self.device = create_device(kwargs['device'])
else:
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')

View File

@@ -9,7 +9,8 @@ import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from timm.models.layers import drop, drop_path, trunc_normal_
from timm.layers.drop import drop_path
from timm.layers.weight_init import trunc_normal_
from .common import Upsample, resize

View File

@@ -11,7 +11,8 @@ from collections import OrderedDict
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop, drop_path, trunc_normal_
from timm.layers.drop import drop_path
from timm.layers.weight_init import trunc_normal_
from torch import nn

View File

@@ -8,7 +8,8 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from timm.models.layers import drop, drop_path, trunc_normal_
from timm.layers.drop import drop_path
from timm.layers.weight_init import trunc_normal_
from .common import resize

View File

@@ -0,0 +1,660 @@
# Copyright © Alibaba, Inc. and its affiliates.
# The implementation here is modifed based on StableDiffusionControlNetInpaintPipeline,
# originally Apache 2.0 License and public available at
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
import os
from typing import Any, Callable, Dict, List, Optional, Union
import cv2
import numpy as np
import PIL
import PIL.Image as Image
import torch
import torchvision.transforms as transforms
from diffusers import (AutoencoderKL, ControlNetModel, DiffusionPipeline,
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionInpaintPipeline, StableDiffusionPipeline,
UNet2DConditionModel)
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import (deprecate, is_accelerate_available,
is_accelerate_version, is_compiled_module,
logging, randn_tensor, replace_example_docstring)
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
from modelscope.metainfo import Models
from modelscope.models.base import Tensor, TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.text_texture_generation.lib2.camera import *
from modelscope.models.cv.text_texture_generation.lib2.init_view import *
from modelscope.models.cv.text_texture_generation.utils import *
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
>>> from diffusers.utils import load_image
>>> import numpy as np
>>> import torch
>>> init_image = load_image(image_path)
>>> init_image = init_image.resize((512, 512))
>>> generator = torch.Generator(device="cpu").manual_seed(1)
>>> mask_image = load_image(mask_path)
>>> mask_image = mask_image.resize((512, 512))
>>> def make_inpaint_condition(image, image_mask):
... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
... image[image_mask > 0.5] = -1.0 # set as masked pixel
... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
... image = torch.from_numpy(image)
... return image
>>> control_image = make_inpaint_condition(init_image, mask_image)
>>> controlnet = ControlNetModel.from_pretrained(
... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
... )
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
>>> pipe.enable_model_cpu_offload()
>>> image = pipe(
... "a handsome man with ray-ban sunglasses",
... num_inference_steps=20,
... generator=generator,
... eta=1.0,
... image=init_image,
... mask_image=mask_image,
... control_image=control_image,
... ).images[0]
```
"""
@MODELS.register_module(
Tasks.text_texture_generation, module_name=Models.text_texture_generation)
class Tex2Texture(TorchModel):
def __init__(self, model_dir, *args, **kwargs):
"""The Tex2Texture is modified based on TEXTure and Text2Tex, publicly available at
https://github.com/TEXTurePaper/TEXTurePaper &
https://github.com/daveredrum/Text2Tex
Args:
model_dir: the root directory of the model files
"""
super().__init__(model_dir=model_dir, *args, **kwargs)
if torch.cuda.is_available():
self.device = torch.device('cuda')
logger.info('Use GPU: {}'.format(self.device))
else:
print('no gpu avaiable')
exit()
model_path = model_dir + '/base_model/'
controlmodel_path = model_dir + '/control_model/'
inpaintmodel_path = model_dir + '/inpaint_model/'
torch_dtype = kwargs.get('torch_dtype', torch.float16)
self.controlnet = ControlNetModel.from_pretrained(
controlmodel_path, torch_dtype=torch_dtype).to(self.device)
self.inpaintmodel = StableDiffusionInpaintPipeline.from_pretrained(
inpaintmodel_path,
torch_dtype=torch_dtype,
).to(self.device)
self.pipe = StableDiffusionControlinpaintPipeline.from_pretrained(
model_path, controlnet=self.controlnet,
torch_dtype=torch_dtype).to(self.device)
logger.info('model load over')
def init_mesh(self, mesh_path):
verts, faces, aux = load_obj(mesh_path, device=self.device)
mesh = load_objs_as_meshes([mesh_path], device=self.device)
return mesh, verts, faces, aux
def normalize_mesh(self, mesh):
bbox = mesh.get_bounding_boxes()
num_verts = mesh.verts_packed().shape[0]
mesh_center = bbox.mean(dim=2).repeat(num_verts, 1)
mesh = mesh.offset_verts(-mesh_center)
lens = bbox[0, :, 1] - bbox[0, :, 0]
max_len = lens.max()
scale = 0.9 / max_len
scale = scale.unsqueeze(0).repeat(num_verts)
# mesh.scale_verts_(scale)
new_mesh = mesh.scale_verts(scale)
return new_mesh.verts_packed(), new_mesh, mesh_center, scale
def save_normalized_obj(self, verts, faces, aux, path='normalized.obj'):
print('=> saving normalized mesh file...')
obj_path = path
save_obj(
obj_path,
verts=verts,
faces=faces.verts_idx,
decimal_places=5,
verts_uvs=aux.verts_uvs,
faces_uvs=faces.textures_idx,
texture_map=aux.texture_images[list(aux.texture_images.keys())[0]])
def mesh_normalized(self, mesh_path, save_path='normalized.obj'):
mesh, verts, faces, aux = self.init_mesh(mesh_path)
verts, mesh, mesh_center, scale = self.normalize_mesh(mesh)
self.save_normalized_obj(verts, faces, aux, save_path)
return mesh, verts, faces, aux, mesh_center, scale
def prepare_mask_and_masked_image(image,
mask,
height,
width,
return_image=False):
if image is None:
raise ValueError('`image` input cannot be undefined.')
if mask is None:
raise ValueError('`mask_image` input cannot be undefined.')
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
raise TypeError(
f'`image` is a torch.Tensor but `mask` (type: {type(mask)} is not'
)
# Batch single image
if image.ndim == 3:
assert image.shape[
0] == 3, 'Image outside a batch should be of shape (3, H, W)'
image = image.unsqueeze(0)
# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)
assert image.ndim == 4 and mask.ndim == 4, 'Image and Mask must have 4 dimensions'
assert image.shape[-2:] == mask.shape[
-2:], 'Image and Mask must have the same spatial dimensions'
assert image.shape[0] == mask.shape[
0], 'Image and Mask must have the same batch size'
# Check image is in [-1, 1]
if image.min() < -1 or image.max() > 1:
raise ValueError('Image should be in [-1, 1] range')
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError('Mask should be in [0, 1] range')
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(
f'`mask` is a torch.Tensor but `image` (type: {type(image)} is not'
)
else:
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
# resize all images w.r.t passed height an width
image = [
i.resize((width, height), resample=PIL.Image.LANCZOS)
for i in image
]
image = [np.array(i.convert('RGB'))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
mask = [
i.resize((width, height), resample=PIL.Image.LANCZOS)
for i in mask
]
mask = np.concatenate(
[np.array(m.convert('L'))[None, None, :] for m in mask],
axis=0)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
# n.b. ensure backwards compatibility as old function does not return image
if return_image:
return mask, masked_image, image
return mask, masked_image
class StableDiffusionControlinpaintPipeline(
StableDiffusionControlNetInpaintPipeline):
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.Tensor, PIL.Image.Image] = None,
mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
control_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray,
List[torch.FloatTensor], List[PIL.Image.Image],
List[np.ndarray], ] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 1.0,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator,
List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = 'pil',
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor],
None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
guess_mode: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
`List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
specified in init, images must be passed as a list such that each element of the list can be correctly
batched for input to a single controlnet.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
strength (`float`, *optional*, defaults to 1.):
Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
`strength`. The number of denoising steps depends on the amount of noise initially added. When
`strength` is 1, added noise will be maximum and the denoising process will run for the full number of
iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
portion of the reference `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
than for [`~StableDiffusionControlNetPipeline.__call__`].
guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height, width = self._default_height_width(height, width, image)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
control_image,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
controlnet = self.controlnet._orig_mod if is_compiled_module(
self.controlnet) else self.controlnet
if isinstance(controlnet, MultiControlNetModel) and isinstance(
controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale
] * len(controlnet.nets)
global_pool_conditions = (
controlnet.config.global_pool_conditions if isinstance(
controlnet, ControlNetModel) else
controlnet.nets[0].config.global_pool_conditions)
guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get('scale', None)
if cross_attention_kwargs is not None else None)
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
control_image = self.prepare_control_image(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
elif isinstance(controlnet, MultiControlNetModel):
control_images = []
for control_image_ in control_image:
control_image_ = self.prepare_control_image(
image=control_image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
control_images.append(control_image_)
control_image = control_images
else:
assert False
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
mask, masked_image, init_image = prepare_mask_and_masked_image(
image, mask_image, height, width, return_image=True)
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps=num_inference_steps,
strength=strength,
device=device)
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
latent_timestep = timesteps[:1].repeat(batch_size
* num_images_per_prompt)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0
# 6. Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
num_channels_unet = self.unet.config.in_channels
return_image_latents = num_channels_unet == 4
latents_outputs = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
image=init_image,
timestep=latent_timestep,
is_strength_max=is_strength_max,
return_noise=True,
return_image_latents=return_image_latents,
)
if return_image_latents:
latents, noise, image_latents = latents_outputs
else:
latents, noise = latents_outputs
# 7. Prepare mask latent variables
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
batch_size * num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
)
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat(
[latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t)
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(
control_model_input, t)
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image,
conditioning_scale=controlnet_conditioning_scale,
guess_mode=guess_mode,
return_dict=False,
)
if guess_mode and do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [
torch.cat([torch.zeros_like(d), d])
for d in down_block_res_samples
]
mid_block_res_sample = torch.cat([
torch.zeros_like(mid_block_res_sample),
mid_block_res_sample
])
# predict the noise residual
if num_channels_unet == 9:
latent_model_input = torch.cat(
[latent_model_input, mask, masked_image_latents],
dim=1)
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred,
t,
latents,
**extra_step_kwargs,
return_dict=False)[0]
if num_channels_unet == 4:
init_latents_proper = image_latents[:1]
init_mask = mask[:1]
if i < len(timesteps) - 1:
init_latents_proper = self.scheduler.add_noise(
init_latents_proper, noise, torch.tensor([t]))
latents = (1 - init_mask
) * init_latents_proper + init_mask * latents
if i == len(timesteps) - 1 or ((i + 1) % self.scheduler.order
== 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(
self,
'final_offload_hook') and self.final_offload_hook is not None:
self.unet.to('cpu')
self.controlnet.to('cpu')
torch.cuda.empty_cache()
if not output_type == 'latent':
image = self.vae.decode(
latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(
image, output_type=output_type, do_denormalize=do_denormalize)
if hasattr(
self,
'final_offload_hook') and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -0,0 +1,165 @@
# customized
import sys
import numpy as np
import torch
from pytorch3d.renderer import PerspectiveCameras, look_at_view_transform
from sklearn.metrics.pairwise import cosine_similarity
from modelscope.models.cv.text_texture_generation.lib2.init_view import \
VIEWPOINTS
sys.path.append('.')
# ---------------- UTILS ----------------------
def degree_to_radian(d):
return d * np.pi / 180
def radian_to_degree(r):
return 180 * r / np.pi
def xyz_to_polar(xyz):
""" assume y-axis is the up axis """
x, y, z = xyz
theta = 180 * np.arccos(z) / np.pi
phi = 180 * np.arccos(y) / np.pi
return theta, phi
def polar_to_xyz(theta, phi, dist):
""" assume y-axis is the up axis """
theta = degree_to_radian(theta)
phi = degree_to_radian(phi)
x = np.sin(phi) * np.sin(theta) * dist
y = np.cos(phi) * dist
z = np.sin(phi) * np.cos(theta) * dist
return [x, y, z]
# ---------------- VIEWPOINTS ----------------------
def filter_viewpoints(pre_viewpoints: dict, viewpoints: dict):
""" return the binary mask of viewpoints to be filtered """
filter_mask = [0 for _ in viewpoints.keys()]
for i, v in viewpoints.items():
x_v, y_v, z_v = polar_to_xyz(v['azim'], 90 - v['elev'], v['dist'])
for _, pv in pre_viewpoints.items():
x_pv, y_pv, z_pv = polar_to_xyz(pv['azim'], 90 - pv['elev'],
pv['dist'])
sim = cosine_similarity(
np.array([[x_v, y_v, z_v]]), np.array([[x_pv, y_pv, z_pv]]))[0,
0]
if sim > 0.9:
filter_mask[i] = 1
return filter_mask
def init_viewpoints(init_dist,
init_elev,
init_azim,
use_principle=True,
use_shapenet=False,
use_objaverse=False):
sample_space = 12
(dist_list, elev_list, azim_list,
sector_list) = init_predefined_viewpoints(sample_space, init_dist,
init_elev)
# punishments for views -> in case always selecting the same view
view_punishments = [1 for _ in range(len(dist_list))]
if use_principle:
(dist_list, elev_list, azim_list, sector_list,
view_punishments) = init_principle_viewpoints(dist_list, elev_list,
azim_list, sector_list,
view_punishments,
use_shapenet,
use_objaverse)
azim_list = [v - init_azim for v in azim_list]
elev_list = [v - init_elev for v in elev_list]
return dist_list, elev_list, azim_list, sector_list, view_punishments
def init_principle_viewpoints(dist_list,
elev_list,
azim_list,
sector_list,
view_punishments,
use_shapenet=False,
use_objaverse=False):
if use_shapenet:
key = 'shapenet'
pre_elev_list = [v for v in VIEWPOINTS[key]['elev']]
pre_azim_list = [v for v in VIEWPOINTS[key]['azim']]
pre_sector_list = [v for v in VIEWPOINTS[key]['sector']]
num_principle = 10
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
pre_view_punishments = [0 for _ in range(num_principle)]
elif use_objaverse:
key = 'objaverse'
pre_elev_list = [v for v in VIEWPOINTS[key]['elev']]
pre_azim_list = [v for v in VIEWPOINTS[key]['azim']]
pre_sector_list = [v for v in VIEWPOINTS[key]['sector']]
num_principle = 10
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
pre_view_punishments = [0 for _ in range(num_principle)]
else:
num_principle = 12
pre_elev_list = [v for v in VIEWPOINTS[num_principle]['elev']]
pre_azim_list = [v for v in VIEWPOINTS[num_principle]['azim']]
pre_sector_list = [v for v in VIEWPOINTS[num_principle]['sector']]
pre_dist_list = [dist_list[0] for _ in range(num_principle)]
pre_view_punishments = [0 for _ in range(num_principle)]
dist_list = pre_dist_list + dist_list
elev_list = pre_elev_list + elev_list
azim_list = pre_azim_list + azim_list
sector_list = pre_sector_list + sector_list
view_punishments = pre_view_punishments + view_punishments
return dist_list, elev_list, azim_list, sector_list, view_punishments
def init_predefined_viewpoints(sample_space, init_dist, init_elev):
viewpoints = VIEWPOINTS[sample_space]
assert sample_space == len(viewpoints['sector'])
dist_list = [init_dist
for _ in range(sample_space)] # always the same dist
elev_list = [viewpoints['elev'][i] for i in range(sample_space)]
azim_list = [viewpoints['azim'][i] for i in range(sample_space)]
sector_list = [viewpoints['sector'][i] for i in range(sample_space)]
return dist_list, elev_list, azim_list, sector_list
def init_camera(dist, elev, azim, image_size, device):
R, T = look_at_view_transform(dist, elev, azim)
image_size = torch.tensor([image_size, image_size]).unsqueeze(0)
T[0][2] = dist
cameras = PerspectiveCameras(
R=R, T=T, device=device, image_size=image_size)
return cameras

View File

@@ -0,0 +1,229 @@
PALETTE = {
0: [255, 255, 255], # white - background
1: [204, 50, 50], # red - old
2: [231, 180, 22], # yellow - update
3: [45, 201, 55] # green - new
}
QUAD_WEIGHTS = {
0: 0, # background
1: 0.1, # old
2: 0.5, # update
3: 1 # new
}
VIEWPOINTS = {
2: {
'azim': [0, 180],
'elev': [0, 0],
'sector': ['front', 'back']
},
4: {
'azim': [
45,
315,
135,
225,
],
'elev': [
0,
0,
0,
0,
],
'sector': [
'front right',
'front left',
'back right',
'back left',
]
},
6: {
'azim': [0, 90, 270, 0, 180, 0],
'elev': [0, 0, 0, 90, 0, -90],
'sector': [
'front',
'right',
'left',
'top',
'back',
'bottom',
]
},
10: {
'azim': [270, 315, 225, 0, 180, 45, 135, 90, 270, 270],
'elev': [15, 15, 15, 15, 15, 15, 15, 15, 90, -90],
'sector': [
'front',
'front right',
'front left',
'right',
'left',
'back right',
'back left',
'back',
'top',
'bottom',
]
},
12: {
'azim': [
0,
45,
315,
135,
225,
180,
45,
315,
90,
270,
90,
270,
],
'elev': [
0,
0,
0,
0,
0,
0,
30,
30,
15,
15,
90,
-90,
],
'sector': [
'front',
'front right',
'front left',
'back right',
'back left',
'back',
'front right',
'front left',
'right',
'left',
'top',
'bottom',
]
},
36: {
'azim': [
45,
315,
135,
225,
0,
45,
315,
90,
270,
135,
225,
180,
0,
45,
315,
90,
270,
135,
225,
180,
22.5,
337.5,
67.5,
292.5,
112.5,
247.5,
157.5,
202.5,
22.5,
337.5,
67.5,
292.5,
112.5,
247.5,
157.5,
202.5,
],
'elev': [
0,
0,
0,
0,
30,
30,
30,
30,
30,
30,
30,
30,
60,
60,
60,
60,
60,
60,
60,
60,
15,
15,
15,
15,
15,
15,
15,
15,
45,
45,
45,
45,
45,
45,
45,
45,
],
'sector': [
'front right',
'front left',
'back right',
'back left',
'front',
'front right',
'front left',
'right',
'left',
'back right',
'back left',
'back',
'top front',
'top right',
'top left',
'top right',
'top left',
'top right',
'top left',
'top back',
'front right',
'front left',
'front right',
'front left',
'back right',
'back left',
'back right',
'back left',
'front right',
'front left',
'front right',
'front left',
'back right',
'back left',
'back right',
'back left',
]
}
}

View File

@@ -0,0 +1,655 @@
import os
import random
# customized
import sys
from typing import NamedTuple, Sequence
import cv2
import numpy as np
import torch
from PIL import Image
from pytorch3d.io import save_obj
from pytorch3d.ops import interpolate_face_attributes
from pytorch3d.renderer import (AmbientLights, MeshRasterizer,
MeshRendererWithFragments,
RasterizationSettings, SoftPhongShader,
TexturesUV)
from pytorch3d.renderer.mesh.shader import ShaderBase
from torchvision import transforms
from tqdm import tqdm
from modelscope.models.cv.text_texture_generation.lib2.camera import \
init_camera
from modelscope.models.cv.text_texture_generation.lib2.init_view import *
from modelscope.models.cv.text_texture_generation.lib2.viusel import (
visualize_outputs, visualize_quad_mask)
sys.path.append('.')
class BlendParams(NamedTuple):
sigma: float = 1e-4
gamma: float = 1e-4
background_color: Sequence = (1, 1, 1)
class FlatTexelShader(ShaderBase):
def __init__(self,
device='cpu',
cameras=None,
lights=None,
materials=None,
blend_params=None):
super().__init__(device, cameras, lights, materials, blend_params)
def forward(self, fragments, meshes, **_kwargs):
texels = meshes.sample_textures(fragments)
texels[(fragments.pix_to_face == -1), :] = 0
return texels.squeeze(-2)
def init_soft_phong_shader(camera, blend_params, device):
lights = AmbientLights(device=device)
shader = SoftPhongShader(
cameras=camera,
lights=lights,
device=device,
blend_params=blend_params)
return shader
def init_flat_texel_shader(camera, device):
shader = FlatTexelShader(cameras=camera, device=device)
return shader
def init_renderer(camera, shader, image_size, faces_per_pixel):
raster_settings = RasterizationSettings(
image_size=image_size, faces_per_pixel=faces_per_pixel)
renderer = MeshRendererWithFragments(
rasterizer=MeshRasterizer(
cameras=camera, raster_settings=raster_settings),
shader=shader)
return renderer
@torch.no_grad()
def render(mesh, renderer, pad_value=10):
def phong_normal_shading(meshes, fragments) -> torch.Tensor:
faces = meshes.faces_packed() # (F, 3)
vertex_normals = meshes.verts_normals_packed() # (V, 3)
faces_normals = vertex_normals[faces]
pixel_normals = interpolate_face_attributes(fragments.pix_to_face,
fragments.bary_coords,
faces_normals)
return pixel_normals
def similarity_shading(meshes, fragments):
faces = meshes.faces_packed() # (F, 3)
vertex_normals = meshes.verts_normals_packed() # (V, 3)
faces_normals = vertex_normals[faces]
vertices = meshes.verts_packed() # (V, 3)
face_positions = vertices[faces]
view_directions = torch.nn.functional.normalize(
(renderer.shader.cameras.get_camera_center().reshape(1, 1, 3)
- face_positions),
p=2,
dim=2)
cosine_similarity = torch.nn.CosineSimilarity(dim=2)(faces_normals,
view_directions)
pixel_similarity = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords,
cosine_similarity.unsqueeze(-1))
return pixel_similarity
def get_relative_depth_map(fragments, pad_value=pad_value):
absolute_depth = fragments.zbuf[..., 0] # B, H, W
no_depth = -1
depth_min, depth_max = absolute_depth[absolute_depth != no_depth].min(
), absolute_depth[absolute_depth != no_depth].max()
target_min, target_max = 50, 255
depth_value = absolute_depth[absolute_depth != no_depth]
depth_value = depth_max - depth_value # reverse values
depth_value /= (depth_max - depth_min)
depth_value = depth_value * (target_max - target_min) + target_min
relative_depth = absolute_depth.clone()
relative_depth[absolute_depth != no_depth] = depth_value
relative_depth[absolute_depth == no_depth] = pad_value
return relative_depth
images, fragments = renderer(mesh)
normal_maps = phong_normal_shading(mesh, fragments).squeeze(-2)
similarity_maps = similarity_shading(mesh, fragments).squeeze(-2) # -1 - 1
depth_maps = get_relative_depth_map(fragments)
# normalize similarity mask to 0 - 1
similarity_maps = torch.abs(similarity_maps) # 0 - 1
# HACK erode, eliminate isolated dots
non_zero_similarity = (similarity_maps > 0).float()
non_zero_similarity = (non_zero_similarity * 255.).cpu().numpy().astype(
np.uint8)[0]
non_zero_similarity = cv2.erode(
non_zero_similarity, kernel=np.ones((3, 3), np.uint8), iterations=2)
non_zero_similarity = torch.from_numpy(non_zero_similarity).to(
similarity_maps.device).unsqueeze(0) / 255.
similarity_maps = non_zero_similarity.unsqueeze(-1) * similarity_maps
return images, normal_maps, similarity_maps, depth_maps, fragments
@torch.no_grad()
def check_visible_faces(mesh, fragments):
pix_to_face = fragments.pix_to_face
visible_map = pix_to_face.unique() # (num_visible_faces)
return visible_map
def get_all_4_locations(values_y, values_x):
y_0 = torch.floor(values_y)
y_1 = torch.ceil(values_y)
x_0 = torch.floor(values_x)
x_1 = torch.ceil(values_x)
return torch.cat([y_0, y_0, y_1, y_1],
0).long(), torch.cat([x_0, x_1, x_0, x_1], 0).long()
def compose_quad_mask(new_mask_image, update_mask_image, old_mask_image,
device):
"""
compose quad mask:
-> 0: background
-> 1: old
-> 2: update
-> 3: new
"""
new_mask_tensor = transforms.ToTensor()(new_mask_image).to(device)
update_mask_tensor = transforms.ToTensor()(update_mask_image).to(device)
old_mask_tensor = transforms.ToTensor()(old_mask_image).to(device)
all_mask_tensor = new_mask_tensor + update_mask_tensor + old_mask_tensor
quad_mask_tensor = torch.zeros_like(all_mask_tensor)
quad_mask_tensor[old_mask_tensor == 1] = 1
quad_mask_tensor[update_mask_tensor == 1] = 2
quad_mask_tensor[new_mask_tensor == 1] = 3
return old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor, quad_mask_tensor
def compute_view_heat(similarity_tensor, quad_mask_tensor):
num_total_pixels = quad_mask_tensor.reshape(-1).shape[0]
heat = 0
for idx in QUAD_WEIGHTS:
heat += (quad_mask_tensor
== idx).sum() * QUAD_WEIGHTS[idx] / num_total_pixels
return heat
def select_viewpoint(selected_view_ids,
view_punishments,
mode,
dist_list,
elev_list,
azim_list,
sector_list,
view_idx,
similarity_texture_cache,
exist_texture,
mesh,
faces,
verts_uvs,
image_size,
faces_per_pixel,
init_image_dir,
mask_image_dir,
normal_map_dir,
depth_map_dir,
similarity_map_dir,
device,
use_principle=False):
if mode == 'sequential':
num_views = len(dist_list)
dist = dist_list[view_idx % num_views]
elev = elev_list[view_idx % num_views]
azim = azim_list[view_idx % num_views]
sector = sector_list[view_idx % num_views]
selected_view_ids.append(view_idx % num_views)
elif mode == 'heuristic':
if use_principle and view_idx < 6:
selected_view_idx = view_idx
else:
selected_view_idx = None
max_heat = 0
print('=> selecting next view...')
view_heat_list = []
for sample_idx in tqdm(range(len(dist_list))):
view_heat, *_ = render_one_view_and_build_masks(
dist_list[sample_idx], elev_list[sample_idx],
azim_list[sample_idx], sample_idx, sample_idx,
view_punishments, similarity_texture_cache, exist_texture,
mesh, faces, verts_uvs, image_size, faces_per_pixel,
init_image_dir, mask_image_dir, normal_map_dir,
depth_map_dir, similarity_map_dir, device)
if view_heat > max_heat:
selected_view_idx = sample_idx
max_heat = view_heat
view_heat_list.append(view_heat.item())
print(view_heat_list)
print('select view {} with heat {}'.format(selected_view_idx,
max_heat))
dist = dist_list[selected_view_idx]
elev = elev_list[selected_view_idx]
azim = azim_list[selected_view_idx]
sector = sector_list[selected_view_idx]
selected_view_ids.append(selected_view_idx)
view_punishments[selected_view_idx] *= 0.01
elif mode == 'random':
selected_view_idx = random.choice(range(len(dist_list)))
dist = dist_list[selected_view_idx]
elev = elev_list[selected_view_idx]
azim = azim_list[selected_view_idx]
sector = sector_list[selected_view_idx]
selected_view_ids.append(selected_view_idx)
else:
raise NotImplementedError()
return dist, elev, azim, sector, selected_view_ids, view_punishments
@torch.no_grad()
def build_backproject_mask(mesh, faces, verts_uvs, cameras, reference_image,
faces_per_pixel, image_size, uv_size, device):
# construct pixel UVs
renderer_scaled = init_renderer(
cameras,
shader=init_soft_phong_shader(
camera=cameras, blend_params=BlendParams(), device=device),
image_size=image_size,
faces_per_pixel=faces_per_pixel)
fragments_scaled = renderer_scaled.rasterizer(mesh)
# get UV coordinates for each pixel
faces_verts_uvs = verts_uvs[faces.textures_idx]
pixel_uvs = interpolate_face_attributes(fragments_scaled.pix_to_face,
fragments_scaled.bary_coords,
faces_verts_uvs) # NxHsxWsxKx2
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(-1, 2)
texture_locations_y, texture_locations_x = get_all_4_locations(
(1 - pixel_uvs[:, 1]).reshape(-1) * (uv_size - 1),
pixel_uvs[:, 0].reshape(-1) * (uv_size - 1))
K = faces_per_pixel
texture_values = torch.from_numpy(
np.array(reference_image.resize(
(image_size, image_size)))).float() / 255.
texture_values = texture_values.to(device).unsqueeze(0).expand(
[4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
# texture
texture_tensor = torch.zeros(uv_size, uv_size, 3).to(device)
texture_tensor[texture_locations_y,
texture_locations_x, :] = texture_values.reshape(-1, 3)
return texture_tensor[:, :, 0]
@torch.no_grad()
def build_diffusion_mask(mesh_stuff,
renderer,
exist_texture,
similarity_texture_cache,
target_value,
device,
image_size,
smooth_mask=False,
view_threshold=0.01):
mesh, faces, verts_uvs = mesh_stuff
mask_mesh = mesh.clone() # NOTE in-place operation - DANGER!!!
# visible mask => the whole region
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(
-1, -1, -1, 3).to(device)
mask_mesh.textures = TexturesUV(
maps=torch.ones_like(exist_texture_expand),
faces_uvs=faces.textures_idx[None, ...],
verts_uvs=verts_uvs[None, ...],
sampling_mode='nearest')
# visible_mask_tensor, *_ = render(mask_mesh, renderer)
visible_mask_tensor, _, similarity_map_tensor, *_ = render(
mask_mesh, renderer)
# faces that are too rotated away from the viewpoint will be treated as invisible
valid_mask_tensor = (similarity_map_tensor >= view_threshold).float()
visible_mask_tensor *= valid_mask_tensor
# nonexist mask <=> new mask
exist_texture_expand = exist_texture.unsqueeze(0).unsqueeze(-1).expand(
-1, -1, -1, 3).to(device)
mask_mesh.textures = TexturesUV(
maps=1 - exist_texture_expand,
faces_uvs=faces.textures_idx[None, ...],
verts_uvs=verts_uvs[None, ...],
sampling_mode='nearest')
new_mask_tensor, *_ = render(mask_mesh, renderer)
new_mask_tensor *= valid_mask_tensor
# exist mask => visible mask - new mask
exist_mask_tensor = visible_mask_tensor - new_mask_tensor
exist_mask_tensor[
exist_mask_tensor < 0] = 0 # NOTE dilate can lead to overflow
# all update mask
mask_mesh.textures = TexturesUV(
maps=(
similarity_texture_cache.argmax(0) == target_value
# # only consider the views that have already appeared before
# similarity_texture_cache[0:target_value+1].argmax(0) == target_value
).float().unsqueeze(0).unsqueeze(-1).expand(-1, -1, -1, 3).to(device),
faces_uvs=faces.textures_idx[None, ...],
verts_uvs=verts_uvs[None, ...],
sampling_mode='nearest')
all_update_mask_tensor, *_ = render(mask_mesh, renderer)
# current update mask => intersection between all update mask and exist mask
update_mask_tensor = exist_mask_tensor * all_update_mask_tensor
# keep mask => exist mask - update mask
old_mask_tensor = exist_mask_tensor - update_mask_tensor
# convert
new_mask = new_mask_tensor[0].cpu().float().permute(2, 0, 1)
new_mask = transforms.ToPILImage()(new_mask).convert('L')
update_mask = update_mask_tensor[0].cpu().float().permute(2, 0, 1)
update_mask = transforms.ToPILImage()(update_mask).convert('L')
old_mask = old_mask_tensor[0].cpu().float().permute(2, 0, 1)
old_mask = transforms.ToPILImage()(old_mask).convert('L')
exist_mask = exist_mask_tensor[0].cpu().float().permute(2, 0, 1)
exist_mask = transforms.ToPILImage()(exist_mask).convert('L')
return new_mask, update_mask, old_mask, exist_mask
@torch.no_grad()
def render_one_view(mesh, dist, elev, azim, image_size, faces_per_pixel,
device):
# render the view
# print(image_size)
cameras = init_camera(dist, elev, azim, image_size, device)
renderer = init_renderer(
cameras,
shader=init_soft_phong_shader(
camera=cameras, blend_params=BlendParams(), device=device),
image_size=image_size,
faces_per_pixel=faces_per_pixel)
init_images_tensor, normal_maps_tensor, similarity_tensor, depth_maps_tensor, fragments = render(
mesh, renderer)
# print(init_images_tensor.shape, torch.max(init_images_tensor), torch.min(init_images_tensor))
cv2.imwrite('img.png',
(np.array(init_images_tensor.squeeze(0)[:, :, :3].cpu())
* 255).astype(np.uint8))
return (cameras, renderer, init_images_tensor, normal_maps_tensor,
similarity_tensor, depth_maps_tensor, fragments)
@torch.no_grad()
def build_similarity_texture_cache_for_all_views(mesh, faces, verts_uvs,
dist_list, elev_list,
azim_list, image_size,
image_size_scaled, uv_size,
faces_per_pixel, device):
num_candidate_views = len(dist_list)
similarity_texture_cache = torch.zeros(num_candidate_views, uv_size,
uv_size).to(device)
print('=> building similarity texture cache for all views...')
for i in tqdm(range(num_candidate_views)):
cameras, _, _, _, similarity_tensor, _, _ = render_one_view(
mesh, dist_list[i], elev_list[i], azim_list[i], image_size,
faces_per_pixel, device)
similarity_texture_cache[i] = build_backproject_mask(
mesh, faces, verts_uvs, cameras,
transforms.ToPILImage()(similarity_tensor[0, :, :,
0]).convert('RGB'),
faces_per_pixel, image_size_scaled, uv_size, device)
return similarity_texture_cache
@torch.no_grad()
def render_one_view_and_build_masks(dist,
elev,
azim,
selected_view_idx,
view_idx,
view_punishments,
similarity_texture_cache,
exist_texture,
mesh,
faces,
verts_uvs,
image_size,
faces_per_pixel,
init_image_dir,
mask_image_dir,
normal_map_dir,
depth_map_dir,
similarity_map_dir,
device,
save_intermediate=False,
smooth_mask=False,
view_threshold=0.01):
# render the view
(cameras, renderer, init_images_tensor, normal_maps_tensor,
similarity_tensor, depth_maps_tensor,
fragments) = render_one_view(mesh, dist, elev, azim, image_size,
faces_per_pixel, device)
init_image = init_images_tensor[0].cpu()
init_image = init_image.permute(2, 0, 1)
init_image = transforms.ToPILImage()(init_image).convert('RGB')
normal_map = normal_maps_tensor[0].cpu()
normal_map = normal_map.permute(2, 0, 1)
normal_map = transforms.ToPILImage()(normal_map).convert('RGB')
depth_map = depth_maps_tensor[0].cpu().numpy()
depth_map = Image.fromarray(depth_map).convert('L')
similarity_map = similarity_tensor[0, :, :, 0].cpu()
similarity_map = transforms.ToPILImage()(similarity_map).convert('L')
flat_renderer = init_renderer(
cameras,
shader=init_flat_texel_shader(camera=cameras, device=device),
image_size=image_size,
faces_per_pixel=faces_per_pixel)
new_mask_image, update_mask_image, old_mask_image, exist_mask_image = build_diffusion_mask(
(mesh, faces, verts_uvs),
flat_renderer,
exist_texture,
similarity_texture_cache,
selected_view_idx,
device,
image_size,
smooth_mask=smooth_mask,
view_threshold=view_threshold)
# NOTE the view idx is the absolute idx in the sample space (i.e. `selected_view_idx`)
# it should match with `similarity_texture_cache`
(old_mask_tensor, update_mask_tensor, new_mask_tensor, all_mask_tensor,
quad_mask_tensor) = compose_quad_mask(new_mask_image, update_mask_image,
old_mask_image, device)
view_heat = compute_view_heat(similarity_tensor, quad_mask_tensor)
view_heat *= view_punishments[selected_view_idx]
# save intermediate results
if save_intermediate:
init_image.save(
os.path.join(init_image_dir, '{}.png'.format(view_idx)))
normal_map.save(
os.path.join(normal_map_dir, '{}.png'.format(view_idx)))
depth_map.save(os.path.join(depth_map_dir, '{}.png'.format(view_idx)))
similarity_map.save(
os.path.join(similarity_map_dir, '{}.png'.format(view_idx)))
new_mask_image.save(
os.path.join(mask_image_dir, '{}_new.png'.format(view_idx)))
update_mask_image.save(
os.path.join(mask_image_dir, '{}_update.png'.format(view_idx)))
old_mask_image.save(
os.path.join(mask_image_dir, '{}_old.png'.format(view_idx)))
exist_mask_image.save(
os.path.join(mask_image_dir, '{}_exist.png'.format(view_idx)))
visualize_quad_mask(mask_image_dir, quad_mask_tensor, view_idx,
view_heat, device)
return (view_heat, renderer, cameras, fragments, init_image, normal_map,
depth_map, init_images_tensor, normal_maps_tensor,
depth_maps_tensor, similarity_tensor, old_mask_image,
update_mask_image, new_mask_image, old_mask_tensor,
update_mask_tensor, new_mask_tensor, all_mask_tensor,
quad_mask_tensor)
def save_full_obj(output_dir, obj_name, verts, faces, verts_uvs, faces_uvs,
projected_texture, device):
print('=> saving OBJ file...')
texture_map = transforms.ToTensor()(projected_texture).to(device)
texture_map = texture_map.permute(1, 2, 0)
obj_path = os.path.join(output_dir, obj_name)
save_obj(
obj_path,
verts=verts,
faces=faces,
decimal_places=5,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
texture_map=texture_map)
@torch.no_grad()
def backproject_from_image(mesh, faces, verts_uvs, cameras, reference_image,
new_mask_image, update_mask_image, init_texture,
exist_texture, image_size, uv_size, faces_per_pixel,
device):
# construct pixel UVs
renderer_scaled = init_renderer(
cameras,
shader=init_soft_phong_shader(
camera=cameras, blend_params=BlendParams(), device=device),
image_size=image_size,
faces_per_pixel=faces_per_pixel)
fragments_scaled = renderer_scaled.rasterizer(mesh)
# get UV coordinates for each pixel
faces_verts_uvs = verts_uvs[faces.textures_idx]
pixel_uvs = interpolate_face_attributes(fragments_scaled.pix_to_face,
fragments_scaled.bary_coords,
faces_verts_uvs) # NxHsxWsxKx2
pixel_uvs = pixel_uvs.permute(0, 3, 1, 2,
4).reshape(pixel_uvs.shape[-2],
pixel_uvs.shape[1],
pixel_uvs.shape[2], 2)
# the update mask has to be on top of the diffusion mask
new_mask_image_tensor = transforms.ToTensor()(new_mask_image).to(
device).unsqueeze(-1)
update_mask_image_tensor = transforms.ToTensor()(update_mask_image).to(
device).unsqueeze(-1)
project_mask_image_tensor = torch.logical_or(
update_mask_image_tensor, new_mask_image_tensor).float()
project_mask_image = project_mask_image_tensor * 255.
project_mask_image = Image.fromarray(
project_mask_image[0, :, :, 0].cpu().numpy().astype(np.uint8))
project_mask_image_scaled = project_mask_image.resize(
(image_size, image_size), )
# Image.Resampling.NEAREST
# )
project_mask_image_tensor_scaled = transforms.ToTensor()(
project_mask_image_scaled).to(device)
pixel_uvs_masked = pixel_uvs[project_mask_image_tensor_scaled == 1]
texture_locations_y, texture_locations_x = get_all_4_locations(
(1 - pixel_uvs_masked[:, 1]).reshape(-1) * (uv_size - 1),
pixel_uvs_masked[:, 0].reshape(-1) * (uv_size - 1))
K = pixel_uvs.shape[0]
project_mask_image_tensor_scaled = project_mask_image_tensor_scaled[:,
None, :, :,
None].repeat(
1,
4,
1,
1,
3)
texture_values = torch.from_numpy(
np.array(reference_image.resize((image_size, image_size))))
texture_values = texture_values.to(device).unsqueeze(0).expand(
[4, -1, -1, -1]).unsqueeze(0).expand([K, -1, -1, -1, -1])
texture_values_masked = texture_values.reshape(
-1, 3)[project_mask_image_tensor_scaled.reshape(-1, 3) == 1].reshape(
-1, 3)
# texture
texture_tensor = torch.from_numpy(np.array(init_texture)).to(device)
texture_tensor[texture_locations_y,
texture_locations_x, :] = texture_values_masked
init_texture = Image.fromarray(texture_tensor.cpu().numpy().astype(
np.uint8))
# update texture cache
exist_texture[texture_locations_y, texture_locations_x] = 1
return init_texture, project_mask_image, exist_texture

View File

@@ -0,0 +1,268 @@
import os
import sys
import imageio.v2 as imageio
# visualization
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
from modelscope.models.cv.text_texture_generation.lib2.camera import \
polar_to_xyz
from modelscope.models.cv.text_texture_generation.lib2.init_view import *
matplotlib.use('Agg')
sys.path.append('.')
def visualize_quad_mask(mask_image_dir, quad_mask_tensor, view_idx, view_score,
device):
quad_mask_tensor = quad_mask_tensor.unsqueeze(-1).repeat(1, 1, 1, 3)
quad_mask_image_tensor = torch.zeros_like(quad_mask_tensor)
for idx in PALETTE:
selected = quad_mask_tensor[quad_mask_tensor == idx].reshape(-1, 3)
selected = torch.FloatTensor(
PALETTE[idx]).to(device).unsqueeze(0).repeat(selected.shape[0], 1)
quad_mask_image_tensor[quad_mask_tensor == idx] = selected.reshape(-1)
quad_mask_image_np = quad_mask_image_tensor[0].cpu().numpy().astype(
np.uint8)
quad_mask_image = Image.fromarray(quad_mask_image_np).convert('RGB')
quad_mask_image.save(
os.path.join(mask_image_dir,
'{}_quad_{:.5f}.png'.format(view_idx, view_score)))
def visualize_outputs(output_dir, init_image_dir, mask_image_dir,
inpainted_image_dir, num_views):
# subplot settings
num_col = 3
num_row = 1
sus = 4
summary_image_dir = os.path.join(output_dir, 'summary')
os.makedirs(summary_image_dir, exist_ok=True)
# graph settings
print('=> visualizing results...')
for view_idx in range(num_views):
plt.switch_backend('agg')
fig = plt.figure(dpi=100)
fig.set_size_inches(sus * num_col, sus * (num_row + 1))
fig.set_facecolor('white')
# rendering
plt.subplot2grid((num_row, num_col), (0, 0))
plt.imshow(
Image.open(
os.path.join(init_image_dir, '{}.png'.format(view_idx))))
plt.text(
0,
0,
'Rendering',
fontsize=16,
color='black',
backgroundcolor='white')
plt.axis('off')
# mask
plt.subplot2grid((num_row, num_col), (0, 1))
plt.imshow(
Image.open(
os.path.join(mask_image_dir,
'{}_project.png'.format(view_idx))))
plt.text(
0,
0,
'Project Mask',
fontsize=16,
color='black',
backgroundcolor='white')
plt.set_cmap(cm.Greys_r)
plt.axis('off')
# inpainted
plt.subplot2grid((num_row, num_col), (0, 2))
plt.imshow(
Image.open(
os.path.join(inpainted_image_dir, '{}.png'.format(view_idx))))
plt.text(
0,
0,
'Inpainted',
fontsize=16,
color='black',
backgroundcolor='white')
plt.axis('off')
plt.savefig(
os.path.join(summary_image_dir, '{}.png'.format(view_idx)),
bbox_inches='tight')
fig.clf()
# generate GIF
images = [
imageio.imread(
os.path.join(summary_image_dir, '{}.png'.format(view_idx)))
for view_idx in range(num_views)
]
imageio.mimsave(
os.path.join(summary_image_dir, 'output.gif'), images, duration=1)
print('=> done!')
def visualize_principle_viewpoints(output_dir, dist_list, elev_list,
azim_list):
theta_list = [e for e in azim_list]
phi_list = [90 - e for e in elev_list]
DIST = dist_list[0]
xyz_list = [
polar_to_xyz(theta, phi, DIST)
for theta, phi in zip(theta_list, phi_list)
]
xyz_np = np.array(xyz_list)
color_np = np.array([[0, 0, 0]]).repeat(xyz_np.shape[0], 0)
ax = plt.axes(projection='3d')
SCALE = 0.8
ax.set_xlim((-DIST, DIST))
ax.set_ylim((-DIST, DIST))
ax.set_zlim((-SCALE * DIST, SCALE * DIST))
ax.scatter(
xyz_np[:, 0],
xyz_np[:, 2],
xyz_np[:, 1],
s=100,
c=color_np,
depthshade=True,
label='Principle views')
ax.scatter([0], [0], [0],
c=[[1, 0, 0]],
s=100,
depthshade=True,
label='Object center')
# draw hemisphere
# theta inclination angle
# phi azimuthal angle
n_theta = 50 # number of values for theta
n_phi = 200 # number of values for phi
r = DIST # radius of sphere
# theta, phi = np.mgrid[0.0:0.5*np.pi:n_theta*1j, 0.0:2.0*np.pi:n_phi*1j]
theta, phi = np.mgrid[0.0:1 * np.pi:n_theta * 1j,
0.0:2.0 * np.pi:n_phi * 1j]
x = r * np.sin(theta) * np.cos(phi)
y = r * np.sin(theta) * np.sin(phi)
z = r * np.cos(theta)
ax.plot_surface(x, y, z, rstride=1, cstride=1, alpha=0.25, linewidth=1)
# Make the grid
ax.quiver(
xyz_np[:, 0],
xyz_np[:, 2],
xyz_np[:, 1],
-xyz_np[:, 0],
-xyz_np[:, 2],
-xyz_np[:, 1],
normalize=True,
length=0.3)
ax.set_xlabel('X Label')
ax.set_ylabel('Z Label')
ax.set_zlabel('Y Label')
ax.view_init(30, 35)
ax.legend()
plt.show()
plt.savefig(os.path.join(output_dir, 'principle_viewpoints.png'))
def visualize_refinement_viewpoints(output_dir, selected_view_ids, dist_list,
elev_list, azim_list):
theta_list = [azim_list[i] for i in selected_view_ids]
phi_list = [90 - elev_list[i] for i in selected_view_ids]
DIST = dist_list[0]
xyz_list = [
polar_to_xyz(theta, phi, DIST)
for theta, phi in zip(theta_list, phi_list)
]
xyz_np = np.array(xyz_list)
color_np = np.array([[0, 0, 0]]).repeat(xyz_np.shape[0], 0)
fig = plt.figure()
ax = plt.axes(projection='3d')
SCALE = 0.8
ax.set_xlim((-DIST, DIST))
ax.set_ylim((-DIST, DIST))
ax.set_zlim((-SCALE * DIST, SCALE * DIST))
ax.scatter(
xyz_np[:, 0],
xyz_np[:, 2],
xyz_np[:, 1],
c=color_np,
depthshade=True,
label='Refinement views')
ax.scatter([0], [0], [0],
c=[[1, 0, 0]],
s=100,
depthshade=True,
label='Object center')
# draw hemisphere
# theta inclination angle
# phi azimuthal angle
n_theta = 50 # number of values for theta
n_phi = 200 # number of values for phi
r = DIST # radius of sphere
# theta, phi = np.mgrid[0.0:0.5*np.pi:n_theta*1j, 0.0:2.0*np.pi:n_phi*1j]
theta, phi = np.mgrid[0.0:1 * np.pi:n_theta * 1j,
0.0:2.0 * np.pi:n_phi * 1j]
x = r * np.sin(theta) * np.cos(phi)
y = r * np.sin(theta) * np.sin(phi)
z = r * np.cos(theta)
ax.plot_surface(x, y, z, rstride=1, cstride=1, alpha=0.25, linewidth=1)
# Make the grid
ax.quiver(
xyz_np[:, 0],
xyz_np[:, 2],
xyz_np[:, 1],
-xyz_np[:, 0],
-xyz_np[:, 2],
-xyz_np[:, 1],
normalize=True,
length=0.3)
ax.set_xlabel('X Label')
ax.set_ylabel('Z Label')
ax.set_zlabel('Y Label')
ax.view_init(30, 35)
ax.legend()
plt.show()
plt.savefig(os.path.join(output_dir, 'refinement_viewpoints.png'))
fig.clear()

View File

@@ -0,0 +1,91 @@
# common utils
import os
import imageio.v2 as imageio
import torch
# pytorch3d
from pytorch3d.io import load_obj, load_objs_as_meshes
from pytorch3d.renderer import (AmbientLights, MeshRasterizer,
MeshRendererWithFragments, PerspectiveCameras,
RasterizationSettings, SoftPhongShader,
look_at_view_transform)
from torchvision import transforms
from tqdm import tqdm
IMAGE_SIZE = 768
def init_mesh(model_path, device):
verts, faces, aux = load_obj(model_path, device=device)
mesh = load_objs_as_meshes([model_path], device=device)
return mesh, verts, faces, aux
def init_camera(num_views, dist, elev, azim, view_idx, device):
interval = 360 // num_views
azim = (azim + interval * view_idx) % 360
R, T = look_at_view_transform(dist, elev, azim)
T[0][2] = dist
image_size = torch.tensor([IMAGE_SIZE, IMAGE_SIZE]).unsqueeze(0)
focal_length = torch.tensor(2.0)
cameras = PerspectiveCameras(
focal_length=focal_length,
R=R,
T=T,
device=device,
image_size=image_size)
return cameras, dist, elev, azim
def init_renderer(camera, device):
raster_settings = RasterizationSettings(image_size=IMAGE_SIZE)
lights = AmbientLights(device=device)
renderer = MeshRendererWithFragments(
rasterizer=MeshRasterizer(
cameras=camera, raster_settings=raster_settings),
shader=SoftPhongShader(cameras=camera, lights=lights, device=device))
return renderer
def generation_gif(mesh_path):
num_views = 72
if torch.cuda.is_available():
DEVICE = torch.device('cuda:0')
torch.cuda.set_device(DEVICE)
else:
print('no gpu avaiable')
exit()
output_dir = 'GIF-{}'.format(num_views)
os.makedirs(output_dir, exist_ok=True)
mesh, verts, faces, aux = init_mesh(mesh_path, DEVICE)
# rendering
print('=> rendering...')
for view_idx in tqdm(range(num_views)):
init_image_path = os.path.join(output_dir, '{}.png'.format(view_idx))
dist = 1.8
elev = 15
azim = 0
cameras, dist, elev, azim = init_camera(num_views, dist, elev, azim,
view_idx, DEVICE)
renderer = init_renderer(cameras, DEVICE)
init_images_tensor, fragments = renderer(mesh)
# save images
init_image = init_images_tensor[0].cpu()
init_image = init_image.permute(2, 0, 1)
init_image = transforms.ToPILImage()(init_image).convert('RGB')
init_image.save(init_image_path)
# generate GIF
images = [
imageio.imread(os.path.join(output_dir, '{}.png').format(v_id))
for v_id in range(args.num_views)
]
imageio.mimsave(
os.path.join(output_dir, 'output.gif'), images, duration=0.1)
imageio.mimsave(os.path.join(output_dir, 'output.mp4'), images, fps=25)
print('=> done!')

View File

@@ -0,0 +1,55 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from collections import OrderedDict
import cv2
import numpy as np
import torch
from diffusers import (ControlNetModel, DDIMScheduler,
StableDiffusionControlNetPipeline)
from diffusers.utils import load_image
from modelscope.models import MODELS, TorchModel
@MODELS.register_module('text-to-head', 'text_to_head')
class TextToHeadModel(TorchModel):
def __init__(self, model_dir, *args, **kwargs):
"""The HeadReconModel is implemented based on HRN, publicly available at
https://github.com/youngLBW/HRN
Args:
model_dir: the root directory of the model files
"""
super().__init__(model_dir, *args, **kwargs)
self.model_dir = model_dir
base_model_path = os.path.join(model_dir, 'base_model')
controlnet_path = os.path.join(model_dir, 'control_net')
controlnet = ControlNetModel.from_pretrained(
controlnet_path, torch_dtype=torch.float16)
self.face_gen_pipeline = StableDiffusionControlNetPipeline.from_pretrained(
base_model_path, controlnet=controlnet, torch_dtype=torch.float16)
self.face_gen_pipeline.scheduler = DDIMScheduler.from_config(
self.face_gen_pipeline.scheduler.config)
self.face_gen_pipeline.enable_model_cpu_offload()
self.add_prompt = ', 4K, good looking face, epic realistic, Sony a7, sharp, ' \
'skin detail pores, soft light, uniform illumination'
self.neg_prompt = 'ugly, cross eye, bangs, teeth, glasses, hat, dark, shadow'
control_pose_path = os.path.join(self.model_dir, 'control_pose.jpg')
self.control_pose = load_image(control_pose_path)
def forward(self, input):
prompt = input['text'] + self.add_prompt
image = self.face_gen_pipeline(
prompt,
negative_prompt=self.neg_prompt,
image=self.control_pose,
num_inference_steps=20).images[0] # PIL.Image
return image

View File

@@ -314,7 +314,7 @@ class CLIP4Clip(CLIP4ClipPreTrainedModel):
if key in clip_state_dict:
del clip_state_dict[key]
convert_weights(self.clip)
# convert_weights(self.clip)
# <=== End of CLIP Encoders
self.sim_header = 'seqTransf'

View File

@@ -421,8 +421,10 @@ class Frame_Layer(nn.Module):
tgt = self.norm1(tgt)
memory = self.norm2(memory)
mask_new = adaptive_mask(tgt.shape[0], memory.shape[0], ada_para=0.2)
if torch.cuda.is_available():
mask_new = mask_new.cuda()
tgt2, atten_weights = self.multihead_attn(
tgt, memory, memory, attn_mask=mask_new.cuda())
tgt, memory, memory, attn_mask=mask_new)
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm3(tgt)

View File

@@ -14,7 +14,8 @@ class GenUnifiedTransformer(UnifiedTransformer):
super(GenUnifiedTransformer, self).__init__(model_dir, config, reader,
generator)
self.understand = config.BPETextField.understand
if torch.cuda.is_available():
self.use_gpu = True
if self.use_gpu:
self.cuda()
return
@@ -201,15 +202,21 @@ class GenUnifiedTransformer(UnifiedTransformer):
mask = state['mask']
# shape: [batch_size, 1, 1]
pred_token = state['pred_token']
pred_mask = state['pred_mask']
pred_pos = state['pred_pos']
pred_type = state['pred_type']
pred_turn = state['pred_turn']
if self.use_gpu:
pred_token = state['pred_token'].cuda()
pred_mask = state['pred_mask'].cuda()
pred_pos = state['pred_pos'].cuda()
pred_type = state['pred_type'].cuda()
pred_turn = state['pred_turn'].cuda()
else:
pred_token = state['pred_token']
pred_mask = state['pred_mask']
pred_pos = state['pred_pos']
pred_type = state['pred_type']
pred_turn = state['pred_turn']
# list of shape(len: num_layers): [batch_size, seq_len, hidden_dim]
cache = state['cache']
pred_embed = self.embedder(pred_token, pred_pos, pred_type,
pred_turn).squeeze(-2)
pred_embed = self.embed_layer_norm(pred_embed)

View File

@@ -67,6 +67,8 @@ class SpaceGenerator(object):
self.min_gen_len = config.Generator.min_gen_len
self.max_gen_len = config.Generator.max_gen_len
self.use_gpu = config.use_gpu
if torch.cuda.is_available():
self.use_gpu = True
assert 1 <= self.min_gen_len <= self.max_gen_len
return
@@ -184,7 +186,6 @@ class BeamSearch(SpaceGenerator):
unk_penalty = unk_penalty.cuda()
eos_penalty = eos_penalty.cuda()
scores_after_end = scores_after_end.cuda()
if self.ignore_unk:
scores = scores + unk_penalty
scores = scores + eos_penalty

View File

@@ -0,0 +1,11 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
from .util import EasyDict, make_cache_dir_path

View File

@@ -0,0 +1,52 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""Miscellaneous utility classes and functions."""
import os
import sys
import tempfile
from typing import Any, List, Tuple, Union
class EasyDict(dict):
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
_dnnlib_cache_dir = None
def set_cache_dir(path: str) -> None:
global _dnnlib_cache_dir
_dnnlib_cache_dir = path
def make_cache_dir_path(*paths: str) -> str:
if _dnnlib_cache_dir is not None:
return os.path.join(_dnnlib_cache_dir, *paths)
if 'DNNLIB_CACHE_DIR' in os.environ:
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
if 'HOME' in os.environ:
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
if 'USERPROFILE' in os.environ:
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib',
*paths)
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)

View File

@@ -0,0 +1,181 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import glob
import hashlib
import importlib
import os
import re
import shutil
import uuid
import torch
import torch.utils.cpp_extension
from torch.utils.file_baton import FileBaton
# Global options.
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
# Internal helper funcs.
def _find_compiler_bindir():
patterns = [
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
]
for pattern in patterns:
matches = sorted(glob.glob(pattern))
if len(matches):
return matches[-1]
return None
def _get_mangled_gpu_name():
name = torch.cuda.get_device_name().lower()
out = []
for c in name:
if re.match('[a-z0-9_-]+', c):
out.append(c)
else:
out.append('-')
return ''.join(out)
# Main entry point for compiling and loading C++/CUDA plugins.
_cached_plugins = dict()
def get_plugin(module_name,
sources,
headers=None,
source_dir=None,
**build_kwargs):
assert verbosity in ['none', 'brief', 'full']
if headers is None:
headers = []
if source_dir is not None:
sources = [os.path.join(source_dir, fname) for fname in sources]
headers = [os.path.join(source_dir, fname) for fname in headers]
# Already cached?
if module_name in _cached_plugins:
return _cached_plugins[module_name]
# Print status.
if verbosity == 'full':
print(f'Setting up PyTorch plugin "{module_name}"...')
elif verbosity == 'brief':
print(
f'Setting up PyTorch plugin "{module_name}"... ',
end='',
flush=True)
verbose_build = (verbosity == 'full')
# Compile and load.
try:
if os.name == 'nt' and os.system('where cl.exe >nul 2>nul') != 0:
compiler_bindir = _find_compiler_bindir()
if compiler_bindir is None:
raise RuntimeError(
f'Could not find MSVC/GCC/CLANG installation on this computer.'
f' Check _find_compiler_bindir() in "{__file__}".')
os.environ['PATH'] += ';' + compiler_bindir
# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
# break the build or unnecessarily restrict what's available to nvcc.
# Unset it to let nvcc decide based on what's available on the
# machine.
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
# Incremental build md5sum trickery. Copies all the input source files
# into a cached build directory under a combined md5 digest of the input
# source files. Copying is done only if the combined digest has changed.
# This keeps input file timestamps and filenames the same as in previous
# extension builds, allowing for fast incremental rebuilds.
#
# This optimization is done only in case all the source files reside in
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
# environment variable is set (we take this as a signal that the user
# actually cares about this.)
#
# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
# around the *.cu dependency bug in ninja config.
#
all_source_files = sorted(sources + headers)
all_source_dirs = set(
os.path.dirname(fname) for fname in all_source_files)
if len(all_source_dirs
) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
# Compute combined hash digest for all source files.
hash_md5 = hashlib.md5()
for src in all_source_files:
with open(src, 'rb') as f:
hash_md5.update(f.read())
# Select cached build directory name.
source_digest = hash_md5.hexdigest()
build_top_dir = torch.utils.cpp_extension._get_build_directory(
module_name, verbose=verbose_build) # pylint: disable=protected-access
cached_build_dir = os.path.join(
build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
if not os.path.isdir(cached_build_dir):
tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
os.makedirs(tmpdir)
for src in all_source_files:
shutil.copyfile(
src, os.path.join(tmpdir, os.path.basename(src)))
try:
os.replace(tmpdir, cached_build_dir) # atomic
except OSError:
# source directory already exists, delete tmpdir and its contents.
shutil.rmtree(tmpdir)
if not os.path.isdir(cached_build_dir):
raise
# Compile.
cached_sources = [
os.path.join(cached_build_dir, os.path.basename(fname))
for fname in sources
]
torch.utils.cpp_extension.load(
name=module_name,
build_directory=cached_build_dir,
verbose=verbose_build,
sources=cached_sources,
**build_kwargs)
else:
torch.utils.cpp_extension.load(
name=module_name,
verbose=verbose_build,
sources=sources,
**build_kwargs)
# Load.
module = importlib.import_module(module_name)
except Exception:
if verbosity == 'brief':
print('Failed!')
raise
# Print status and add to cache dict.
if verbosity == 'full':
print(f'Done setting up PyTorch plugin "{module_name}".')
elif verbosity == 'brief':
print('Done.')
_cached_plugins[module_name] = module
return module

View File

@@ -0,0 +1,325 @@
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import contextlib
import re
import warnings
import numpy as np
import torch
from .. import dnnlib
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
# same constant is used multiple times.
_constant_cache = dict()
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
value = np.asarray(value)
if shape is not None:
shape = tuple(shape)
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
device = torch.device('cpu')
if memory_format is None:
memory_format = torch.contiguous_format
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device,
memory_format)
tensor = _constant_cache.get(key, None)
if tensor is None:
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
if shape is not None:
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
tensor = tensor.contiguous(memory_format=memory_format)
_constant_cache[key] = tensor
return tensor
# Replace NaN/Inf with specified numerical values.
try:
nan_to_num = torch.nan_to_num # 1.8.0a0
except AttributeError:
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
assert isinstance(input, torch.Tensor)
if posinf is None:
posinf = torch.finfo(input.dtype).max
if neginf is None:
neginf = torch.finfo(input.dtype).min
assert nan == 0
return torch.clamp(
input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
# Symbolic assert.
try:
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
except AttributeError:
symbolic_assert = torch.Assert # 1.7.0
# Context manager to temporarily suppress known warnings in torch.jit.trace().
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
@contextlib.contextmanager
def suppress_tracer_warnings():
flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
warnings.filters.insert(0, flt)
yield
warnings.filters.remove(flt)
# Assert that the shape of a tensor matches the given list of integers.
# None indicates that the size of a dimension is allowed to vary.
# Performs symbolic assertion when used in torch.jit.trace().
def assert_shape(tensor, ref_shape):
if tensor.ndim != len(ref_shape):
raise AssertionError(
f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}'
)
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
if ref_size is None:
pass
elif isinstance(ref_size, torch.Tensor):
with suppress_tracer_warnings(
): # as_tensor results are registered as constants
symbolic_assert(
torch.equal(torch.as_tensor(size), ref_size),
f'Wrong size for dimension {idx}')
elif isinstance(size, torch.Tensor):
with suppress_tracer_warnings(
): # as_tensor results are registered as constants
symbolic_assert(
torch.equal(size, torch.as_tensor(ref_size)),
f'Wrong size for dimension {idx}: expected {ref_size}')
elif size != ref_size:
raise AssertionError(
f'Wrong size for dimension {idx}: got {size}, expected {ref_size}'
)
# Function decorator that calls torch.autograd.profiler.record_function().
def profiled_function(fn):
def decorator(*args, **kwargs):
with torch.autograd.profiler.record_function(fn.__name__):
return fn(*args, **kwargs)
decorator.__name__ = fn.__name__
return decorator
# Sampler for torch.utils.data.DataLoader that loops over the dataset
# indefinitely, shuffling items as it goes.
class InfiniteSampler(torch.utils.data.Sampler):
def __init__(self,
dataset,
rank=0,
num_replicas=1,
shuffle=True,
seed=0,
window_size=0.5):
assert len(dataset) > 0
assert num_replicas > 0
assert 0 <= rank < num_replicas
assert 0 <= window_size <= 1
super().__init__(dataset)
self.dataset = dataset
self.rank = rank
self.num_replicas = num_replicas
self.shuffle = shuffle
self.seed = seed
self.window_size = window_size
def __iter__(self):
order = np.arange(len(self.dataset))
rnd = None
window = 0
if self.shuffle:
rnd = np.random.RandomState(self.seed)
rnd.shuffle(order)
window = int(np.rint(order.size * self.window_size))
idx = 0
while True:
i = idx % order.size
if idx % self.num_replicas == self.rank:
yield order[i]
if window >= 2:
j = (i - rnd.randint(window)) % order.size
order[i], order[j] = order[j], order[i]
idx += 1
# Utilities for operating with torch.nn.Module parameters and buffers.
def params_and_buffers(module):
assert isinstance(module, torch.nn.Module)
return list(module.parameters()) + list(module.buffers())
def named_params_and_buffers(module):
assert isinstance(module, torch.nn.Module)
return list(module.named_parameters()) + list(module.named_buffers())
def copy_params_and_buffers(src_module, dst_module, require_all=False):
assert isinstance(src_module, torch.nn.Module)
assert isinstance(dst_module, torch.nn.Module)
src_tensors = dict(named_params_and_buffers(src_module))
for name, tensor in named_params_and_buffers(dst_module):
assert (name in src_tensors) or (not require_all)
if name in src_tensors:
tensor.copy_(src_tensors[name].detach()).requires_grad_(
tensor.requires_grad)
# Context manager for easily enabling/disabling DistributedDataParallel
# synchronization.
@contextlib.contextmanager
def ddp_sync(module, sync):
assert isinstance(module, torch.nn.Module)
if sync or not isinstance(module,
torch.nn.parallel.DistributedDataParallel):
yield
else:
with module.no_sync():
yield
# Check DistributedDataParallel consistency across processes.
def check_ddp_consistency(module, ignore_regex=None):
assert isinstance(module, torch.nn.Module)
for name, tensor in named_params_and_buffers(module):
fullname = type(module).__name__ + '.' + name
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
continue
tensor = tensor.detach()
if tensor.is_floating_point():
tensor = nan_to_num(tensor)
other = tensor.clone()
torch.distributed.broadcast(tensor=other, src=0)
assert (tensor == other).all(), fullname
# Print summary table of module hierarchy.
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
assert isinstance(module, torch.nn.Module)
assert not isinstance(module, torch.jit.ScriptModule)
assert isinstance(inputs, (tuple, list))
# Register hooks.
entries = []
nesting = [0]
def pre_hook(_mod, _inputs):
nesting[0] += 1
def post_hook(mod, _inputs, outputs):
nesting[0] -= 1
if nesting[0] <= max_nesting:
outputs = list(outputs) if isinstance(outputs,
(tuple,
list)) else [outputs]
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
hooks = [
mod.register_forward_pre_hook(pre_hook) for mod in module.modules()
]
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
# Run module.
outputs = module(*inputs)
for hook in hooks:
hook.remove()
# Identify unique outputs, parameters, and buffers.
tensors_seen = set()
for e in entries:
e.unique_params = [
t for t in e.mod.parameters() if id(t) not in tensors_seen
]
e.unique_buffers = [
t for t in e.mod.buffers() if id(t) not in tensors_seen
]
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
tensors_seen |= {
id(t)
for t in e.unique_params + e.unique_buffers + e.unique_outputs
}
# Filter out redundant entries.
if skip_redundant:
entries = [
e for e in entries if len(e.unique_params) or len(e.unique_buffers)
or len(e.unique_outputs)
]
# Construct table.
rows = [[
type(module).__name__, 'Parameters', 'Buffers', 'Output shape',
'Datatype'
]]
rows += [['---'] * len(rows[0])]
param_total = 0
buffer_total = 0
submodule_names = {mod: name for name, mod in module.named_modules()}
for e in entries:
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
param_size = sum(t.numel() for t in e.unique_params)
buffer_size = sum(t.numel() for t in e.unique_buffers)
output_shapes = [str(list(t.shape)) for t in e.outputs]
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
rows += [[
name + (':0' if len(e.outputs) >= 2 else ''),
str(param_size) if param_size else '-',
str(buffer_size) if buffer_size else '-',
(output_shapes + ['-'])[0],
(output_dtypes + ['-'])[0],
]]
for idx in range(1, len(e.outputs)):
rows += [[
name + f':{idx}', '-', '-', output_shapes[idx],
output_dtypes[idx]
]]
param_total += param_size
buffer_total += buffer_size
rows += [['---'] * len(rows[0])]
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
# Print table.
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
print()
for row in rows:
print(' '.join(cell + ' ' * (width - len(cell))
for cell, width in zip(row, widths)))
print()
return outputs

Some files were not shown because too many files have changed in this diff Show More