Fix/aigc weight (#1464)

(cherry picked from commit e802630865)
This commit is contained in:
Koko-ry
2025-08-13 23:31:59 +08:00
committed by tastelikefeet
parent 743304f3da
commit 5c13b66dbe
2 changed files with 72 additions and 2 deletions

View File

@@ -244,6 +244,8 @@ class HubApi:
if aigc_model is not None:
# Use AIGC model endpoint
path = f'{endpoint}/api/v1/models/aigc'
# Best-effort pre-upload weights so server recognizes sha256 (use existing cookies)
aigc_model.preupload_weights(cookies=cookies, headers=self.builder_headers(self.headers))
# Add AIGC-specific fields to body
body.update({
@@ -272,7 +274,7 @@ class HubApi:
raise_on_error(r.json())
model_repo_url = f'{endpoint}/models/{model_id}'
# TODO: due to server error, the upload function is not working
# TODO: to be aligned with the new api
# Upload model files for AIGC models
# if aigc_model is not None:
# aigc_model.upload_to_repo(self, model_id, token)

View File

@@ -1,9 +1,12 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import os
from typing import List, Optional
import requests
from tqdm.auto import tqdm
from modelscope.hub.utils.utils import MODELSCOPE_URL_SCHEME, get_domain
from modelscope.utils.logger import get_logger
logger = get_logger()
@@ -209,6 +212,71 @@ class AigcModel:
'You may need to upload the model manually after creation.')
return False
def preupload_weights(self,
*,
cookies: Optional[object] = None,
timeout: int = 300,
headers: Optional[dict] = None) -> None:
"""Pre-upload aigc model weights to the LFS server.
Server may require the sha256 of weights to be registered before creation.
This method streams the weight file so the sha gets registered.
Args:
cookies: Optional requests-style cookies (CookieJar/dict). If provided, preferred.
timeout: Request timeout seconds.
headers: Optional headers.
"""
domain: str = get_domain()
base_url: str = f'{MODELSCOPE_URL_SCHEME}lfs.{domain.lstrip("www.")}'
url: str = f'{base_url}/api/v1/models/aigc/weights'
file_path = getattr(self, 'target_file', None) or self.model_path
file_path = os.path.abspath(os.path.expanduser(file_path))
if not os.path.isfile(file_path):
raise ValueError(f'Pre-upload expects a file, got: {file_path}')
cookies = dict(cookies) if cookies else None
if cookies is None:
raise ValueError('Token does not exist, please login first.')
headers.update({'Cookie': f"m_session_id={cookies['m_session_id']}"})
file_size = os.path.getsize(file_path)
def read_in_chunks(file_object,
pbar,
chunk_size: int = 1 * 1024 * 1024):
while True:
ck = file_object.read(chunk_size)
if not ck:
break
pbar.update(len(ck))
yield ck
with tqdm(
total=file_size,
unit='B',
unit_scale=True,
dynamic_ncols=True,
desc='[Pre-uploading] ') as pbar:
with open(file_path, 'rb') as f:
r = requests.put(
url,
headers=headers,
data=read_in_chunks(f, pbar),
timeout=timeout,
)
try:
resp = r.json()
except requests.exceptions.JSONDecodeError:
r.raise_for_status()
return
# If JSON body returned, try best-effort check
if isinstance(resp, dict) and resp.get('Success') is False:
msg = resp.get('Message', 'unknown error')
raise RuntimeError(f'Pre-upload failed: {msg}')
def to_dict(self) -> dict:
"""Converts the AIGC parameters to a dictionary suitable for API calls."""
return {