From 88de9f78aa2a311f68e3158ae81c8b06cbde4e05 Mon Sep 17 00:00:00 2001 From: "suluyan.sly" Date: Thu, 29 Jun 2023 10:26:52 +0800 Subject: [PATCH] [to #50538422]fix: support load from model id Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13092954 --- tools/weight_diff.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tools/weight_diff.py b/tools/weight_diff.py index ba619887..8cfa1d50 100644 --- a/tools/weight_diff.py +++ b/tools/weight_diff.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import os from typing import Dict, Optional import torch @@ -112,6 +113,16 @@ def weight_diff(path_raw: str, This function is given to present full transparency of how the weight diff was created. """ + if not os.path.exists(path_raw): + logger.info( + f'Path `{path_raw}` not found. Try to load from cache or remote.') + path_raw = snapshot_download(path_raw) + if not os.path.exists(path_convert): + logger.info( + f'Path `{path_convert}` not found. Try to load from cache or remote.' + ) + path_convert = snapshot_download(path_convert) + model_raw = Model.from_pretrained(path_raw, device=device) model_convert = Model.from_pretrained(path_convert, device=device)