From 1944416d3d883f89396fc9a7ae142fcdaa7f6d09 Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Thu, 29 Jun 2023 16:49:20 +0800 Subject: [PATCH] Fix device_map for DiffusionForTextToImageSynthesis Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13115051 * add kwargs for device map --- modelscope/models/multi_modal/diffusion/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelscope/models/multi_modal/diffusion/model.py b/modelscope/models/multi_modal/diffusion/model.py index d979cc7f..34d8d342 100644 --- a/modelscope/models/multi_modal/diffusion/model.py +++ b/modelscope/models/multi_modal/diffusion/model.py @@ -114,9 +114,9 @@ class DiffusionModel(nn.Module): Tasks.text_to_image_synthesis, module_name=Models.diffusion) class DiffusionForTextToImageSynthesis(Model): - def __init__(self, model_dir, device='gpu'): + def __init__(self, model_dir, device='gpu', **kwargs): device = 'gpu' if torch.cuda.is_available() else 'cpu' - super().__init__(model_dir=model_dir, device=device) + super().__init__(model_dir=model_dir, device=device, **kwargs) diffusion_model = DiffusionModel(model_dir=model_dir) pretrained_params = torch.load( osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu')