mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
tmp for hf wrapper
This commit is contained in:
@@ -76,7 +76,7 @@ def pipeline(task: str = None,
|
||||
config_file: str = None,
|
||||
pipeline_name: str = None,
|
||||
framework: str = None,
|
||||
device: str = 'gpu',
|
||||
device: str = None,
|
||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
ignore_file_pattern: List[str] = None,
|
||||
**kwargs) -> Pipeline:
|
||||
@@ -174,9 +174,13 @@ def pipeline(task: str = None,
|
||||
if not hasattr(first_model, 'pipeline'):
|
||||
# model is instantiated by user, we should parse config again
|
||||
cfg = read_config(first_model.model_dir)
|
||||
check_config(cfg)
|
||||
first_model.pipeline = cfg.pipeline
|
||||
pipeline_props = first_model.pipeline
|
||||
try:
|
||||
check_config(cfg)
|
||||
first_model.pipeline = cfg.pipeline
|
||||
except AssertionError as e:
|
||||
logger.info(str(e))
|
||||
if first_model.__dict__.get('pipeline'):
|
||||
pipeline_props = first_model.pipeline
|
||||
else:
|
||||
pipeline_name, default_model_repo = get_default_pipeline_info(task)
|
||||
model = normalize_model_input(default_model_repo, model_revision)
|
||||
@@ -192,6 +196,8 @@ def pipeline(task: str = None,
|
||||
device=device,
|
||||
**kwargs)
|
||||
|
||||
if not device:
|
||||
device = 'gpu'
|
||||
pipeline_props['model'] = model
|
||||
pipeline_props['device'] = device
|
||||
cfg = ConfigDict(pipeline_props)
|
||||
|
||||
@@ -63,11 +63,13 @@ from transformers import (PretrainedConfig, PreTrainedModel,
|
||||
PreTrainedTokenizerBase)
|
||||
from transformers import T5EncoderModel as T5EncoderModelHF
|
||||
from transformers import __version__ as transformers_version
|
||||
from transformers import pipeline as hf_pipeline
|
||||
from transformers import pipeline
|
||||
from transformers import Pipeline as PipelineHF
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
|
||||
from .logger import get_logger
|
||||
from ..pipelines.multi_modal.disco_guided_diffusion_pipeline.utils import NoneClass
|
||||
|
||||
try:
|
||||
from transformers import GPTQConfig as GPTQConfigHF
|
||||
@@ -78,6 +80,168 @@ except ImportError:
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
def _get_hf_device(device):
|
||||
if device
|
||||
return device
|
||||
|
||||
def _get_hf_pipeline_class():
|
||||
return NoneClass
|
||||
|
||||
def _wrapper_hf_pipeline_class(hf_pipeline_class: PipelineHF):
|
||||
class HFPipelineWrapper(PipelineHF):
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
safe_serialization: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
push_to_hub = kwargs.get('push_to_hub', False)
|
||||
if push_to_hub:
|
||||
kwargs.pop('push_to_hub')
|
||||
|
||||
|
||||
super().save_pretrained(self,
|
||||
save_directory=save_directory,
|
||||
safe_serialization=safe_serialization,
|
||||
push_to_hub=False,
|
||||
**kwargs)
|
||||
|
||||
def _upload_modified_files(
|
||||
self,
|
||||
working_dir: Union[str, os.PathLike],
|
||||
repo_id: str,
|
||||
files_timestamps: Dict[str, float],
|
||||
commit_message: Optional[str] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
create_pr: bool = False,
|
||||
revision: str = None,
|
||||
commit_description: str = None,
|
||||
):
|
||||
"""
|
||||
Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
|
||||
"""
|
||||
if commit_message is None:
|
||||
if "Model" in self.__class__.__name__:
|
||||
commit_message = "Upload model"
|
||||
elif "Config" in self.__class__.__name__:
|
||||
commit_message = "Upload config"
|
||||
elif "Tokenizer" in self.__class__.__name__:
|
||||
commit_message = "Upload tokenizer"
|
||||
elif "FeatureExtractor" in self.__class__.__name__:
|
||||
commit_message = "Upload feature extractor"
|
||||
elif "Processor" in self.__class__.__name__:
|
||||
commit_message = "Upload processor"
|
||||
else:
|
||||
commit_message = f"Upload {self.__class__.__name__}"
|
||||
modified_files = [
|
||||
f
|
||||
for f in os.listdir(working_dir)
|
||||
if f not in files_timestamps or os.path.getmtime(os.path.join(working_dir, f)) > files_timestamps[f]
|
||||
]
|
||||
|
||||
# filter for actual files + folders at the root level
|
||||
modified_files = [
|
||||
f
|
||||
for f in modified_files
|
||||
if os.path.isfile(os.path.join(working_dir, f)) or os.path.isdir(os.path.join(working_dir, f))
|
||||
]
|
||||
|
||||
operations = []
|
||||
# upload standalone files
|
||||
for file in modified_files:
|
||||
if os.path.isdir(os.path.join(working_dir, file)):
|
||||
# go over individual files of folder
|
||||
for f in os.listdir(os.path.join(working_dir, file)):
|
||||
operations.append(
|
||||
CommitOperationAdd(
|
||||
path_or_fileobj=os.path.join(working_dir, file, f), path_in_repo=os.path.join(file, f)
|
||||
)
|
||||
)
|
||||
else:
|
||||
operations.append(
|
||||
CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file)
|
||||
)
|
||||
|
||||
if revision is not None and not revision.startswith("refs/pr"):
|
||||
try:
|
||||
create_branch(repo_id=repo_id, branch=revision, token=token, exist_ok=True)
|
||||
except HfHubHTTPError as e:
|
||||
if e.response.status_code == 403 and create_pr:
|
||||
# If we are creating a PR on a repo we don't have access to, we can't create the branch.
|
||||
# so let's assume the branch already exists. If it's not the case, an error will be raised when
|
||||
# calling `create_commit` below.
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")
|
||||
return create_commit(
|
||||
repo_id=repo_id,
|
||||
operations=operations,
|
||||
commit_message=commit_message,
|
||||
commit_description=commit_description,
|
||||
token=token,
|
||||
create_pr=create_pr,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
def _create_repo(self,
|
||||
repo_id: str,
|
||||
private: Optional[bool] = None,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
repo_url: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create the repo if needed, cleans up repo_id with deprecated kwargs `repo_url` and `organization`, retrieves
|
||||
the token.
|
||||
"""
|
||||
if repo_url is not None:
|
||||
warnings.warn(
|
||||
"The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` "
|
||||
"instead."
|
||||
)
|
||||
if repo_id is not None:
|
||||
raise ValueError(
|
||||
"`repo_id` and `repo_url` are both specified. Please set only the argument `repo_id`."
|
||||
)
|
||||
repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "")
|
||||
if organization is not None:
|
||||
warnings.warn(
|
||||
"The `organization` argument is deprecated and will be removed in v5 of Transformers. Set your "
|
||||
"organization directly in the `repo_id` passed instead (`repo_id={organization}/{model_id}`)."
|
||||
)
|
||||
if not repo_id.startswith(organization):
|
||||
if "/" in repo_id:
|
||||
repo_id = repo_id.split("/")[-1]
|
||||
repo_id = f"{organization}/{repo_id}"
|
||||
|
||||
url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
|
||||
return url.repo_id
|
||||
|
||||
def hf_pipeline(
|
||||
task: str = None,
|
||||
model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None,
|
||||
framework: Optional[str] = None,
|
||||
device: Optional[Union[int, str, "torch.device"]] = None,
|
||||
**kwargs,
|
||||
)->PipelineHF:
|
||||
if isinstance(model, str):
|
||||
if not os.path.exists(model):
|
||||
model = snapshot_download(model)
|
||||
|
||||
framework = 'pt' if framework == 'pytorch' else framework
|
||||
|
||||
device = _get_hf_device(device)
|
||||
hf_pipeline_class = _get_hf_pipeline_class()
|
||||
wrapped_pipeline_class = _wrapper_hf_pipeline_class(hf_pipeline_class)
|
||||
|
||||
return pipeline(task=task,
|
||||
model=model,
|
||||
framework=framework,
|
||||
device=device,
|
||||
pipeline_class=wrapped_pipeline_class,
|
||||
**kwargs)
|
||||
|
||||
class UnsupportedAutoClass:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user