update ckpt to general_v0.1 (#696)

This commit is contained in:
Firmament-cyou
2023-12-26 16:59:47 +08:00
committed by GitHub
parent fb46e18351
commit b473decba0
4 changed files with 7 additions and 6 deletions

View File

@@ -7,6 +7,7 @@ https://github.com/CompVis/taming-transformers
"""
import itertools
import os
from contextlib import contextmanager, nullcontext
from functools import partial
@@ -741,7 +742,8 @@ class LatentDiffusion(DDPM):
param.requires_grad = False
def instantiate_cond_stage(self, config):
config.params.model_dir = self.model_dir
config.params.model_path = os.path.join(self.model_dir,
config.params.model_path)
if not self.cond_stage_trainable:
if config == '__is_first_stage__':
print('Using first stage also as cond stage.')

View File

@@ -331,9 +331,8 @@ class FrozenDinoV2Encoder(AbstractEncoder):
Uses the DINOv2 encoder for image
"""
def __init__(self, model_dir, device='cuda', freeze=True):
DINOv2_weight_path = os.path.join(model_dir,
'dinov2_vitg14_pretrain.pth')
def __init__(self, model_path, device='cuda', freeze=True):
DINOv2_weight_path = model_path
super().__init__()
dinov2 = hubconf.dinov2_vitg14()

View File

@@ -52,7 +52,7 @@ class AnydoorPipeline(Pipeline):
"""
super().__init__(model=model, **kwargs)
model_ckpt = os.path.join(self.model.model_dir,
'epoch=1-step=8687.ckpt')
self.cfg.model.model_path)
self.model.load_state_dict(
self._get_state_dict(model_ckpt, location='cuda'))
self.ddim_sampler = DDIMSampler(self.model)

View File

@@ -11,7 +11,7 @@ class AnydoorTest(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.image_to_image_generation
self.model_id = 'damo/AnyDoor'
self.model_id = 'damo/AnyDoor_models'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run(self):