mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-15 15:57:42 +01:00
Feat/add official tag (#1509)
This commit is contained in:
@@ -37,6 +37,7 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
||||
API_RESPONSE_FIELD_MESSAGE,
|
||||
API_RESPONSE_FIELD_USERNAME,
|
||||
DEFAULT_MAX_WORKERS,
|
||||
DEFAULT_MODELSCOPE_INTL_DOMAIN,
|
||||
MODELSCOPE_CLOUD_ENVIRONMENT,
|
||||
MODELSCOPE_CLOUD_USERNAME,
|
||||
MODELSCOPE_CREDENTIALS_PATH,
|
||||
@@ -301,18 +302,26 @@ class HubApi:
|
||||
'WeightsSha256': aigc_model.weight_sha256,
|
||||
'WeightsSize': aigc_model.weight_size,
|
||||
'ModelPath': aigc_model.model_path,
|
||||
'TriggerWords': aigc_model.trigger_words
|
||||
'TriggerWords': aigc_model.trigger_words,
|
||||
})
|
||||
|
||||
if aigc_model.official_tags:
|
||||
body['OfficialTags'] = aigc_model.official_tags
|
||||
|
||||
else:
|
||||
# Use regular model endpoint
|
||||
path = f'{endpoint}/api/v1/models'
|
||||
|
||||
headers = self.builder_headers(self.headers)
|
||||
|
||||
intl_end = DEFAULT_MODELSCOPE_INTL_DOMAIN.split('.')[-1]
|
||||
if endpoint.rstrip('/').endswith(f'.{intl_end}'):
|
||||
headers['X-Modelscope-Accept-Language'] = 'en_US'
|
||||
r = self.session.post(
|
||||
path,
|
||||
json=body,
|
||||
cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
headers=headers)
|
||||
raise_for_http_status(r)
|
||||
d = r.json()
|
||||
raise_on_error(d)
|
||||
|
||||
@@ -58,6 +58,17 @@ class AigcModel:
|
||||
'WAN_VIDEO_2_1_FLF2V_14_B'
|
||||
}
|
||||
|
||||
OFFICIAL_TAGS = {
|
||||
'photography', 'illustration-design', 'e-commerce-design', 'dimension',
|
||||
'3d', 'hand-drawn-style', 'logo', 'commodity', 'toy-figurines',
|
||||
'flat-abstraction', 'character-enhancement', 'scenery', 'animal',
|
||||
'art-style-strong', 'other-styles', 'architectural-design',
|
||||
'classic-painting-style', 'cg-fantasy', 'artware', 'construction',
|
||||
'man', 'woman', 'food', 'automobile-traffic', 'sci-fi-mecha',
|
||||
'clothing', 'plant', 'other-functions', 'picture-control',
|
||||
'main-strong', 'character-strong'
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
aigc_type: str,
|
||||
base_model_type: str,
|
||||
@@ -67,7 +78,8 @@ class AigcModel:
|
||||
description: Optional[str] = 'this is an aigc model',
|
||||
cover_images: Optional[List[str]] = None,
|
||||
path_in_repo: Optional[str] = '',
|
||||
trigger_words: Optional[List[str]] = None):
|
||||
trigger_words: Optional[List[str]] = None,
|
||||
official_tags: Optional[List[str]] = None):
|
||||
"""
|
||||
Initializes the AigcModel helper.
|
||||
|
||||
@@ -81,6 +93,7 @@ class AigcModel:
|
||||
base_model_id (str, optional): Base model name. e.g., 'AI-ModelScope/FLUX.1-dev'.
|
||||
path_in_repo (str, optional): Path in the repository.
|
||||
trigger_words (List[str], optional): Trigger words for the AIGC Lora model.
|
||||
official_tags (List[str], optional): Official tags for the AIGC model. Defaults to None.
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.aigc_type = aigc_type
|
||||
@@ -123,6 +136,12 @@ class AigcModel:
|
||||
self._validate_aigc_type()
|
||||
self._validate_base_model_type()
|
||||
|
||||
if official_tags:
|
||||
self.official_tags = official_tags
|
||||
self._validate_official_tags()
|
||||
else:
|
||||
self.official_tags = None
|
||||
|
||||
# Process model path and calculate weights information
|
||||
self._process_model_path()
|
||||
|
||||
@@ -143,6 +162,19 @@ class AigcModel:
|
||||
f'Recommended values: {supported_types}. '
|
||||
f'Custom values are allowed but may cause issues. ')
|
||||
|
||||
def _validate_official_tags(self):
|
||||
"""Validate official tags and provide warning for unsupported tags."""
|
||||
invalid_tags = {
|
||||
tag
|
||||
for tag in self.official_tags if tag not in self.OFFICIAL_TAGS
|
||||
}
|
||||
if invalid_tags:
|
||||
supported_tags = ', '.join(self.OFFICIAL_TAGS)
|
||||
invalid_tags_str = ', '.join(f'"{tag}"' for tag in invalid_tags)
|
||||
logger.warning(
|
||||
f'Your tag(s): {invalid_tags_str} may not be supported. '
|
||||
f'Recommended values: {supported_tags}. ')
|
||||
|
||||
def _process_model_path(self):
|
||||
"""Process model_path to extract weight information"""
|
||||
from modelscope.utils.file_utils import get_file_hash
|
||||
@@ -350,7 +382,8 @@ class AigcModel:
|
||||
'weight_filename': self.weight_filename,
|
||||
'weight_sha256': self.weight_sha256,
|
||||
'weight_size': self.weight_size,
|
||||
'trigger_words': self.trigger_words
|
||||
'trigger_words': self.trigger_words,
|
||||
'official_tags': self.official_tags
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user