Feat: add set_repo_visibility in hub api (#1564)

* add set_repo_visibility

* fix cr

* fix lint
This commit is contained in:
Xingjun.Wang
2025-12-05 15:37:45 +08:00
committed by GitHub
parent a31c44cb5a
commit be6874ea8e
2 changed files with 80 additions and 5 deletions

View File

@@ -418,10 +418,11 @@ RUN pip install --no-cache-dir -U icecream soundfile pybind11 py-spy
class AscendSwiftImageBuilder(SwiftImageBuilder):
def init_args(self, args) -> Any:
if not args.base_image:
# other vision search for: https://hub.docker.com/r/ascendai/cann/tags
args.base_image = "swr.cn-south-1.myhuaweicloud.com/ascendhub/cann:8.3.rc1-a3-ubuntu22.04-py3.11"
args.base_image = 'swr.cn-south-1.myhuaweicloud.com/ascendhub/cann:8.3.rc1-a3-ubuntu22.04-py3.11'
return super().init_args(args)
def generate_dockerfile(self) -> str:
@@ -442,8 +443,7 @@ RUN pip install --no-cache-dir -U icecream soundfile pybind11 py-spy
def image(self) -> str:
return (
f'{docker_registry}:{self.args.base_image.split(":")[-1]}-torch2.7.1'
f'-{self.args.modelscope_version}-ascend-swift-test'
)
f'-{self.args.modelscope_version}-ascend-swift-test')
def push(self):
return 0

View File

@@ -19,7 +19,8 @@ from http import HTTPStatus
from http.cookiejar import CookieJar
from os.path import expanduser
from pathlib import Path
from typing import Any, BinaryIO, Dict, Iterable, List, Optional, Tuple, Union
from typing import (Any, BinaryIO, Dict, Iterable, List, Literal, Optional,
Tuple, Union)
from urllib.parse import urlencode
import json
@@ -457,7 +458,7 @@ class HubApi:
model_id: str,
revision: Optional[str] = DEFAULT_MODEL_REVISION,
endpoint: Optional[str] = None
) -> str:
) -> dict:
"""Get model information at ModelScope
Args:
@@ -2868,6 +2869,80 @@ class HubApi:
'total_files': len(to_delete)
}
def set_repo_visibility(self,
repo_id: str,
repo_type: Literal['model', 'dataset'],
visibility: Literal['private', 'public'],
token: Union[str, None] = None
) -> dict:
"""
Set the visibility of a repo.
Args:
repo_id (str): The repo id in the format of `owner_name/repo_name`.
repo_type (Literal['model', 'dataset']): The repo type, `model` or `dataset`.
visibility (Literal['private', 'public']): The visibility to set, `private` or `public`.
token (Union[str, None]): The access token. If None, will use the cookies from the local cache.
See `https://modelscope.cn/my/myaccesstoken` to get your token.
Returns:
dict: The response from the server.
"""
if not repo_id:
raise ValueError('The arg `repo_id` cannot be empty!')
visibility_map: Dict[str, int] = {v: k for k, v in VisibilityMap.items()}
visibility_code: int = visibility_map.get(visibility, 5)
cookies = self.get_cookies(access_token=token, cookies_required=True)
if repo_type == REPO_TYPE_MODEL:
model_info = self.get_model(model_id=repo_id)
path = f'{self.endpoint}/api/v1/models/{repo_id}'
payload = {
'ChineseName': model_info.get('ChineseName', ''),
'ModelFramework': model_info.get('ModelFramework', 'Pytorch'),
'Visibility': visibility_code,
'ProtectedMode': 2,
'ApprovalMode': model_info.get('ApprovalMode', 2),
'Description': model_info.get('Description', ''),
'AigcType': model_info.get('AigcType', ''),
'VisionFoundation': model_info.get('VisionFoundation', ''),
'ModelCover': model_info.get('ModelCover', ''),
'SubScientificField': model_info.get('SubScientificField', None),
'ScientificField': model_info.get('NEXA', {}).get('ScientificField', ''),
'Source': model_info.get('NEXA', {}).get('Source', ''),
}
elif repo_type == REPO_TYPE_DATASET:
repo_id_parts = repo_id.split('/')
if len(repo_id_parts) != 2 or not all(repo_id_parts):
raise ValueError(f'Invalid dataset repo_id: {repo_id}, should be in format of `owner/dataset_name`')
dataset_idx, _ = self.get_dataset_id_and_type(
dataset_name=repo_id_parts[1],
namespace=repo_id_parts[0],
)
path = f'{self.endpoint}/api/v1/datasets/{dataset_idx}'
payload = {
'Visibility': visibility_code,
'ProtectedMode': 2,
}
else:
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
r = self.session.put(
path,
json=payload,
cookies=cookies,
headers=self.builder_headers(self.headers))
raise_for_http_status(r)
resp = r.json()
raise_on_error(resp)
return resp
class ModelScopeConfig:
path_credential = expanduser(MODELSCOPE_CREDENTIALS_PATH)