mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Add third_party key (#546)
This commit is contained in:
@@ -142,12 +142,8 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self._samplers = samplers
|
||||
|
||||
if isinstance(model, str):
|
||||
third_party = kwargs.get(ThirdParty.KEY)
|
||||
if third_party is not None:
|
||||
kwargs.pop(ThirdParty.KEY)
|
||||
|
||||
self.model_dir = self.get_or_download_model_dir(
|
||||
model, model_revision, third_party)
|
||||
model, model_revision, kwargs.pop(ThirdParty.KEY, None))
|
||||
if cfg_file is None:
|
||||
cfg_file = os.path.join(self.model_dir,
|
||||
ModelFile.CONFIGURATION)
|
||||
@@ -159,7 +155,10 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
if hasattr(model, 'model_dir'):
|
||||
check_local_model_is_latest(
|
||||
model.model_dir,
|
||||
user_agent={Invoke.KEY: Invoke.LOCAL_TRAINER})
|
||||
user_agent={
|
||||
Invoke.KEY: Invoke.LOCAL_TRAINER,
|
||||
ThirdParty.KEY: kwargs.pop(ThirdParty.KEY, None)
|
||||
})
|
||||
|
||||
super().__init__(cfg_file, arg_parse_fn)
|
||||
self.cfg_modify_fn = cfg_modify_fn
|
||||
|
||||
Reference in New Issue
Block a user