mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
[to #42322933] interface refine with doc
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9159678
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import imp
|
||||
import os
|
||||
import pickle
|
||||
import subprocess
|
||||
|
||||
@@ -6,3 +6,16 @@ DEFAULT_MODELSCOPE_GROUP = 'damo'
|
||||
MODEL_ID_SEPARATOR = '/'
|
||||
|
||||
LOGGER_NAME = 'ModelScopeHub'
|
||||
|
||||
|
||||
class Licenses(object):
|
||||
APACHE_V2 = 'Apache License 2.0'
|
||||
GPL = 'GPL'
|
||||
LGPL = 'LGPL'
|
||||
MIT = 'MIT'
|
||||
|
||||
|
||||
class ModelVisibility(object):
|
||||
PRIVATE = 1
|
||||
INTERNAL = 3
|
||||
PUBLIC = 5
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models.builder import build_model
|
||||
@@ -42,13 +42,18 @@ class Model(ABC):
|
||||
return input
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name_or_path: str, *model_args, **kwargs):
|
||||
""" Instantiate a model from local directory or remote model repo
|
||||
def from_pretrained(cls,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str] = 'master',
|
||||
*model_args,
|
||||
**kwargs):
|
||||
""" Instantiate a model from local directory or remote model repo. Note
|
||||
that when loading from remote, the model revision can be specified.
|
||||
"""
|
||||
if osp.exists(model_name_or_path):
|
||||
local_model_dir = model_name_or_path
|
||||
else:
|
||||
local_model_dir = snapshot_download(model_name_or_path)
|
||||
local_model_dir = snapshot_download(model_name_or_path, revision)
|
||||
logger.info(f'initialize model from {local_model_dir}')
|
||||
cfg = Config.from_file(
|
||||
osp.join(local_model_dir, ModelFile.CONFIGURATION))
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import List, Optional, Union
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
from modelscope.hub.constants import Licenses, ModelVisibility
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.utils.config import Config
|
||||
@@ -16,8 +17,8 @@ def create_model_if_not_exist(
|
||||
api,
|
||||
model_id: str,
|
||||
chinese_name: str,
|
||||
visibility: Optional[int] = 5, # 1-private, 5-public
|
||||
license: Optional[str] = 'apache-2.0',
|
||||
visibility: Optional[int] = ModelVisibility.PUBLIC,
|
||||
license: Optional[str] = Licenses.APACHE_V2,
|
||||
revision: Optional[str] = 'master'):
|
||||
exists = True
|
||||
try:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import unittest
|
||||
|
||||
from maas_hub.maas_api import MaasApi
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.utils.hub import create_model_if_not_exist
|
||||
|
||||
# note this is temporary before official account management is ready
|
||||
USER_NAME = 'maasadmin'
|
||||
PASSWORD = '12345678'
|
||||
|
||||
@@ -11,8 +11,7 @@ PASSWORD = '12345678'
|
||||
class HubExampleTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.api = MaasApi()
|
||||
# note this is temporary before official account management is ready
|
||||
self.api = HubApi()
|
||||
self.api.login(USER_NAME, PASSWORD)
|
||||
|
||||
@unittest.skip('to be used for local test only')
|
||||
@@ -22,7 +21,6 @@ class HubExampleTest(unittest.TestCase):
|
||||
model_chinese_name = '达摩卡通化模型'
|
||||
model_org = 'damo'
|
||||
model_id = '%s/%s' % (model_org, model_name)
|
||||
|
||||
created = create_model_if_not_exist(self.api, model_id,
|
||||
model_chinese_name)
|
||||
if not created:
|
||||
|
||||
@@ -4,7 +4,8 @@ import tempfile
|
||||
import unittest
|
||||
import uuid
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.constants import Licenses, ModelVisibility
|
||||
from modelscope.hub.file_download import model_file_download
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
@@ -31,8 +32,8 @@ class HubOperationTest(unittest.TestCase):
|
||||
self.api.create_model(
|
||||
model_id=self.model_id,
|
||||
chinese_name=model_chinese_name,
|
||||
visibility=5, # 1-private, 5-public
|
||||
license='apache-2.0')
|
||||
visibility=ModelVisibility.PUBLIC,
|
||||
license=Licenses.APACHE_V2)
|
||||
temporary_dir = tempfile.mkdtemp()
|
||||
self.model_dir = os.path.join(temporary_dir, self.model_name)
|
||||
repo = Repository(self.model_dir, clone_from=self.model_id)
|
||||
|
||||
Reference in New Issue
Block a user