mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Merge remote-tracking branch 'origin/master' into ofa/finetune
This commit is contained in:
@@ -12,7 +12,6 @@ from http.cookiejar import CookieJar
|
||||
from os.path import expanduser
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import attrs
|
||||
import requests
|
||||
|
||||
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
|
||||
@@ -22,14 +21,9 @@ from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
|
||||
API_RESPONSE_FIELD_USERNAME,
|
||||
DEFAULT_CREDENTIALS_PATH, Licenses,
|
||||
ModelVisibility)
|
||||
from modelscope.hub.deploy import (DeleteServiceParameters,
|
||||
DeployServiceParameters,
|
||||
GetServiceParameters, ListServiceParameters,
|
||||
ServiceParameters, ServiceResourceConfig,
|
||||
Vendor)
|
||||
from modelscope.hub.errors import (InvalidParameter, NotExistError,
|
||||
NotLoginException, NotSupportError,
|
||||
RequestError, datahub_raise_on_error,
|
||||
NotLoginException, RequestError,
|
||||
datahub_raise_on_error,
|
||||
handle_http_post_error,
|
||||
handle_http_response, is_ok, raise_on_error)
|
||||
from modelscope.hub.git import GitCommandWrapper
|
||||
@@ -312,169 +306,6 @@ class HubApi:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def deploy_model(self, model_id: str, revision: str, instance_name: str,
|
||||
resource: ServiceResourceConfig,
|
||||
provider: ServiceParameters):
|
||||
"""Deploy model to cloud, current we only support PAI EAS, this is asynchronous
|
||||
call , please check instance status through the console or query the instance status.
|
||||
At the same time, this call may take a long time.
|
||||
|
||||
Args:
|
||||
model_id (str): The deployed model id
|
||||
revision (str): The model revision
|
||||
instance_name (str): The deployed model instance name.
|
||||
resource (DeployResource): The resource information.
|
||||
provider (CreateParameter): The cloud service provider parameter
|
||||
|
||||
Raises:
|
||||
NotLoginException: To use this api, you need login first.
|
||||
NotSupportError: Not supported platform.
|
||||
RequestError: The server return error.
|
||||
|
||||
Returns:
|
||||
InstanceInfo: The instance information.
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise NotLoginException(
|
||||
'Token does not exist, please login first.')
|
||||
if provider.vendor != Vendor.EAS:
|
||||
raise NotSupportError(
|
||||
'Not support vendor: %s ,only support EAS current.' %
|
||||
(provider.vendor))
|
||||
create_params = DeployServiceParameters(
|
||||
instance_name=instance_name,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
resource=resource,
|
||||
provider=provider)
|
||||
path = f'{self.endpoint}/api/v1/deployer/endpoint'
|
||||
body = attrs.asdict(create_params)
|
||||
r = requests.post(
|
||||
path,
|
||||
json=body,
|
||||
cookies=cookies,
|
||||
)
|
||||
handle_http_response(r, logger, cookies, 'create_eas_instance')
|
||||
if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def list_deployed_model_instances(self,
|
||||
provider: ServiceParameters,
|
||||
skip: int = 0,
|
||||
limit: int = 100):
|
||||
"""List deployed model instances.
|
||||
|
||||
Args:
|
||||
provider (ListServiceParameter): The cloud service provider parameter,
|
||||
for eas, need access_key_id and access_key_secret.
|
||||
skip: start of the list, current not support.
|
||||
limit: maximum number of instances return, current not support
|
||||
Raises:
|
||||
NotLoginException: To use this api, you need login first.
|
||||
RequestError: The request is failed from server.
|
||||
|
||||
Returns:
|
||||
List: List of instance information
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise NotLoginException(
|
||||
'Token does not exist, please login first.')
|
||||
params = ListServiceParameters(
|
||||
provider=provider, skip=skip, limit=limit)
|
||||
path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint,
|
||||
params.to_query_str())
|
||||
r = requests.get(path, cookies=cookies)
|
||||
handle_http_response(r, logger, cookies, 'list_deployed_model')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def get_deployed_model_instance(self, instance_name: str,
|
||||
provider: ServiceParameters):
|
||||
"""Query the specified instance information.
|
||||
|
||||
Args:
|
||||
instance_name (str): The deployed instance name.
|
||||
provider (GetParameter): The cloud provider information, for eas
|
||||
need region(eg: ch-hangzhou), access_key_id and access_key_secret.
|
||||
|
||||
Raises:
|
||||
NotLoginException: To use this api, you need login first.
|
||||
RequestError: The request is failed from server.
|
||||
|
||||
Returns:
|
||||
Dict: The request instance information
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise NotLoginException(
|
||||
'Token does not exist, please login first.')
|
||||
params = GetServiceParameters(provider=provider)
|
||||
path = '%s/api/v1/deployer/endpoint/%s?%s' % (
|
||||
self.endpoint, instance_name, params.to_query_str())
|
||||
r = requests.get(path, cookies=cookies)
|
||||
handle_http_response(r, logger, cookies, 'get_deployed_model')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def delete_deployed_model_instance(self, instance_name: str,
|
||||
provider: ServiceParameters):
|
||||
"""Delete deployed model, this api send delete command and return, it will take
|
||||
some to delete, please check through the cloud console.
|
||||
|
||||
Args:
|
||||
instance_name (str): The instance name you want to delete.
|
||||
provider (DeleteParameter): The cloud provider information, for eas
|
||||
need region(eg: ch-hangzhou), access_key_id and access_key_secret.
|
||||
|
||||
Raises:
|
||||
NotLoginException: To call this api, you need login first.
|
||||
RequestError: The request is failed.
|
||||
|
||||
Returns:
|
||||
Dict: The deleted instance information.
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise NotLoginException(
|
||||
'Token does not exist, please login first.')
|
||||
params = DeleteServiceParameters(provider=provider)
|
||||
path = '%s/api/v1/deployer/endpoint/%s?%s' % (
|
||||
self.endpoint, instance_name, params.to_query_str())
|
||||
r = requests.delete(path, cookies=cookies)
|
||||
handle_http_response(r, logger, cookies, 'delete_deployed_model')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def _check_cookie(self,
|
||||
use_cookies: Union[bool,
|
||||
CookieJar] = False) -> CookieJar:
|
||||
|
||||
@@ -1,11 +1,25 @@
|
||||
import urllib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Union
|
||||
from abc import ABC
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
|
||||
import attrs
|
||||
import json
|
||||
from attr import fields
|
||||
import requests
|
||||
from attrs import asdict, define, field, validators
|
||||
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
|
||||
API_RESPONSE_FIELD_MESSAGE)
|
||||
from modelscope.hub.errors import (NotLoginException, NotSupportError,
|
||||
RequestError, handle_http_response, is_ok)
|
||||
from modelscope.hub.utils.utils import get_endpoint
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
# yapf: enable
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class Accelerator(object):
|
||||
CPU = 'cpu'
|
||||
@@ -76,12 +90,12 @@ class ServiceResourceConfig(object):
|
||||
|
||||
|
||||
@define
|
||||
class ServiceParameters(ABC):
|
||||
class ServiceProviderParameters(ABC):
|
||||
pass
|
||||
|
||||
|
||||
@define
|
||||
class EASDeployParameters(ServiceParameters):
|
||||
class EASDeployParameters(ServiceProviderParameters):
|
||||
"""Parameters for EAS Deployment.
|
||||
|
||||
Args:
|
||||
@@ -97,29 +111,10 @@ class EASDeployParameters(ServiceParameters):
|
||||
resource_group: Optional[str] = None
|
||||
vendor: str = field(
|
||||
default=Vendor.EAS, validator=validators.in_([Vendor.EAS]))
|
||||
"""
|
||||
def __init__(self,
|
||||
instance_name: str,
|
||||
access_key_id: str,
|
||||
access_key_secret: str,
|
||||
region = EASRegion.beijing,
|
||||
instance_type: str = EASCpuInstances.small,
|
||||
accelerator: str = Accelerator.CPU,
|
||||
resource_group: Optional[str] = None,
|
||||
scaling: Optional[str] = None):
|
||||
self.instance_name=instance_name
|
||||
self.access_key_id=self.access_key_id
|
||||
self.access_key_secret = access_key_secret
|
||||
self.region = region
|
||||
self.instance_type = instance_type
|
||||
self.accelerator = accelerator
|
||||
self.resource_group = resource_group
|
||||
self.scaling = scaling
|
||||
"""
|
||||
|
||||
|
||||
@define
|
||||
class EASListParameters(ServiceParameters):
|
||||
class EASListParameters(ServiceProviderParameters):
|
||||
"""EAS instance list parameters.
|
||||
|
||||
Args:
|
||||
@@ -152,7 +147,7 @@ class DeployServiceParameters(object):
|
||||
model_id: str
|
||||
revision: str
|
||||
resource: ServiceResourceConfig
|
||||
provider: ServiceParameters
|
||||
provider: ServiceProviderParameters
|
||||
|
||||
|
||||
class AttrsToQueryString(ABC):
|
||||
@@ -174,16 +169,173 @@ class AttrsToQueryString(ABC):
|
||||
|
||||
@define
|
||||
class ListServiceParameters(AttrsToQueryString):
|
||||
provider: ServiceParameters
|
||||
provider: ServiceProviderParameters
|
||||
skip: int = 0
|
||||
limit: int = 100
|
||||
|
||||
|
||||
@define
|
||||
class GetServiceParameters(AttrsToQueryString):
|
||||
provider: ServiceParameters
|
||||
provider: ServiceProviderParameters
|
||||
|
||||
|
||||
@define
|
||||
class DeleteServiceParameters(AttrsToQueryString):
|
||||
provider: ServiceParameters
|
||||
provider: ServiceProviderParameters
|
||||
|
||||
|
||||
class ServiceDeployer(object):
|
||||
|
||||
def __init__(self, endpoint=None):
|
||||
self.endpoint = endpoint if endpoint is not None else get_endpoint()
|
||||
self.cookies = ModelScopeConfig.get_cookies()
|
||||
if self.cookies is None:
|
||||
raise NotLoginException(
|
||||
'Token does not exist, please login with HubApi first.')
|
||||
|
||||
# deploy_model
|
||||
def create(self, model_id: str, revision: str, instance_name: str,
|
||||
resource: ServiceResourceConfig,
|
||||
provider: ServiceProviderParameters):
|
||||
"""Deploy model to cloud, current we only support PAI EAS, this is an async API ,
|
||||
and the deployment could take a while to finish remotely. Please check deploy instance
|
||||
status separately via checking the status.
|
||||
|
||||
Args:
|
||||
model_id (str): The deployed model id
|
||||
revision (str): The model revision
|
||||
instance_name (str): The deployed model instance name.
|
||||
resource (ServiceResourceConfig): The service resource information.
|
||||
provider (ServiceProviderParameters): The service provider parameter
|
||||
|
||||
Raises:
|
||||
NotLoginException: To use this api, you need login first.
|
||||
NotSupportError: Not supported platform.
|
||||
RequestError: The server return error.
|
||||
|
||||
Returns:
|
||||
ServiceInstanceInfo: The information of the deployed service instance.
|
||||
"""
|
||||
if provider.vendor != Vendor.EAS:
|
||||
raise NotSupportError(
|
||||
'Not support vendor: %s ,only support EAS current.' %
|
||||
(provider.vendor))
|
||||
create_params = DeployServiceParameters(
|
||||
instance_name=instance_name,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
resource=resource,
|
||||
provider=provider)
|
||||
path = f'{self.endpoint}/api/v1/deployer/endpoint'
|
||||
body = attrs.asdict(create_params)
|
||||
r = requests.post(
|
||||
path,
|
||||
json=body,
|
||||
cookies=self.cookies,
|
||||
)
|
||||
handle_http_response(r, logger, self.cookies, 'create_service')
|
||||
if r.status_code >= HTTPStatus.OK and r.status_code < HTTPStatus.MULTIPLE_CHOICES:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def get(self, instance_name: str, provider: ServiceProviderParameters):
|
||||
"""Query the specified instance information.
|
||||
|
||||
Args:
|
||||
instance_name (str): The deployed instance name.
|
||||
provider (ServiceProviderParameters): The cloud provider information, for eas
|
||||
need region(eg: ch-hangzhou), access_key_id and access_key_secret.
|
||||
|
||||
Raises:
|
||||
NotLoginException: To use this api, you need login first.
|
||||
RequestError: The request is failed from server.
|
||||
|
||||
Returns:
|
||||
Dict: The information of the requested service instance.
|
||||
"""
|
||||
params = GetServiceParameters(provider=provider)
|
||||
path = '%s/api/v1/deployer/endpoint/%s?%s' % (
|
||||
self.endpoint, instance_name, params.to_query_str())
|
||||
r = requests.get(path, cookies=self.cookies)
|
||||
handle_http_response(r, logger, self.cookies, 'get_service')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def delete(self, instance_name: str, provider: ServiceProviderParameters):
|
||||
"""Delete deployed model, this api send delete command and return, it will take
|
||||
some to delete, please check through the cloud console.
|
||||
|
||||
Args:
|
||||
instance_name (str): The instance name you want to delete.
|
||||
provider (ServiceProviderParameters): The cloud provider information, for eas
|
||||
need region(eg: ch-hangzhou), access_key_id and access_key_secret.
|
||||
|
||||
Raises:
|
||||
NotLoginException: To call this api, you need login first.
|
||||
RequestError: The request is failed.
|
||||
|
||||
Returns:
|
||||
Dict: The deleted instance information.
|
||||
"""
|
||||
params = DeleteServiceParameters(provider=provider)
|
||||
path = '%s/api/v1/deployer/endpoint/%s?%s' % (
|
||||
self.endpoint, instance_name, params.to_query_str())
|
||||
r = requests.delete(path, cookies=self.cookies)
|
||||
handle_http_response(r, logger, self.cookies, 'delete_service')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
def list(self,
|
||||
provider: ServiceProviderParameters,
|
||||
skip: int = 0,
|
||||
limit: int = 100):
|
||||
"""List deployed model instances.
|
||||
|
||||
Args:
|
||||
provider (ServiceProviderParameters): The cloud service provider parameter,
|
||||
for eas, need access_key_id and access_key_secret.
|
||||
skip: start of the list, current not support.
|
||||
limit: maximum number of instances return, current not support
|
||||
Raises:
|
||||
NotLoginException: To use this api, you need login first.
|
||||
RequestError: The request is failed from server.
|
||||
|
||||
Returns:
|
||||
List: List of instance information
|
||||
"""
|
||||
|
||||
params = ListServiceParameters(
|
||||
provider=provider, skip=skip, limit=limit)
|
||||
path = '%s/api/v1/deployer/endpoint?%s' % (self.endpoint,
|
||||
params.to_query_str())
|
||||
r = requests.get(path, cookies=self.cookies)
|
||||
handle_http_response(r, logger, self.cookies, 'list_service_instances')
|
||||
if r.status_code == HTTPStatus.OK:
|
||||
if is_ok(r.json()):
|
||||
data = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return data
|
||||
else:
|
||||
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
||||
else:
|
||||
r.raise_for_status()
|
||||
return None
|
||||
|
||||
@@ -49,7 +49,7 @@ class FaceRecognitionPipeline(Pipeline):
|
||||
# face detect pipeline
|
||||
det_model_id = 'damo/cv_resnet_facedetection_scrfd10gkps'
|
||||
self.face_detection = pipeline(
|
||||
Tasks.face_detection, model=det_model_id, model_revision='v2')
|
||||
Tasks.face_detection, model=det_model_id)
|
||||
|
||||
def _choose_face(self,
|
||||
det_result,
|
||||
|
||||
@@ -17,6 +17,9 @@ from modelscope.preprocessors.space_T_cn.fields.database import Database
|
||||
from modelscope.preprocessors.space_T_cn.fields.struct import (Constant,
|
||||
SQLQuery)
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['TableQuestionAnsweringPipeline']
|
||||
|
||||
@@ -309,7 +312,8 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
'header_name': header_names,
|
||||
'rows': rows
|
||||
}
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
tabledata = {'header_id': [], 'header_name': [], 'rows': []}
|
||||
else:
|
||||
tabledata = {'header_id': [], 'header_name': [], 'rows': []}
|
||||
|
||||
@@ -17,7 +17,8 @@ class Database:
|
||||
self.tokenizer = tokenizer
|
||||
self.is_use_sqlite = is_use_sqlite
|
||||
if self.is_use_sqlite:
|
||||
self.connection_obj = sqlite3.connect(':memory:')
|
||||
self.connection_obj = sqlite3.connect(
|
||||
':memory:', check_same_thread=False)
|
||||
self.type_dict = {'text': 'TEXT', 'number': 'INT', 'date': 'TEXT'}
|
||||
self.tables = self.init_tables(table_file_path=table_file_path)
|
||||
self.syn_dict = self.init_syn_dict(
|
||||
|
||||
@@ -28,8 +28,7 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
input_location = ['data/test/images/face_detection2.jpeg']
|
||||
|
||||
dataset = MsDataset.load(input_location, target='image')
|
||||
face_detection = pipeline(
|
||||
Tasks.face_detection, model=self.model_id, model_revision='v2')
|
||||
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
# note that for dataset output, the inference-output is a Generator that can be iterated.
|
||||
result = face_detection(dataset)
|
||||
result = next(result)
|
||||
@@ -37,8 +36,7 @@ class FaceDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
face_detection = pipeline(
|
||||
Tasks.face_detection, model=self.model_id, model_revision='v2')
|
||||
face_detection = pipeline(Tasks.face_detection, model=self.model_id)
|
||||
img_path = 'data/test/images/face_detection2.jpeg'
|
||||
|
||||
result = face_detection(img_path)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import unittest
|
||||
from threading import Thread
|
||||
from typing import List
|
||||
|
||||
import json
|
||||
@@ -108,8 +109,6 @@ class TableQuestionAnswering(unittest.TestCase):
|
||||
self.task = Tasks.table_question_answering
|
||||
self.model_id = 'damo/nlp_convai_text2sql_pretrain_cn'
|
||||
|
||||
model_id = 'damo/nlp_convai_text2sql_pretrain_cn'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
@@ -122,6 +121,27 @@ class TableQuestionAnswering(unittest.TestCase):
|
||||
]
|
||||
tableqa_tracking_and_print_results_with_history(pipelines)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download_with_multithreads(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
pl = pipeline(Tasks.table_question_answering, model=cache_path)
|
||||
|
||||
def print_func(pl, i):
|
||||
result = pl({
|
||||
'question': '长江流域的小(2)型水库的库容总量是多少?',
|
||||
'table_id': 'reservoir',
|
||||
'history_sql': None
|
||||
})
|
||||
print(i, json.dumps(result))
|
||||
|
||||
procs = []
|
||||
for i in range(5):
|
||||
proc = Thread(target=print_func, args=(pl, i))
|
||||
procs.append(proc)
|
||||
proc.start()
|
||||
for proc in procs:
|
||||
proc.join()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
|
||||
@@ -28,7 +28,7 @@ def _setup():
|
||||
val_root = val_dir + '/' + os.listdir(val_dir)[0] + '/'
|
||||
max_epochs = 1 # run epochs in unit test
|
||||
|
||||
cache_path = snapshot_download(model_id, revision='v2')
|
||||
cache_path = snapshot_download(model_id)
|
||||
|
||||
tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(tmp_dir):
|
||||
|
||||
@@ -34,14 +34,14 @@ class ImageDenoiseTrainerTest(unittest.TestCase):
|
||||
'SIDD',
|
||||
namespace='huizheng',
|
||||
subset_name='default',
|
||||
split='validation',
|
||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds
|
||||
split='test',
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds
|
||||
dataset_val = MsDataset.load(
|
||||
'SIDD',
|
||||
namespace='huizheng',
|
||||
subset_name='default',
|
||||
split='test',
|
||||
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS)._hf_ds
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD)._hf_ds
|
||||
self.dataset_train = SiddImageDenoisingDataset(
|
||||
dataset_train, self.config.dataset, is_train=True)
|
||||
self.dataset_val = SiddImageDenoisingDataset(
|
||||
@@ -51,7 +51,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
@@ -65,7 +65,7 @@ class ImageDenoiseTrainerTest(unittest.TestCase):
|
||||
for i in range(2):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_trainer_with_model_and_args(self):
|
||||
model = NAFNetForImageDenoise.from_pretrained(self.cache_path)
|
||||
kwargs = dict(
|
||||
|
||||
Reference in New Issue
Block a user