mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-14 15:27:42 +01:00
Feat: add set_repo_visibility in hub api (#1564)
* add set_repo_visibility * fix cr * fix lint
This commit is contained in:
@@ -418,10 +418,11 @@ RUN pip install --no-cache-dir -U icecream soundfile pybind11 py-spy
|
||||
|
||||
|
||||
class AscendSwiftImageBuilder(SwiftImageBuilder):
|
||||
|
||||
def init_args(self, args) -> Any:
|
||||
if not args.base_image:
|
||||
# other vision search for: https://hub.docker.com/r/ascendai/cann/tags
|
||||
args.base_image = "swr.cn-south-1.myhuaweicloud.com/ascendhub/cann:8.3.rc1-a3-ubuntu22.04-py3.11"
|
||||
args.base_image = 'swr.cn-south-1.myhuaweicloud.com/ascendhub/cann:8.3.rc1-a3-ubuntu22.04-py3.11'
|
||||
return super().init_args(args)
|
||||
|
||||
def generate_dockerfile(self) -> str:
|
||||
@@ -442,8 +443,7 @@ RUN pip install --no-cache-dir -U icecream soundfile pybind11 py-spy
|
||||
def image(self) -> str:
|
||||
return (
|
||||
f'{docker_registry}:{self.args.base_image.split(":")[-1]}-torch2.7.1'
|
||||
f'-{self.args.modelscope_version}-ascend-swift-test'
|
||||
)
|
||||
f'-{self.args.modelscope_version}-ascend-swift-test')
|
||||
|
||||
def push(self):
|
||||
return 0
|
||||
|
||||
@@ -19,7 +19,8 @@ from http import HTTPStatus
|
||||
from http.cookiejar import CookieJar
|
||||
from os.path import expanduser
|
||||
from pathlib import Path
|
||||
from typing import Any, BinaryIO, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import (Any, BinaryIO, Dict, Iterable, List, Literal, Optional,
|
||||
Tuple, Union)
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import json
|
||||
@@ -457,7 +458,7 @@ class HubApi:
|
||||
model_id: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
endpoint: Optional[str] = None
|
||||
) -> str:
|
||||
) -> dict:
|
||||
"""Get model information at ModelScope
|
||||
|
||||
Args:
|
||||
@@ -2868,6 +2869,80 @@ class HubApi:
|
||||
'total_files': len(to_delete)
|
||||
}
|
||||
|
||||
def set_repo_visibility(self,
|
||||
repo_id: str,
|
||||
repo_type: Literal['model', 'dataset'],
|
||||
visibility: Literal['private', 'public'],
|
||||
token: Union[str, None] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Set the visibility of a repo.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repo id in the format of `owner_name/repo_name`.
|
||||
repo_type (Literal['model', 'dataset']): The repo type, `model` or `dataset`.
|
||||
visibility (Literal['private', 'public']): The visibility to set, `private` or `public`.
|
||||
token (Union[str, None]): The access token. If None, will use the cookies from the local cache.
|
||||
See `https://modelscope.cn/my/myaccesstoken` to get your token.
|
||||
|
||||
Returns:
|
||||
dict: The response from the server.
|
||||
"""
|
||||
if not repo_id:
|
||||
raise ValueError('The arg `repo_id` cannot be empty!')
|
||||
|
||||
visibility_map: Dict[str, int] = {v: k for k, v in VisibilityMap.items()}
|
||||
visibility_code: int = visibility_map.get(visibility, 5)
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
||||
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
model_info = self.get_model(model_id=repo_id)
|
||||
path = f'{self.endpoint}/api/v1/models/{repo_id}'
|
||||
payload = {
|
||||
'ChineseName': model_info.get('ChineseName', ''),
|
||||
'ModelFramework': model_info.get('ModelFramework', 'Pytorch'),
|
||||
'Visibility': visibility_code,
|
||||
'ProtectedMode': 2,
|
||||
'ApprovalMode': model_info.get('ApprovalMode', 2),
|
||||
'Description': model_info.get('Description', ''),
|
||||
'AigcType': model_info.get('AigcType', ''),
|
||||
'VisionFoundation': model_info.get('VisionFoundation', ''),
|
||||
'ModelCover': model_info.get('ModelCover', ''),
|
||||
'SubScientificField': model_info.get('SubScientificField', None),
|
||||
'ScientificField': model_info.get('NEXA', {}).get('ScientificField', ''),
|
||||
'Source': model_info.get('NEXA', {}).get('Source', ''),
|
||||
}
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
|
||||
repo_id_parts = repo_id.split('/')
|
||||
if len(repo_id_parts) != 2 or not all(repo_id_parts):
|
||||
raise ValueError(f'Invalid dataset repo_id: {repo_id}, should be in format of `owner/dataset_name`')
|
||||
|
||||
dataset_idx, _ = self.get_dataset_id_and_type(
|
||||
dataset_name=repo_id_parts[1],
|
||||
namespace=repo_id_parts[0],
|
||||
)
|
||||
|
||||
path = f'{self.endpoint}/api/v1/datasets/{dataset_idx}'
|
||||
payload = {
|
||||
'Visibility': visibility_code,
|
||||
'ProtectedMode': 2,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
||||
|
||||
r = self.session.put(
|
||||
path,
|
||||
json=payload,
|
||||
cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
|
||||
raise_for_http_status(r)
|
||||
resp = r.json()
|
||||
raise_on_error(resp)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
class ModelScopeConfig:
|
||||
path_credential = expanduser(MODELSCOPE_CREDENTIALS_PATH)
|
||||
|
||||
Reference in New Issue
Block a user