Add third_party key (#546)

This commit is contained in:
tastelikefeet
2023-09-19 14:05:44 +08:00
committed by GitHub
parent ae039bbe02
commit 4cf7b1e737

View File

@@ -142,12 +142,8 @@ class EpochBasedTrainer(BaseTrainer):
self._samplers = samplers
if isinstance(model, str):
third_party = kwargs.get(ThirdParty.KEY)
if third_party is not None:
kwargs.pop(ThirdParty.KEY)
self.model_dir = self.get_or_download_model_dir(
model, model_revision, third_party)
model, model_revision, kwargs.pop(ThirdParty.KEY, None))
if cfg_file is None:
cfg_file = os.path.join(self.model_dir,
ModelFile.CONFIGURATION)
@@ -159,7 +155,10 @@ class EpochBasedTrainer(BaseTrainer):
if hasattr(model, 'model_dir'):
check_local_model_is_latest(
model.model_dir,
user_agent={Invoke.KEY: Invoke.LOCAL_TRAINER})
user_agent={
Invoke.KEY: Invoke.LOCAL_TRAINER,
ThirdParty.KEY: kwargs.pop(ThirdParty.KEY, None)
})
super().__init__(cfg_file, arg_parse_fn)
self.cfg_modify_fn = cfg_modify_fn