diff --git a/tools/weight_diff.py b/tools/weight_diff.py index ba619887..8cfa1d50 100644 --- a/tools/weight_diff.py +++ b/tools/weight_diff.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import os from typing import Dict, Optional import torch @@ -112,6 +113,16 @@ def weight_diff(path_raw: str, This function is given to present full transparency of how the weight diff was created. """ + if not os.path.exists(path_raw): + logger.info( + f'Path `{path_raw}` not found. Try to load from cache or remote.') + path_raw = snapshot_download(path_raw) + if not os.path.exists(path_convert): + logger.info( + f'Path `{path_convert}` not found. Try to load from cache or remote.' + ) + path_convert = snapshot_download(path_convert) + model_raw = Model.from_pretrained(path_raw, device=device) model_convert = Model.from_pretrained(path_convert, device=device)