mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
Feat/add official tag (#1509)
This commit is contained in:
@@ -37,6 +37,7 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
|||||||
API_RESPONSE_FIELD_MESSAGE,
|
API_RESPONSE_FIELD_MESSAGE,
|
||||||
API_RESPONSE_FIELD_USERNAME,
|
API_RESPONSE_FIELD_USERNAME,
|
||||||
DEFAULT_MAX_WORKERS,
|
DEFAULT_MAX_WORKERS,
|
||||||
|
DEFAULT_MODELSCOPE_INTL_DOMAIN,
|
||||||
MODELSCOPE_CLOUD_ENVIRONMENT,
|
MODELSCOPE_CLOUD_ENVIRONMENT,
|
||||||
MODELSCOPE_CLOUD_USERNAME,
|
MODELSCOPE_CLOUD_USERNAME,
|
||||||
MODELSCOPE_CREDENTIALS_PATH,
|
MODELSCOPE_CREDENTIALS_PATH,
|
||||||
@@ -301,18 +302,26 @@ class HubApi:
|
|||||||
'WeightsSha256': aigc_model.weight_sha256,
|
'WeightsSha256': aigc_model.weight_sha256,
|
||||||
'WeightsSize': aigc_model.weight_size,
|
'WeightsSize': aigc_model.weight_size,
|
||||||
'ModelPath': aigc_model.model_path,
|
'ModelPath': aigc_model.model_path,
|
||||||
'TriggerWords': aigc_model.trigger_words
|
'TriggerWords': aigc_model.trigger_words,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if aigc_model.official_tags:
|
||||||
|
body['OfficialTags'] = aigc_model.official_tags
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Use regular model endpoint
|
# Use regular model endpoint
|
||||||
path = f'{endpoint}/api/v1/models'
|
path = f'{endpoint}/api/v1/models'
|
||||||
|
|
||||||
|
headers = self.builder_headers(self.headers)
|
||||||
|
|
||||||
|
intl_end = DEFAULT_MODELSCOPE_INTL_DOMAIN.split('.')[-1]
|
||||||
|
if endpoint.rstrip('/').endswith(f'.{intl_end}'):
|
||||||
|
headers['X-Modelscope-Accept-Language'] = 'en_US'
|
||||||
r = self.session.post(
|
r = self.session.post(
|
||||||
path,
|
path,
|
||||||
json=body,
|
json=body,
|
||||||
cookies=cookies,
|
cookies=cookies,
|
||||||
headers=self.builder_headers(self.headers))
|
headers=headers)
|
||||||
raise_for_http_status(r)
|
raise_for_http_status(r)
|
||||||
d = r.json()
|
d = r.json()
|
||||||
raise_on_error(d)
|
raise_on_error(d)
|
||||||
|
|||||||
@@ -58,6 +58,17 @@ class AigcModel:
|
|||||||
'WAN_VIDEO_2_1_FLF2V_14_B'
|
'WAN_VIDEO_2_1_FLF2V_14_B'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OFFICIAL_TAGS = {
|
||||||
|
'photography', 'illustration-design', 'e-commerce-design', 'dimension',
|
||||||
|
'3d', 'hand-drawn-style', 'logo', 'commodity', 'toy-figurines',
|
||||||
|
'flat-abstraction', 'character-enhancement', 'scenery', 'animal',
|
||||||
|
'art-style-strong', 'other-styles', 'architectural-design',
|
||||||
|
'classic-painting-style', 'cg-fantasy', 'artware', 'construction',
|
||||||
|
'man', 'woman', 'food', 'automobile-traffic', 'sci-fi-mecha',
|
||||||
|
'clothing', 'plant', 'other-functions', 'picture-control',
|
||||||
|
'main-strong', 'character-strong'
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
aigc_type: str,
|
aigc_type: str,
|
||||||
base_model_type: str,
|
base_model_type: str,
|
||||||
@@ -67,7 +78,8 @@ class AigcModel:
|
|||||||
description: Optional[str] = 'this is an aigc model',
|
description: Optional[str] = 'this is an aigc model',
|
||||||
cover_images: Optional[List[str]] = None,
|
cover_images: Optional[List[str]] = None,
|
||||||
path_in_repo: Optional[str] = '',
|
path_in_repo: Optional[str] = '',
|
||||||
trigger_words: Optional[List[str]] = None):
|
trigger_words: Optional[List[str]] = None,
|
||||||
|
official_tags: Optional[List[str]] = None):
|
||||||
"""
|
"""
|
||||||
Initializes the AigcModel helper.
|
Initializes the AigcModel helper.
|
||||||
|
|
||||||
@@ -81,6 +93,7 @@ class AigcModel:
|
|||||||
base_model_id (str, optional): Base model name. e.g., 'AI-ModelScope/FLUX.1-dev'.
|
base_model_id (str, optional): Base model name. e.g., 'AI-ModelScope/FLUX.1-dev'.
|
||||||
path_in_repo (str, optional): Path in the repository.
|
path_in_repo (str, optional): Path in the repository.
|
||||||
trigger_words (List[str], optional): Trigger words for the AIGC Lora model.
|
trigger_words (List[str], optional): Trigger words for the AIGC Lora model.
|
||||||
|
official_tags (List[str], optional): Official tags for the AIGC model. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.aigc_type = aigc_type
|
self.aigc_type = aigc_type
|
||||||
@@ -123,6 +136,12 @@ class AigcModel:
|
|||||||
self._validate_aigc_type()
|
self._validate_aigc_type()
|
||||||
self._validate_base_model_type()
|
self._validate_base_model_type()
|
||||||
|
|
||||||
|
if official_tags:
|
||||||
|
self.official_tags = official_tags
|
||||||
|
self._validate_official_tags()
|
||||||
|
else:
|
||||||
|
self.official_tags = None
|
||||||
|
|
||||||
# Process model path and calculate weights information
|
# Process model path and calculate weights information
|
||||||
self._process_model_path()
|
self._process_model_path()
|
||||||
|
|
||||||
@@ -143,6 +162,19 @@ class AigcModel:
|
|||||||
f'Recommended values: {supported_types}. '
|
f'Recommended values: {supported_types}. '
|
||||||
f'Custom values are allowed but may cause issues. ')
|
f'Custom values are allowed but may cause issues. ')
|
||||||
|
|
||||||
|
def _validate_official_tags(self):
|
||||||
|
"""Validate official tags and provide warning for unsupported tags."""
|
||||||
|
invalid_tags = {
|
||||||
|
tag
|
||||||
|
for tag in self.official_tags if tag not in self.OFFICIAL_TAGS
|
||||||
|
}
|
||||||
|
if invalid_tags:
|
||||||
|
supported_tags = ', '.join(self.OFFICIAL_TAGS)
|
||||||
|
invalid_tags_str = ', '.join(f'"{tag}"' for tag in invalid_tags)
|
||||||
|
logger.warning(
|
||||||
|
f'Your tag(s): {invalid_tags_str} may not be supported. '
|
||||||
|
f'Recommended values: {supported_tags}. ')
|
||||||
|
|
||||||
def _process_model_path(self):
|
def _process_model_path(self):
|
||||||
"""Process model_path to extract weight information"""
|
"""Process model_path to extract weight information"""
|
||||||
from modelscope.utils.file_utils import get_file_hash
|
from modelscope.utils.file_utils import get_file_hash
|
||||||
@@ -350,7 +382,8 @@ class AigcModel:
|
|||||||
'weight_filename': self.weight_filename,
|
'weight_filename': self.weight_filename,
|
||||||
'weight_sha256': self.weight_sha256,
|
'weight_sha256': self.weight_sha256,
|
||||||
'weight_size': self.weight_size,
|
'weight_size': self.weight_size,
|
||||||
'trigger_words': self.trigger_words
|
'trigger_words': self.trigger_words,
|
||||||
|
'official_tags': self.official_tags
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user