tmp for hf wrapper

This commit is contained in:
suluyan
2025-01-27 10:19:35 +08:00
parent defb668def
commit 816efec5d1
2 changed files with 175 additions and 5 deletions

View File

@@ -76,7 +76,7 @@ def pipeline(task: str = None,
config_file: str = None,
pipeline_name: str = None,
framework: str = None,
device: str = 'gpu',
device: str = None,
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
ignore_file_pattern: List[str] = None,
**kwargs) -> Pipeline:
@@ -174,9 +174,13 @@ def pipeline(task: str = None,
if not hasattr(first_model, 'pipeline'):
# model is instantiated by user, we should parse config again
cfg = read_config(first_model.model_dir)
check_config(cfg)
first_model.pipeline = cfg.pipeline
pipeline_props = first_model.pipeline
try:
check_config(cfg)
first_model.pipeline = cfg.pipeline
except AssertionError as e:
logger.info(str(e))
if first_model.__dict__.get('pipeline'):
pipeline_props = first_model.pipeline
else:
pipeline_name, default_model_repo = get_default_pipeline_info(task)
model = normalize_model_input(default_model_repo, model_revision)
@@ -192,6 +196,8 @@ def pipeline(task: str = None,
device=device,
**kwargs)
if not device:
device = 'gpu'
pipeline_props['model'] = model
pipeline_props['device'] = device
cfg = ConfigDict(pipeline_props)

View File

@@ -63,11 +63,13 @@ from transformers import (PretrainedConfig, PreTrainedModel,
PreTrainedTokenizerBase)
from transformers import T5EncoderModel as T5EncoderModelHF
from transformers import __version__ as transformers_version
from transformers import pipeline as hf_pipeline
from transformers import pipeline
from transformers import Pipeline as PipelineHF
from modelscope import snapshot_download
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
from .logger import get_logger
from ..pipelines.multi_modal.disco_guided_diffusion_pipeline.utils import NoneClass
try:
from transformers import GPTQConfig as GPTQConfigHF
@@ -78,6 +80,168 @@ except ImportError:
logger = get_logger()
def _get_hf_device(device):
if device
return device
def _get_hf_pipeline_class():
return NoneClass
def _wrapper_hf_pipeline_class(hf_pipeline_class: PipelineHF):
class HFPipelineWrapper(PipelineHF):
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = True,
**kwargs,
):
push_to_hub = kwargs.get('push_to_hub', False)
if push_to_hub:
kwargs.pop('push_to_hub')
super().save_pretrained(self,
save_directory=save_directory,
safe_serialization=safe_serialization,
push_to_hub=False,
**kwargs)
def _upload_modified_files(
self,
working_dir: Union[str, os.PathLike],
repo_id: str,
files_timestamps: Dict[str, float],
commit_message: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
revision: str = None,
commit_description: str = None,
):
"""
Uploads all modified files in `working_dir` to `repo_id`, based on `files_timestamps`.
"""
if commit_message is None:
if "Model" in self.__class__.__name__:
commit_message = "Upload model"
elif "Config" in self.__class__.__name__:
commit_message = "Upload config"
elif "Tokenizer" in self.__class__.__name__:
commit_message = "Upload tokenizer"
elif "FeatureExtractor" in self.__class__.__name__:
commit_message = "Upload feature extractor"
elif "Processor" in self.__class__.__name__:
commit_message = "Upload processor"
else:
commit_message = f"Upload {self.__class__.__name__}"
modified_files = [
f
for f in os.listdir(working_dir)
if f not in files_timestamps or os.path.getmtime(os.path.join(working_dir, f)) > files_timestamps[f]
]
# filter for actual files + folders at the root level
modified_files = [
f
for f in modified_files
if os.path.isfile(os.path.join(working_dir, f)) or os.path.isdir(os.path.join(working_dir, f))
]
operations = []
# upload standalone files
for file in modified_files:
if os.path.isdir(os.path.join(working_dir, file)):
# go over individual files of folder
for f in os.listdir(os.path.join(working_dir, file)):
operations.append(
CommitOperationAdd(
path_or_fileobj=os.path.join(working_dir, file, f), path_in_repo=os.path.join(file, f)
)
)
else:
operations.append(
CommitOperationAdd(path_or_fileobj=os.path.join(working_dir, file), path_in_repo=file)
)
if revision is not None and not revision.startswith("refs/pr"):
try:
create_branch(repo_id=repo_id, branch=revision, token=token, exist_ok=True)
except HfHubHTTPError as e:
if e.response.status_code == 403 and create_pr:
# If we are creating a PR on a repo we don't have access to, we can't create the branch.
# so let's assume the branch already exists. If it's not the case, an error will be raised when
# calling `create_commit` below.
pass
else:
raise
logger.info(f"Uploading the following files to {repo_id}: {','.join(modified_files)}")
return create_commit(
repo_id=repo_id,
operations=operations,
commit_message=commit_message,
commit_description=commit_description,
token=token,
create_pr=create_pr,
revision=revision,
)
def _create_repo(self,
repo_id: str,
private: Optional[bool] = None,
token: Optional[Union[bool, str]] = None,
repo_url: Optional[str] = None,
organization: Optional[str] = None,
) -> str:
"""
Create the repo if needed, cleans up repo_id with deprecated kwargs `repo_url` and `organization`, retrieves
the token.
"""
if repo_url is not None:
warnings.warn(
"The `repo_url` argument is deprecated and will be removed in v5 of Transformers. Use `repo_id` "
"instead."
)
if repo_id is not None:
raise ValueError(
"`repo_id` and `repo_url` are both specified. Please set only the argument `repo_id`."
)
repo_id = repo_url.replace(f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/", "")
if organization is not None:
warnings.warn(
"The `organization` argument is deprecated and will be removed in v5 of Transformers. Set your "
"organization directly in the `repo_id` passed instead (`repo_id={organization}/{model_id}`)."
)
if not repo_id.startswith(organization):
if "/" in repo_id:
repo_id = repo_id.split("/")[-1]
repo_id = f"{organization}/{repo_id}"
url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
return url.repo_id
def hf_pipeline(
task: str = None,
model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None,
framework: Optional[str] = None,
device: Optional[Union[int, str, "torch.device"]] = None,
**kwargs,
)->PipelineHF:
if isinstance(model, str):
if not os.path.exists(model):
model = snapshot_download(model)
framework = 'pt' if framework == 'pytorch' else framework
device = _get_hf_device(device)
hf_pipeline_class = _get_hf_pipeline_class()
wrapped_pipeline_class = _wrapper_hf_pipeline_class(hf_pipeline_class)
return pipeline(task=task,
model=model,
framework=framework,
device=device,
pipeline_class=wrapped_pipeline_class,
**kwargs)
class UnsupportedAutoClass: