v2 inference

This commit is contained in:
Yuwei Guo
2023-09-10 21:27:27 +08:00
parent 108921965d
commit 1904a01117
4 changed files with 53 additions and 3 deletions

View File

@@ -34,7 +34,6 @@ def main(args):
time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
savedir = f"samples/{Path(args.config).stem}-{time_str}"
os.makedirs(savedir)
inference_config = OmegaConf.load(args.inference_config)
config = OmegaConf.load(args.config)
samples = []
@@ -45,7 +44,8 @@ def main(args):
motion_modules = model_config.motion_module
motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
for motion_module in motion_modules:
inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config))
### >>> create validation pipeline >>> ###
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
@@ -148,7 +148,7 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml")
parser.add_argument("--inference_config", type=str, default="configs/inference/inference-v1.yaml")
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--L", type=int, default=16 )