mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Fix device_map for DiffusionForTextToImageSynthesis
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13115051 * add kwargs for device map
This commit is contained in:
@@ -114,9 +114,9 @@ class DiffusionModel(nn.Module):
|
||||
Tasks.text_to_image_synthesis, module_name=Models.diffusion)
|
||||
class DiffusionForTextToImageSynthesis(Model):
|
||||
|
||||
def __init__(self, model_dir, device='gpu'):
|
||||
def __init__(self, model_dir, device='gpu', **kwargs):
|
||||
device = 'gpu' if torch.cuda.is_available() else 'cpu'
|
||||
super().__init__(model_dir=model_dir, device=device)
|
||||
super().__init__(model_dir=model_dir, device=device, **kwargs)
|
||||
diffusion_model = DiffusionModel(model_dir=model_dir)
|
||||
pretrained_params = torch.load(
|
||||
osp.join(model_dir, ModelFile.TORCH_MODEL_BIN_FILE), 'cpu')
|
||||
|
||||
Reference in New Issue
Block a user