[to #42322933] remove mpi4py dependence for ddpm-segmentation

This commit is contained in:
shuying.shu
2023-02-12 01:34:22 +00:00
committed by zhangzhicheng.zzc
parent 4215e67ba2
commit c43bdb2ad8

View File

@@ -88,7 +88,6 @@ class FeatureExtractorDDPM(FeatureExtractor):
def _load_pretrained_model(self, model_path, **kwargs):
import inspect
import ddpm_guided_diffusion.dist_util as dist_util
from ddpm_guided_diffusion.script_util import create_model_and_diffusion
# Needed to pass only expected args to the function
@@ -97,8 +96,9 @@ class FeatureExtractorDDPM(FeatureExtractor):
self.model, self.diffusion = create_model_and_diffusion(
**expected_args)
self.model.load_state_dict(
dist_util.load_state_dict(model_path, map_location='cpu'))
state_dict = torch.load(model_path, map_location='cpu')
self.model.load_state_dict(state_dict)
if kwargs['use_fp16']:
self.model.convert_to_fp16()