diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index 65c238da..a3707918 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -142,12 +142,8 @@ class EpochBasedTrainer(BaseTrainer): self._samplers = samplers if isinstance(model, str): - third_party = kwargs.get(ThirdParty.KEY) - if third_party is not None: - kwargs.pop(ThirdParty.KEY) - self.model_dir = self.get_or_download_model_dir( - model, model_revision, third_party) + model, model_revision, kwargs.pop(ThirdParty.KEY, None)) if cfg_file is None: cfg_file = os.path.join(self.model_dir, ModelFile.CONFIGURATION) @@ -159,7 +155,10 @@ class EpochBasedTrainer(BaseTrainer): if hasattr(model, 'model_dir'): check_local_model_is_latest( model.model_dir, - user_agent={Invoke.KEY: Invoke.LOCAL_TRAINER}) + user_agent={ + Invoke.KEY: Invoke.LOCAL_TRAINER, + ThirdParty.KEY: kwargs.pop(ThirdParty.KEY, None) + }) super().__init__(cfg_file, arg_parse_fn) self.cfg_modify_fn = cfg_modify_fn