[to #50538422]fix: support load from model id

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13092954
This commit is contained in:
suluyan.sly
2023-06-29 10:26:52 +08:00
parent e951598a82
commit 88de9f78aa

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import argparse
import os
from typing import Dict, Optional
import torch
@@ -112,6 +113,16 @@ def weight_diff(path_raw: str,
This function is given to present full transparency of how the weight diff was created.
"""
if not os.path.exists(path_raw):
logger.info(
f'Path `{path_raw}` not found. Try to load from cache or remote.')
path_raw = snapshot_download(path_raw)
if not os.path.exists(path_convert):
logger.info(
f'Path `{path_convert}` not found. Try to load from cache or remote.'
)
path_convert = snapshot_download(path_convert)
model_raw = Model.from_pretrained(path_raw, device=device)
model_convert = Model.from_pretrained(path_convert, device=device)