mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
add llm-docker dependencies in build_image.sh
This commit is contained in:
@@ -159,7 +159,7 @@ docker_file_content=`cat docker/Dockerfile.ubuntu`
|
||||
|
||||
BUILD_HASH_ID=$(git rev-parse HEAD)
|
||||
# install thrid part library
|
||||
docker_file_content="${docker_file_content} \nRUN export COMMIT_ID=$BUILD_HASH_ID && pip install --no-cache-dir -U adaseq pai-easycv && pip install --no-cache-dir -U 'ms-swift' 'decord' 'qwen_vl_utils' 'pyav' 'librosa' 'funasr' 'timm>0.9.5' 'accelerate' 'gradio' 'peft' 'optimum' 'trl' 'transformers'"
|
||||
docker_file_content="${docker_file_content} \nRUN export COMMIT_ID=$BUILD_HASH_ID && pip install --no-cache-dir -U adaseq pai-easycv && pip install --no-cache-dir -U 'git+https://github.com/modelscope/ms-swift.git@release/2.5' 'decord' 'qwen_vl_utils' 'pyav' 'librosa' 'funasr' 'timm>0.9.5' 'transformers' 'accelerate' 'gradio' 'peft' 'optimum' 'trl'"
|
||||
|
||||
docker_file_content="${docker_file_content} \nRUN pip uninstall modelscope -y && export COMMIT_ID=$BUILD_HASH_ID && cd /tmp && GIT_LFS_SKIP_SMUDGE=1 git clone -b $build_branch --single-branch $REPO_URL && cd modelscope && pip install . && cd / && rm -fr /tmp/modelscope && pip cache purge;"
|
||||
|
||||
|
||||
10
.github/ISSUE_TEMPLATE/bug_report.md
vendored
10
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -3,7 +3,7 @@ name: Bug report
|
||||
about: Create a bug report to help us improve
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: Firmament-cyou, tastelikefeet, wangxingjun778, wenmengzhou, zzclynn
|
||||
assignees: tastelikefeet, wangxingjun778, yingdachen
|
||||
|
||||
---
|
||||
|
||||
@@ -36,14 +36,14 @@ A clear and concise description of what the bug is.
|
||||
|
||||
Please @ corresponding people according to your problem:
|
||||
|
||||
Model related: @wenmengzhou @tastelikefeet
|
||||
Model related: @tastelikefeet
|
||||
|
||||
Model hub related: @liuyhwangyh
|
||||
Model hub related: @liuyhwangyh @tastelikefeet @wangxingjun778
|
||||
|
||||
Dataset releated: @wangxingjun778
|
||||
|
||||
Finetune related: @tastelikefeet @Jintao-Huang
|
||||
|
||||
Pipeline related: @Firmament-cyou @wenmengzhou
|
||||
Pipeline related: @tastelikefeet @wangxingjun778
|
||||
|
||||
Contribute your model: @zzclynn
|
||||
Contribute your model: @yingdachen
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/feature_request.md
vendored
2
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -3,7 +3,7 @@ name: Feature request
|
||||
about: Suggest an idea for this project
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: tastelikefeet, wangxingjun778, wenmengzhou, yingdachen, zzclynn
|
||||
assignees: yingdachen, wangxingjun778, tastelikefeet
|
||||
|
||||
---
|
||||
|
||||
|
||||
8
.github/ISSUE_TEMPLATE/question.md
vendored
8
.github/ISSUE_TEMPLATE/question.md
vendored
@@ -3,7 +3,7 @@ name: Question
|
||||
about: Describe this issue template's purpose here.
|
||||
title: ''
|
||||
labels: ''
|
||||
assignees: zzclynn,wenmengzhou
|
||||
assignees: tastelikefeet, wangxingjun778, yingdachen
|
||||
|
||||
---
|
||||
|
||||
@@ -18,7 +18,7 @@ Before asking a question, make sure you have:
|
||||
|
||||
Please @ corresponding people according to your problem:
|
||||
|
||||
Model related: @wenmengzhou @tastelikefeet
|
||||
Model related: @tastelikefeet
|
||||
|
||||
Model hub related: @liuyhwangyh
|
||||
|
||||
@@ -26,6 +26,6 @@ Dataset releated: @wangxingjun778
|
||||
|
||||
Finetune related: @tastelikefeet @Jintao-Huang
|
||||
|
||||
Pipeline related: @Firmament-cyou @wenmengzhou
|
||||
Pipeline related: @tastelikefeet @wangxingjun778
|
||||
|
||||
Contribute your model: @zzclynn
|
||||
Contribute your model: @yingdachen
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
exclude: 'modelscope/preprocessors/templates/'
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pycqa/flake8.git
|
||||
rev: 4.0.0
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
exclude: 'modelscope/preprocessors/templates/'
|
||||
|
||||
repos:
|
||||
- repo: /home/admin/pre-commit/flake8
|
||||
rev: 4.0.0
|
||||
|
||||
107
modelscope/cli/clearcache.py
Normal file
107
modelscope/cli/clearcache.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.hub.constants import TEMPORARY_FOLDER_NAME
|
||||
|
||||
|
||||
def subparser_func(args):
|
||||
""" Function which will be called for a specific sub parser.
|
||||
"""
|
||||
return ClearCacheCMD(args)
|
||||
|
||||
|
||||
class ClearCacheCMD(CLICommand):
|
||||
name = 'clear-cache'
|
||||
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.cache_dir = os.getenv(
|
||||
'MODELSCOPE_CACHE',
|
||||
Path.home().joinpath('.cache', 'modelscope'))
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: ArgumentParser):
|
||||
""" define args for clear-cache command.
|
||||
"""
|
||||
parser = parsers.add_parser(ClearCacheCMD.name)
|
||||
group = parser.add_mutually_exclusive_group()
|
||||
group.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
help=
|
||||
'The id of the model whose cache will be cleared. For clear-cache, '
|
||||
'if neither model or dataset id is provided, entire cache will be cleared.'
|
||||
)
|
||||
group.add_argument(
|
||||
'--dataset',
|
||||
type=str,
|
||||
help=
|
||||
'The id of the dataset whose cache will be cleared. For clear-cache, '
|
||||
'if neither model or dataset id is provided, entire cache will be cleared.'
|
||||
)
|
||||
|
||||
parser.set_defaults(func=subparser_func)
|
||||
|
||||
def execute(self):
|
||||
self._execute_with_confirmation()
|
||||
|
||||
def _execute_with_confirmation(self):
|
||||
all = False
|
||||
single_model = False
|
||||
prompt = '\nYou are about to delete '
|
||||
|
||||
if self.args.model or self.args.dataset:
|
||||
if self.args.model:
|
||||
id = self.args.model
|
||||
single_model = True
|
||||
prompt = prompt + f'local cache for model {id}. '
|
||||
else:
|
||||
id = self.args.dataset
|
||||
prompt = prompt + f'local cache for dataset {id}. '
|
||||
else:
|
||||
prompt = prompt + f'entire ModelScope cache at {self.cache_dir}, including ALL models and dataset.\n'
|
||||
all = True
|
||||
user_input = input(
|
||||
prompt
|
||||
+ '\nPlease press Y or y to proceed, any other key to abort.\n'
|
||||
).strip().upper()
|
||||
|
||||
if user_input == 'Y':
|
||||
if all:
|
||||
self._remove_directory(self.cache_dir)
|
||||
print('Cache cleared.')
|
||||
else:
|
||||
entity_directory = os.path.join(
|
||||
self.cache_dir, 'hub' if single_model else 'datasets', id)
|
||||
temp_directory = os.path.join(
|
||||
self.cache_dir, 'hub' if single_model else 'datasets',
|
||||
TEMPORARY_FOLDER_NAME, id)
|
||||
entity_removed = self._remove_directory(entity_directory)
|
||||
temp_removed = self._remove_directory(temp_directory)
|
||||
if (not entity_removed) and (not temp_removed):
|
||||
if single_model:
|
||||
print(
|
||||
f'Cache for Model {id} not found. Nothing to do.')
|
||||
else:
|
||||
print(
|
||||
f'Cache for Dataset {id} not found. Nothing to do.'
|
||||
)
|
||||
else:
|
||||
print('Cache cleared.')
|
||||
else:
|
||||
print('Operation aborted.')
|
||||
return
|
||||
|
||||
def _remove_directory(self, path):
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
shutil.rmtree(path)
|
||||
print(f'Cache folder {path} removed.')
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f'An error occurred while clearing cache at {path}: {e}')
|
||||
return False
|
||||
@@ -3,6 +3,7 @@
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from modelscope.cli.clearcache import ClearCacheCMD
|
||||
from modelscope.cli.download import DownloadCMD
|
||||
from modelscope.cli.login import LoginCMD
|
||||
from modelscope.cli.modelcard import ModelCardCMD
|
||||
@@ -23,6 +24,7 @@ def run_cmd():
|
||||
subparsers = parser.add_subparsers(help='modelscope commands helpers')
|
||||
|
||||
DownloadCMD.define_args(subparsers)
|
||||
ClearCacheCMD.define_args(subparsers)
|
||||
PluginsCMD.define_args(subparsers)
|
||||
PipelineCMD.define_args(subparsers)
|
||||
ModelCardCMD.define_args(subparsers)
|
||||
|
||||
@@ -555,7 +555,7 @@ def get_module_without_script(self) -> DatasetModule:
|
||||
|
||||
download_config = self.download_config.copy()
|
||||
if download_config.download_desc is None:
|
||||
download_config.download_desc = 'Downloading readme'
|
||||
download_config.download_desc = 'Downloading [README.md]'
|
||||
try:
|
||||
url_or_filename = _ms_api.get_dataset_file_url(
|
||||
file_name='README.md',
|
||||
@@ -989,7 +989,6 @@ class DatasetsWrapperHF:
|
||||
download_config=download_config,
|
||||
download_mode=download_mode,
|
||||
verification_mode=verification_mode,
|
||||
try_from_hf_gcs=False,
|
||||
num_proc=num_proc,
|
||||
storage_options=storage_options,
|
||||
# base_path=builder_instance.base_path,
|
||||
|
||||
@@ -5,27 +5,138 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import copy
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
import inspect
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urljoin, urlparse
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from datasets import config
|
||||
from datasets.utils.file_utils import hash_url_to_filename, get_authentication_headers_for_url, ftp_head, fsspec_head, \
|
||||
http_head, _raise_if_offline_mode_is_enabled, ftp_get, fsspec_get, http_get
|
||||
from datasets.utils.file_utils import hash_url_to_filename, \
|
||||
get_authentication_headers_for_url, fsspec_head, fsspec_get
|
||||
from filelock import FileLock
|
||||
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
|
||||
from modelscope import __version__
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def get_datasets_user_agent_ms(user_agent: Optional[Union[str, dict]] = None) -> str:
|
||||
ua = f'datasets/{__version__}'
|
||||
ua += f'; python/{config.PY_VERSION}'
|
||||
ua += f'; pyarrow/{config.PYARROW_VERSION}'
|
||||
if config.TORCH_AVAILABLE:
|
||||
ua += f'; torch/{config.TORCH_VERSION}'
|
||||
if config.TF_AVAILABLE:
|
||||
ua += f'; tensorflow/{config.TF_VERSION}'
|
||||
if config.JAX_AVAILABLE:
|
||||
ua += f'; jax/{config.JAX_VERSION}'
|
||||
if isinstance(user_agent, dict):
|
||||
ua += f"; {'; '.join(f'{k}/{v}' for k, v in user_agent.items())}"
|
||||
elif isinstance(user_agent, str):
|
||||
ua += '; ' + user_agent
|
||||
return ua
|
||||
|
||||
|
||||
def _request_with_retry_ms(
|
||||
method: str,
|
||||
url: str,
|
||||
max_retries: int = 2,
|
||||
base_wait_time: float = 0.5,
|
||||
max_wait_time: float = 2,
|
||||
timeout: float = 10.0,
|
||||
**params,
|
||||
) -> requests.Response:
|
||||
"""Wrapper around requests to retry in case it fails with a ConnectTimeout, with exponential backoff.
|
||||
|
||||
Note that if the environment variable HF_DATASETS_OFFLINE is set to 1, then a OfflineModeIsEnabled error is raised.
|
||||
|
||||
Args:
|
||||
method (str): HTTP method, such as 'GET' or 'HEAD'.
|
||||
url (str): The URL of the resource to fetch.
|
||||
max_retries (int): Maximum number of retries, defaults to 0 (no retries).
|
||||
base_wait_time (float): Duration (in seconds) to wait before retrying the first time. Wait time between
|
||||
retries then grows exponentially, capped by max_wait_time.
|
||||
max_wait_time (float): Maximum amount of time between two retries, in seconds.
|
||||
**params (additional keyword arguments): Params to pass to :obj:`requests.request`.
|
||||
"""
|
||||
tries, success = 0, False
|
||||
response = None
|
||||
while not success:
|
||||
tries += 1
|
||||
try:
|
||||
response = requests.request(method=method.upper(), url=url, timeout=timeout, **params)
|
||||
success = True
|
||||
except (requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError) as err:
|
||||
if tries > max_retries:
|
||||
raise err
|
||||
else:
|
||||
logger.info(f'{method} request to {url} timed out, retrying... [{tries/max_retries}]')
|
||||
sleep_time = min(max_wait_time, base_wait_time * 2 ** (tries - 1)) # Exponential backoff
|
||||
time.sleep(sleep_time)
|
||||
return response
|
||||
|
||||
|
||||
def http_head_ms(
|
||||
url, proxies=None, headers=None, cookies=None, allow_redirects=True, timeout=10.0, max_retries=0
|
||||
) -> requests.Response:
|
||||
headers = copy.deepcopy(headers) or {}
|
||||
headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
|
||||
response = _request_with_retry_ms(
|
||||
method='HEAD',
|
||||
url=url,
|
||||
proxies=proxies,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
allow_redirects=allow_redirects,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def http_get_ms(
|
||||
url, temp_file, proxies=None, resume_size=0, headers=None, cookies=None, timeout=100.0, max_retries=0, desc=None
|
||||
) -> Optional[requests.Response]:
|
||||
headers = dict(headers) if headers is not None else {}
|
||||
headers['user-agent'] = get_datasets_user_agent_ms(user_agent=headers.get('user-agent'))
|
||||
if resume_size > 0:
|
||||
headers['Range'] = f'bytes={resume_size:d}-'
|
||||
response = _request_with_retry_ms(
|
||||
method='GET',
|
||||
url=url,
|
||||
stream=True,
|
||||
proxies=proxies,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
)
|
||||
if temp_file is None:
|
||||
return response
|
||||
if response.status_code == 416: # Range not satisfiable
|
||||
return
|
||||
content_length = response.headers.get('Content-Length')
|
||||
total = resume_size + int(content_length) if content_length is not None else None
|
||||
|
||||
progress = tqdm(total=total, initial=resume_size, unit_scale=True, unit='B', desc=desc or 'Downloading')
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
progress.update(len(chunk))
|
||||
temp_file.write(chunk)
|
||||
|
||||
progress.close()
|
||||
|
||||
|
||||
def get_from_cache_ms(
|
||||
url,
|
||||
cache_dir=None,
|
||||
@@ -42,7 +153,7 @@ def get_from_cache_ms(
|
||||
ignore_url_params=False,
|
||||
storage_options=None,
|
||||
download_desc=None,
|
||||
disable_tqdm=False,
|
||||
disable_tqdm=None,
|
||||
) -> str:
|
||||
"""
|
||||
Given a URL, look for the corresponding file in the local cache.
|
||||
@@ -88,6 +199,8 @@ def get_from_cache_ms(
|
||||
# if we don't ask for 'force_download' then we spare a request
|
||||
filename = hash_url_to_filename(cached_url, etag=None)
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
if download_desc is None:
|
||||
download_desc = 'Downloading [' + filename + ']'
|
||||
|
||||
if os.path.exists(cache_path) and not force_download and not use_etag:
|
||||
return cache_path
|
||||
@@ -100,16 +213,14 @@ def get_from_cache_ms(
|
||||
# We don't have the file locally or we need an eTag
|
||||
if not local_files_only:
|
||||
scheme = urlparse(url).scheme
|
||||
if scheme == 'ftp':
|
||||
connected = ftp_head(url)
|
||||
elif scheme not in ('http', 'https'):
|
||||
if scheme not in ('http', 'https'):
|
||||
response = fsspec_head(url, storage_options=storage_options)
|
||||
# s3fs uses "ETag", gcsfs uses "etag"
|
||||
etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None
|
||||
connected = True
|
||||
try:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
response = http_head(
|
||||
response = http_head_ms(
|
||||
url,
|
||||
allow_redirects=True,
|
||||
proxies=proxies,
|
||||
@@ -166,7 +277,6 @@ def get_from_cache_ms(
|
||||
)
|
||||
elif response is not None and response.status_code == 404:
|
||||
raise FileNotFoundError(f"Couldn't find file at {url}")
|
||||
_raise_if_offline_mode_is_enabled(f'Tried to reach {url}')
|
||||
if head_error is not None:
|
||||
raise ConnectionError(f"Couldn't reach {url} ({repr(head_error)})")
|
||||
elif response is not None:
|
||||
@@ -205,48 +315,21 @@ def get_from_cache_ms(
|
||||
# Download to temporary file, then copy to cache path once finished.
|
||||
# Otherwise, you get corrupt cache entries if the download gets interrupted.
|
||||
with temp_file_manager() as temp_file:
|
||||
logger.info(f'Downloading to {temp_file.name}')
|
||||
|
||||
# GET file object
|
||||
if scheme == 'ftp':
|
||||
ftp_get(url, temp_file)
|
||||
elif scheme not in ('http', 'https'):
|
||||
fsspec_get_sig = inspect.signature(fsspec_get)
|
||||
if 'disable_tqdm' in fsspec_get_sig.parameters:
|
||||
fsspec_get(url,
|
||||
temp_file,
|
||||
storage_options=storage_options,
|
||||
desc=download_desc,
|
||||
disable_tqdm=disable_tqdm
|
||||
)
|
||||
else:
|
||||
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
|
||||
if scheme not in ('http', 'https'):
|
||||
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
|
||||
else:
|
||||
http_get_sig = inspect.signature(http_get)
|
||||
|
||||
if 'disable_tqdm' in http_get_sig.parameters:
|
||||
http_get(
|
||||
url,
|
||||
temp_file=temp_file,
|
||||
proxies=proxies,
|
||||
resume_size=resume_size,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
max_retries=max_retries,
|
||||
desc=download_desc,
|
||||
disable_tqdm=disable_tqdm,
|
||||
)
|
||||
else:
|
||||
http_get(
|
||||
url,
|
||||
temp_file=temp_file,
|
||||
proxies=proxies,
|
||||
resume_size=resume_size,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
max_retries=max_retries,
|
||||
desc=download_desc,
|
||||
)
|
||||
http_get_ms(
|
||||
url,
|
||||
temp_file=temp_file,
|
||||
proxies=proxies,
|
||||
resume_size=resume_size,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
max_retries=max_retries,
|
||||
desc=download_desc,
|
||||
)
|
||||
|
||||
logger.info(f'storing {url} in cache at {cache_path}')
|
||||
shutil.move(temp_file.name, cache_path)
|
||||
|
||||
2
modelscope/preprocessors/templates/__init__.py
Normal file
2
modelscope/preprocessors/templates/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import Template, get_template
|
||||
from .template import TemplateType
|
||||
1041
modelscope/preprocessors/templates/base.py
Normal file
1041
modelscope/preprocessors/templates/base.py
Normal file
File diff suppressed because it is too large
Load Diff
371
modelscope/preprocessors/templates/loader.py
Normal file
371
modelscope/preprocessors/templates/loader.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
from modelscope import AutoTokenizer, get_logger, snapshot_download
|
||||
from . import TemplateType
|
||||
from .base import Template, get_template
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemplateInfo:
|
||||
|
||||
template: str = None
|
||||
template_regex: str = None
|
||||
modelfile_link: str = None
|
||||
|
||||
|
||||
def cases(*names):
|
||||
ret = []
|
||||
for name in names:
|
||||
regex = ''
|
||||
for letter in name:
|
||||
if letter.upper() != letter.lower():
|
||||
regex += f'[{letter.upper()}{letter.lower()}]'
|
||||
else:
|
||||
regex += letter
|
||||
ret.append(regex)
|
||||
if len(ret) > 1:
|
||||
ret = '|'.join(ret)
|
||||
ret = '(' + ret + ')'
|
||||
else:
|
||||
ret = ret[0]
|
||||
return ret
|
||||
|
||||
|
||||
chat_suffix = cases('instruct', 'chat', '-rl', '-it')
|
||||
|
||||
|
||||
def no(*names):
|
||||
return f'(?!.*{cases(*names)})'
|
||||
|
||||
|
||||
def no_multi_modal():
|
||||
return no('audio', 'video', 'vl', 'vision')
|
||||
|
||||
|
||||
template_info = [
|
||||
# llama
|
||||
TemplateInfo(
|
||||
template=TemplateType.llama3,
|
||||
template_regex=
|
||||
f'.*{cases("llama3", "llama-3")}{no_multi_modal()}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/llama-3.modelfile',
|
||||
),
|
||||
TemplateInfo(
|
||||
template=TemplateType.llama,
|
||||
template_regex=
|
||||
f'.*{cases("llama2", "llama-2", "mistral", "codestral", "mixtral")}{no_multi_modal()}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# qwen
|
||||
TemplateInfo(
|
||||
template=TemplateType.qwen,
|
||||
template_regex=f'.*{cases("qwen")}{no_multi_modal()}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/qwen2.modelfile',
|
||||
),
|
||||
|
||||
# codeqwen1.5
|
||||
TemplateInfo(
|
||||
template_regex=
|
||||
f'.*{cases("codeqwen1.5", "codeqwen-1.5")}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/codeqwen1.5.modelfile',
|
||||
),
|
||||
|
||||
# chatml
|
||||
TemplateInfo(
|
||||
template=TemplateType.chatml,
|
||||
template_regex=
|
||||
f'.*{cases("yi")}{no_multi_modal()}{no("coder")}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/yi-1.5.modelfile',
|
||||
),
|
||||
|
||||
# chatml
|
||||
TemplateInfo(
|
||||
template=TemplateType.chatml,
|
||||
template_regex=f'.*{cases("minicpm")}{no("-v")}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/yi-1.5.modelfile'
|
||||
),
|
||||
|
||||
# chatglm
|
||||
TemplateInfo(
|
||||
template=TemplateType.chatglm2,
|
||||
template_regex=f'.*{cases("chatglm2")}{no_multi_modal()}.*'),
|
||||
TemplateInfo(
|
||||
template=TemplateType.chatglm3,
|
||||
template_regex=f'.*{cases("chatglm3")}{no_multi_modal()}.*'),
|
||||
TemplateInfo(
|
||||
template=TemplateType.chatglm4,
|
||||
template_regex=f'.*{cases("glm4")}{no_multi_modal()}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/glm4.modelfile',
|
||||
),
|
||||
|
||||
# baichuan
|
||||
TemplateInfo(
|
||||
template=TemplateType.baichuan,
|
||||
template_regex=
|
||||
f'.*{cases("baichuan")}{no_multi_modal()}.*{chat_suffix}.*'),
|
||||
|
||||
# codegeex
|
||||
TemplateInfo(
|
||||
template=TemplateType.codegeex4,
|
||||
template_regex=f'.*{cases("codegeex4")}{no_multi_modal()}.*'),
|
||||
|
||||
# idefics3
|
||||
TemplateInfo(
|
||||
template=TemplateType.idefics3,
|
||||
template_regex=f'.*{cases("idefics3")}{no_multi_modal()}.*'),
|
||||
|
||||
# mistral-nemo
|
||||
TemplateInfo(
|
||||
template=TemplateType.mistral_nemo,
|
||||
template_regex=f'.*{cases("Mistral-Nemo")}{no_multi_modal()}.*',
|
||||
modelfile_link='https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/mistral-nemo.modelfile'),
|
||||
|
||||
# internlm
|
||||
TemplateInfo(
|
||||
template=TemplateType.internlm,
|
||||
template_regex=
|
||||
f'.*{cases("internlm")}{no("internlm2", "internlm3")}{no_multi_modal()}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# internlm2
|
||||
TemplateInfo(
|
||||
template=TemplateType.internlm2,
|
||||
template_regex=
|
||||
f'.*{cases("internlm2")}{no_multi_modal()}.*{chat_suffix}.*'),
|
||||
|
||||
# yi-coder
|
||||
TemplateInfo(
|
||||
template=TemplateType.yi_coder,
|
||||
template_regex=f'.*{cases("yi")}.*{cases("coder")}.*{chat_suffix}.*'),
|
||||
|
||||
# yuan
|
||||
TemplateInfo(
|
||||
template=TemplateType.yuan,
|
||||
template_regex=f'.*{cases("Yuan")}{no_multi_modal()}.*'),
|
||||
|
||||
# xverse
|
||||
TemplateInfo(
|
||||
template=TemplateType.xverse,
|
||||
template_regex=f'.*{cases("xverse")}{no_multi_modal()}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# skywork
|
||||
TemplateInfo(
|
||||
template=TemplateType.skywork,
|
||||
template_regex=
|
||||
f'.*{cases("skywork")}{no_multi_modal()}.*{chat_suffix}.*'),
|
||||
|
||||
# bluelm
|
||||
TemplateInfo(
|
||||
template=TemplateType.bluelm,
|
||||
template_regex=f'.*{cases("bluelm")}{no_multi_modal()}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# zephyr
|
||||
TemplateInfo(
|
||||
template=TemplateType.zephyr,
|
||||
template_regex=f'.*{cases("zephyr")}{no_multi_modal()}.*'),
|
||||
|
||||
# deepseek
|
||||
TemplateInfo(
|
||||
template=TemplateType.deepseek,
|
||||
template_regex=
|
||||
f'.*{cases("deepseek")}{no("v2", "v2.5", "coder")}{no_multi_modal()}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# deepseek2
|
||||
TemplateInfo(
|
||||
template=TemplateType.deepseek2,
|
||||
template_regex=
|
||||
f'.*{cases("deepseek")}.*{cases("v2")}{no("v2.5")}{no_multi_modal()}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/deepseek_v2.modelfile',
|
||||
),
|
||||
|
||||
# deepseek_coder
|
||||
TemplateInfo(
|
||||
template=TemplateType.deepseek_coder,
|
||||
template_regex=
|
||||
f'.*{cases("deepseek")}{no("v2", "v2.5")}.*{cases("coder")}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# deepseek v2.5
|
||||
TemplateInfo(
|
||||
template=TemplateType.deepseek2_5,
|
||||
template_regex=
|
||||
f'.*{cases("deepseek")}.*{cases("v2.5")}{no_multi_modal()}.*'),
|
||||
|
||||
# orion
|
||||
TemplateInfo(
|
||||
template=TemplateType.orion,
|
||||
template_regex=f'.*{cases("orion")}{no_multi_modal()}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# gemma
|
||||
TemplateInfo(
|
||||
template=TemplateType.gemma,
|
||||
template_regex=
|
||||
f'{no("pali")}.*{cases("gemma2", "gemma-2")}\\b.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/gemma2.modelfile',
|
||||
),
|
||||
|
||||
# phi3
|
||||
TemplateInfo(
|
||||
template=TemplateType.phi3,
|
||||
template_regex=
|
||||
f'.*{cases("phi3", "phi-3")}{no_multi_modal()}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/phi3.modelfile',
|
||||
),
|
||||
|
||||
# telechat
|
||||
TemplateInfo(
|
||||
template=TemplateType.telechat,
|
||||
template_regex=f'.*{cases("TeleChat")}{no("v2")}.*'),
|
||||
|
||||
# telechat_v2
|
||||
TemplateInfo(
|
||||
template=TemplateType.telechat_v2,
|
||||
template_regex=f'.*{cases("TeleChat")}.*{cases("v2")}.*'),
|
||||
]
|
||||
|
||||
|
||||
class TemplateLoader:
|
||||
|
||||
@staticmethod
|
||||
def load_by_model_id(model_id: str, **kwargs) -> Template:
|
||||
"""Load a template by model-id
|
||||
|
||||
Args:
|
||||
model_id: The model-id used to load the proper template
|
||||
kwargs:
|
||||
revision: the revision of the model, default is `master`
|
||||
Returns:
|
||||
The template instance
|
||||
"""
|
||||
ignore_file_pattern = [r'.+\.bin$', r'.+\.safetensors$', r'.+\.gguf$']
|
||||
tokenizer = kwargs.get('tokenizer')
|
||||
for _info in template_info:
|
||||
if re.fullmatch(_info.template_regex, model_id):
|
||||
if _info.template:
|
||||
if tokenizer is None:
|
||||
try:
|
||||
model_dir = snapshot_download(
|
||||
model_id,
|
||||
revision=kwargs.pop('revision', 'master'),
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_dir, trust_remote_code=True)
|
||||
except Exception:
|
||||
pass
|
||||
return TemplateLoader.load_by_template_name(
|
||||
_info.template, tokenizer=tokenizer, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load_by_template_name(template_name: str, **kwargs) -> Template:
|
||||
"""Load a template by model-id
|
||||
|
||||
Args:
|
||||
template_name: The template name used to load the proper template
|
||||
kwargs:
|
||||
tokenizer: The tokenizer of the model
|
||||
default_system: The extra default system info
|
||||
max_length: The max_length for the sequence
|
||||
truncation_strategy: 'delete' or 'truncation_left' the sequence of the length exceeds the limit
|
||||
Returns:
|
||||
The template instance
|
||||
"""
|
||||
return get_template(template_name, tokenizer=kwargs.pop('tokenizer', None), **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def replace_and_concat(template: Template, template_list: List,
|
||||
placeholder: str, keyword: str):
|
||||
final_str = ''
|
||||
for t in template_list:
|
||||
if isinstance(t, str):
|
||||
final_str += t.replace(placeholder, keyword)
|
||||
elif isinstance(t, (tuple, list)):
|
||||
if isinstance(t[0], int):
|
||||
final_str += template.tokenizer.decode(t)
|
||||
else:
|
||||
for attr in t:
|
||||
if attr == 'bos_token_id':
|
||||
final_str += template.tokenizer.bos_token
|
||||
elif attr == 'eos_token_id':
|
||||
final_str += template.tokenizer.eos_token
|
||||
else:
|
||||
raise ValueError(f'Unknown token: {attr}')
|
||||
return final_str
|
||||
|
||||
@staticmethod
|
||||
def to_ollama(model_id: str = None,
|
||||
template_name: str = None,
|
||||
gguf_file: str = None,
|
||||
gguf_meta: Dict[str, Any] = None,
|
||||
**kwargs) -> str:
|
||||
"""Export to ollama ModelFile
|
||||
|
||||
Args:
|
||||
model_id: The model-id to use
|
||||
template_name: An extra template name to use
|
||||
gguf_file: An extra gguf_file path to use in the `FROM` field
|
||||
gguf_meta: An gguf extra meta info
|
||||
Returns:
|
||||
The ModelFile content, returns `None` if no template found
|
||||
"""
|
||||
logger.info('Exporting to ollama:')
|
||||
if model_id:
|
||||
for _info in template_info:
|
||||
if re.fullmatch(_info.template_regex, model_id):
|
||||
if _info.modelfile_link:
|
||||
return TemplateLoader._read_content_from_url(
|
||||
_info.modelfile_link)
|
||||
elif _info.template and not template_name:
|
||||
template_name = _info.template
|
||||
if template_name:
|
||||
template = TemplateLoader.load_by_template_name(
|
||||
template_name, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Please make sure you model_id: {model_id} '
|
||||
f'and template_name: {template_name} is supported.')
|
||||
|
||||
if template is None:
|
||||
return None
|
||||
|
||||
content = ''
|
||||
content += 'FROM {{gguf_file}}\n'
|
||||
content += (
|
||||
f'TEMPLATE """{{{{ if .System }}}}'
|
||||
f'{TemplateLoader.replace_and_concat(template, template.system_prefix or [], "{{SYSTEM}}", "{{ .System }}")}'
|
||||
f'{{{{ else }}}}{TemplateLoader.replace_and_concat(template, template.prefix, "", "")}'
|
||||
f'{{{{ end }}}}')
|
||||
content += (
|
||||
f'{{{{ if .Prompt }}}}'
|
||||
f'{TemplateLoader.replace_and_concat(template, template.prompt, "{{QUERY}}", "{{ .Prompt }}")}'
|
||||
f'{{{{ end }}}}')
|
||||
content += '{{ .Response }}'
|
||||
content += TemplateLoader.replace_and_concat(template, template.suffix,
|
||||
'', '') + '"""\n'
|
||||
content += f'PARAMETER stop "{TemplateLoader.replace_and_concat(template, template.suffix, "", "")}"\n'
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
def _read_content_from_url(url):
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
return content.decode('utf-8')
|
||||
101
modelscope/preprocessors/templates/loss_scale.py
Normal file
101
modelscope/preprocessors/templates/loss_scale.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from .utils import split_str_parts_by, split_parts_by_regex
|
||||
|
||||
|
||||
def calculate_loss_scale(query: str,
|
||||
response: str,
|
||||
response_loss_scale_map: Optional[Dict[str, list]] = None,
|
||||
query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]:
|
||||
"""Calculate the loss scale by splitting the agent response.
|
||||
|
||||
This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf
|
||||
|
||||
Agent response format:
|
||||
|
||||
```text
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of the above tools[fire_recognition,
|
||||
fire_alert, call_police, call_fireman]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
||||
Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question
|
||||
```
|
||||
Returns:
|
||||
A tuple of agent response parts and their weights.
|
||||
"""
|
||||
# query loss scale map
|
||||
if query_loss_scale_map is not None:
|
||||
for key in query_loss_scale_map.keys():
|
||||
if key in query:
|
||||
if isinstance(query_loss_scale_map[key], (float, int)):
|
||||
query_loss_scale_map[key] = [query_loss_scale_map[key]]
|
||||
loss_scale_value = query_loss_scale_map[key][0]
|
||||
return [response], [float(loss_scale_value)]
|
||||
delimiters = list(k for k in response_loss_scale_map.keys() if len(response_loss_scale_map[k]) == 2)
|
||||
agent_parts = split_str_parts_by(response, delimiters)
|
||||
regex_delimiters = {k: v for k, v in response_loss_scale_map.items() if len(v) == 1}
|
||||
if len(regex_delimiters):
|
||||
split_parts_by_regex(agent_parts, regex_delimiters)
|
||||
weights = []
|
||||
agent_content = []
|
||||
for c in agent_parts:
|
||||
if isinstance(c['key'], (float, int)):
|
||||
weights += [c['key']]
|
||||
agent_content.append(c['content'])
|
||||
else:
|
||||
if c['key'] in response_loss_scale_map:
|
||||
weights += [response_loss_scale_map[c['key']][0]]
|
||||
weights += [response_loss_scale_map[c['key']][1]]
|
||||
agent_content.append(c['key'])
|
||||
agent_content.append(c['content'])
|
||||
else:
|
||||
weights += [1.0]
|
||||
agent_content.append(c['content'])
|
||||
return agent_content, weights
|
||||
|
||||
|
||||
def alpha_umi_loss_scale(query: str, response: str):
|
||||
cwd = os.getcwd()
|
||||
loss_scale_config_path = 'alpha_umi_loss_scale_config.json'
|
||||
config_path = os.path.join(cwd, loss_scale_config_path)
|
||||
with open(config_path, 'r') as json_file:
|
||||
loss_scale_map = json.load(json_file)
|
||||
return calculate_loss_scale(query, response, loss_scale_map)
|
||||
|
||||
|
||||
def agentflan_loss_scale(query: str, response: str):
|
||||
cwd = os.getcwd()
|
||||
loss_scale_config_path = 'agentflan.json'
|
||||
config_path = os.path.join(cwd, loss_scale_config_path)
|
||||
with open(config_path, 'r') as json_file:
|
||||
loss_scale_map = json.load(json_file)
|
||||
query_loss_scale_map = loss_scale_map['query']
|
||||
response_loss_scale_map = loss_scale_map['response']
|
||||
return calculate_loss_scale(query, response, response_loss_scale_map, query_loss_scale_map)
|
||||
|
||||
|
||||
def react_loss_scale(query: str, response: str):
|
||||
cwd = os.getcwd()
|
||||
loss_scale_config_path = 'default_loss_scale_config.json'
|
||||
config_path = os.path.join(cwd, loss_scale_config_path)
|
||||
with open(config_path, 'r') as json_file:
|
||||
loss_scale_map = json.load(json_file)
|
||||
return calculate_loss_scale(query, response, loss_scale_map)
|
||||
|
||||
|
||||
def default_loss_scale(query: str, response: str):
|
||||
return [response], [1.0]
|
||||
|
||||
|
||||
loss_scale_map = {
|
||||
'agentflan': agentflan_loss_scale,
|
||||
'react': react_loss_scale,
|
||||
'alpha_umi': alpha_umi_loss_scale,
|
||||
'default': default_loss_scale,
|
||||
}
|
||||
2274
modelscope/preprocessors/templates/template.py
Normal file
2274
modelscope/preprocessors/templates/template.py
Normal file
File diff suppressed because it is too large
Load Diff
107
modelscope/preprocessors/templates/tools_prompt.py
Normal file
107
modelscope/preprocessors/templates/tools_prompt.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from typing import List, Dict, Union, Optional
|
||||
|
||||
|
||||
def format_react_en(tool_names, tool_descs):
|
||||
REACT_PROMPT = """Answer the following questions as best as you can. You have access to the following tools:
|
||||
|
||||
{tool_list}
|
||||
|
||||
Use the following format:
|
||||
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
||||
Final Answer: the final answer to the original input question
|
||||
|
||||
Begin!
|
||||
"""
|
||||
return REACT_PROMPT.format(tool_list='\n\n'.join(tool_descs), tool_names=','.join(tool_names))
|
||||
|
||||
|
||||
def format_react_zh(tool_names, tool_descs):
|
||||
REACT_ZH_PROMPT = """尽你所能回答以下问题。你拥有如下工具:
|
||||
|
||||
{tool_list}
|
||||
|
||||
使用以下格式回答:
|
||||
|
||||
Thought: 思考你应该做什么
|
||||
Action: 工具的名称,必须是[{tool_names}]之一
|
||||
Action Input: 工具的输入
|
||||
Observation: 工具返回的结果
|
||||
... (Thought/Action/Action Input/Observation的过程可以重复零次或多次)
|
||||
Final Answer: 对输入问题的最终答案
|
||||
|
||||
开始!
|
||||
"""
|
||||
return REACT_ZH_PROMPT.format(tool_list='\n\n'.join(tool_descs), tool_names=','.join(tool_names))
|
||||
|
||||
|
||||
def format_glm4(tool_names, tool_descs):
|
||||
GLM4_PROMPT = '''你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。
|
||||
|
||||
# 可用工具
|
||||
|
||||
{tool_list}'''
|
||||
tool_list = ''
|
||||
for name, tool in zip(tool_names, tool_descs):
|
||||
tool_list += f'## {name}\n\n{tool}\n\n'
|
||||
return GLM4_PROMPT.format(tool_list=tool_list)
|
||||
|
||||
|
||||
def format_toolbench(tool_names, tool_descs):
|
||||
TOOLBENCH_PROMPT = '''You can use many tools(functions) to do the following task.
|
||||
First I will give you the task description, and your task start.
|
||||
At each step, you need to give your thought to analyze the status now and what to do next, \
|
||||
with a function call to actually excute your step. Your output should follow this format:
|
||||
Thought:
|
||||
Action:
|
||||
Action Input:
|
||||
|
||||
After the call, you will get the call result, and you are now in a new state.
|
||||
Then you will analyze your status now, then decide what to do next...
|
||||
After many (Thought-call) pairs, you finally perform the task, then you can give your finial answer.
|
||||
Remember:
|
||||
1.the state change is irreversible, you can't go back to one of the former state, if you want to restart the task, \
|
||||
say \"I give up and restart\".
|
||||
2.All the thought is short, at most in 5 sentence.
|
||||
3.You can do more then one trys, so if your plan is to continusly try some conditions, \
|
||||
you can do one of the conditions per try.
|
||||
Let's Begin!
|
||||
Task description: You should use functions to help handle the real time user querys. Remember:
|
||||
1.ALWAYS call \"Finish\" function at the end of the task. And the final answer should contain enough information \
|
||||
to show to the user,If you can't handle the task, \
|
||||
or you find that function calls always fail(the function is not valid now), \
|
||||
use function Finish->give_up_and_restart.
|
||||
2.Do not use origin tool names, use only subfunctions' names.
|
||||
Specifically, you have access to the following APIs: {tool_list}'''
|
||||
return TOOLBENCH_PROMPT.format(tool_list='\n\n'.join(tool_descs))
|
||||
|
||||
|
||||
tools_prompt = {
|
||||
'react_en': format_react_en,
|
||||
'react_zh': format_react_zh,
|
||||
'glm4': format_glm4,
|
||||
'toolbench': format_toolbench,
|
||||
}
|
||||
|
||||
|
||||
def get_tools_prompt(TOOLS: List[Dict[str, Union[str, dict]]], prompt_format: str = 'react_en') -> Optional[str]:
|
||||
tool_descs = []
|
||||
tool_names = []
|
||||
for info in TOOLS: # info: Dict[str, Union[str, dict]]
|
||||
try:
|
||||
if 'function' in info:
|
||||
info = info['function']
|
||||
tool_names.append(info['name'])
|
||||
tool_descs.append(str(info)) # info: dict
|
||||
except KeyError:
|
||||
print('invalid tools format, please check'
|
||||
'https://github.com/modelscope/swift/blob/main/docs/source_en/LLM/Agent-deployment-best-practice.md')
|
||||
return None
|
||||
prompt_format = tools_prompt.get(prompt_format) or format_toolbench
|
||||
return prompt_format(tool_names, tool_descs)
|
||||
|
||||
|
||||
542
modelscope/preprocessors/templates/utils.py
Normal file
542
modelscope/preprocessors/templates/utils.py
Normal file
@@ -0,0 +1,542 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, List, TypeVar, Union, Tuple, Set, Dict, Type, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
|
||||
History = List[Union[Tuple[str, str], List[str]]]
|
||||
Prompt = List[Union[str, List[int], List[str]]]
|
||||
StopWords = Prompt
|
||||
Context = Union[str, List[int]]
|
||||
Messages = List[Dict[str, Union[str, List[Dict]]]]
|
||||
|
||||
|
||||
# >>> internvl
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
|
||||
def split_str_parts_by(text: str, delimiters: List[str]):
|
||||
"""Split the text field into parts.
|
||||
|
||||
Args:
|
||||
text: A text to be split.
|
||||
delimiters: The delimiters.
|
||||
|
||||
Returns:
|
||||
The split text in list of dicts.
|
||||
"""
|
||||
assert isinstance(text, str), f'text: {text}'
|
||||
all_start_chars = [d[0] for d in delimiters]
|
||||
all_length = [len(d) for d in delimiters]
|
||||
|
||||
text_list = []
|
||||
last_words = ''
|
||||
|
||||
while len(text) > 0:
|
||||
for char_idx, char in enumerate(text):
|
||||
match_index = [idx for idx, start_char in enumerate(all_start_chars) if start_char == char]
|
||||
is_delimiter = False
|
||||
for index in match_index:
|
||||
if text[char_idx:char_idx + all_length[index]] == delimiters[index]:
|
||||
if text_list:
|
||||
text_list[-1]['content'] = last_words
|
||||
elif last_words:
|
||||
text_list.append({'key': '', 'content': last_words})
|
||||
last_words = ''
|
||||
text_list.append({'key': delimiters[index]})
|
||||
text = text[char_idx + all_length[index]:]
|
||||
is_delimiter = True
|
||||
break
|
||||
if not is_delimiter:
|
||||
last_words += char
|
||||
else:
|
||||
break
|
||||
if last_words == text:
|
||||
text = ''
|
||||
|
||||
if len(text_list):
|
||||
text_list[-1]['content'] = last_words
|
||||
else:
|
||||
text_list.append({'key': '', 'content': last_words})
|
||||
return text_list
|
||||
|
||||
|
||||
def split_parts_by_regex(text_list: list, regex_delimiters: Dict[str, List[float]]) -> None:
|
||||
import re
|
||||
compiled_patterns = [(re.compile(pattern), scale) for pattern, scale in regex_delimiters.items()]
|
||||
for i in range(len(text_list) - 1, -1, -1):
|
||||
item = text_list[i]
|
||||
if item.get('key') == '':
|
||||
res_text = item['content']
|
||||
last_idx = 0
|
||||
segments = []
|
||||
|
||||
for pattern, scale in compiled_patterns:
|
||||
matches = list(re.finditer(pattern, res_text))
|
||||
for match in matches:
|
||||
if match.start() > last_idx:
|
||||
segments.append({'key': '', 'content': res_text[last_idx:match.start()]})
|
||||
segments.append({'key': scale[0], 'content': match.group(0)})
|
||||
last_idx = match.end()
|
||||
|
||||
if last_idx < len(res_text):
|
||||
segments.insert(0, {'key': '', 'content': res_text[last_idx:]})
|
||||
|
||||
if segments:
|
||||
text_list[i:i + 1] = segments
|
||||
|
||||
|
||||
def _decode_prompt(prompt: str, tmp_dir: str = 'tmp') -> str:
|
||||
pattern = r'<(?:img|audio|video)>(.+?)</(?:img|audio|video)>'
|
||||
match_iter = re.finditer(pattern, prompt)
|
||||
new_content = ''
|
||||
idx = 0
|
||||
for m in match_iter:
|
||||
span = m.span(1)
|
||||
img_base64 = m.group(1)
|
||||
img_path = _from_base64(img_base64, tmp_dir)
|
||||
new_content += prompt[idx:span[0]] + img_path
|
||||
idx = span[1]
|
||||
new_content += prompt[idx:]
|
||||
return new_content
|
||||
|
||||
|
||||
def _to_base64(img_path: Union[str, 'PIL.Image.Image', bytes]) -> str:
|
||||
if isinstance(img_path, str) and not os.path.isfile(img_path):
|
||||
# base64
|
||||
return img_path
|
||||
if isinstance(img_path, str):
|
||||
# local_path
|
||||
with open(img_path, 'rb') as f:
|
||||
_bytes = f.read()
|
||||
elif not isinstance(img_path, bytes): # PIL.Image.Image
|
||||
bytes_io = BytesIO()
|
||||
img_path.save(bytes_io, format='png')
|
||||
_bytes = bytes_io.getvalue()
|
||||
else:
|
||||
_bytes = img_path
|
||||
img_base64: str = base64.b64encode(_bytes).decode('utf-8')
|
||||
return img_base64
|
||||
|
||||
|
||||
def _from_base64(img_base64: Union[str, 'PIL.Image.Image'], tmp_dir: str = 'tmp') -> str:
|
||||
from PIL import Image
|
||||
if not isinstance(img_base64, str): # PIL.Image.Image
|
||||
img_base64 = _to_base64(img_base64)
|
||||
if os.path.isfile(img_base64) or img_base64.startswith('http'):
|
||||
return img_base64
|
||||
sha256_hash = hashlib.sha256(img_base64.encode('utf-8')).hexdigest()
|
||||
img_path = os.path.join(tmp_dir, f'{sha256_hash}.png')
|
||||
image = Image.open(BytesIO(base64.b64decode(img_base64)))
|
||||
if not os.path.exists(img_path):
|
||||
image.save(img_path)
|
||||
return img_path
|
||||
|
||||
|
||||
def decode_base64(*,
|
||||
messages: Optional[Messages] = None,
|
||||
prompt: Optional[str] = None,
|
||||
images: Optional[List[str]] = None,
|
||||
tmp_dir: str = 'tmp') -> Dict[str, Any]:
|
||||
# base64 -> local_path
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
res = {}
|
||||
if messages is not None:
|
||||
res_messages = []
|
||||
for m in messages:
|
||||
m_new = deepcopy(m)
|
||||
m_new['content'] = _decode_prompt(m_new['content'], tmp_dir)
|
||||
res_messages.append(m_new)
|
||||
res['messages'] = res_messages
|
||||
if prompt is not None:
|
||||
prompt = _decode_prompt(prompt, tmp_dir)
|
||||
res['prompt'] = prompt
|
||||
if images is not None:
|
||||
res_images = []
|
||||
for image in images:
|
||||
image = _from_base64(image, tmp_dir)
|
||||
res_images.append(image)
|
||||
res['images'] = res_images
|
||||
return res
|
||||
|
||||
|
||||
def to_device(inputs: Any, device: torch.device) -> Any:
|
||||
"""Move inputs to a device"""
|
||||
if callable(getattr(inputs, 'to', None)):
|
||||
return inputs.to(device=device)
|
||||
|
||||
if isinstance(inputs, Mapping):
|
||||
res = {}
|
||||
for k, v in inputs.items():
|
||||
res[k] = to_device(v, device)
|
||||
elif isinstance(inputs, Sequence) and not isinstance(inputs, str):
|
||||
res = []
|
||||
for b in inputs:
|
||||
res.append(to_device(b, device))
|
||||
else:
|
||||
res = inputs
|
||||
return res
|
||||
|
||||
|
||||
def upper_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int:
|
||||
# The upper bound satisfying the condition "cond".
|
||||
while lo < hi:
|
||||
mid = (lo + hi + 1) >> 1 # lo + (hi-lo+1)>>1
|
||||
if cond(mid):
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid - 1
|
||||
return lo
|
||||
|
||||
|
||||
def fetch_one(element: Union[Tuple, List, Set, Dict, Any], type: Type = None) -> Any:
|
||||
if isinstance(element, (tuple, set, list)):
|
||||
for ele in element:
|
||||
out = fetch_one(ele)
|
||||
if out and (type is None or isinstance(out, type)):
|
||||
return out
|
||||
elif isinstance(element, dict):
|
||||
return fetch_one(list(element.values()))
|
||||
else:
|
||||
return element
|
||||
|
||||
|
||||
def _build_transform(input_size):
|
||||
import torchvision.transforms as T
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||
transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD)
|
||||
])
|
||||
return transform
|
||||
|
||||
|
||||
def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
||||
best_ratio_diff = float('inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
|
||||
def _dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
||||
orig_width, orig_height = image.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = _find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size, ((i //
|
||||
(target_width // image_size)) + 1) * image_size)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images
|
||||
|
||||
|
||||
# <<< internvl
|
||||
|
||||
|
||||
def rescale_image(img: 'PIL.Image.Image', rescale_image: int = -1) -> 'PIL.Image.Image':
|
||||
import torchvision.transforms as T
|
||||
width = img.width
|
||||
height = img.height
|
||||
if rescale_image <= 0 or width * height <= rescale_image:
|
||||
return img
|
||||
|
||||
ratio = width / height
|
||||
height_scaled = math.pow(rescale_image / ratio, 0.5)
|
||||
width_scaled = height_scaled * ratio
|
||||
return T.Resize((int(width_scaled), int(height_scaled)))(img)
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
|
||||
def load_file(path: Union[str, _T]) -> Union[BytesIO, _T]:
|
||||
res = path
|
||||
if isinstance(path, str):
|
||||
path = path.strip()
|
||||
if path.startswith('http'):
|
||||
request_kwargs = {}
|
||||
timeout = float(os.getenv('TIMEOUT', '60'))
|
||||
if timeout > 0:
|
||||
request_kwargs['timeout'] = timeout
|
||||
content = requests.get(path, **request_kwargs).content
|
||||
res = BytesIO(content)
|
||||
elif os.path.exists(path):
|
||||
with open(path, 'rb') as f:
|
||||
res = BytesIO(f.read())
|
||||
else: # base64_str
|
||||
import binascii
|
||||
try:
|
||||
data = base64.b64decode(path)
|
||||
res = BytesIO(data)
|
||||
except (ValueError, binascii.Error) as error:
|
||||
if len(path) < 200:
|
||||
raise ValueError(f'invalid image: "{path}"')
|
||||
else:
|
||||
raise ValueError(f'invalid image: {error}')
|
||||
return res
|
||||
|
||||
|
||||
def load_file_decorator(func):
|
||||
|
||||
def new_func(path, *args, **kwargs):
|
||||
path = load_file(path)
|
||||
res = func(path, *args, **kwargs)
|
||||
return res
|
||||
|
||||
return new_func
|
||||
|
||||
|
||||
@load_file_decorator
|
||||
def load_image(image: Union['PIL.Image.Image', BytesIO]) -> 'PIL.Image.Image':
|
||||
from PIL import Image
|
||||
if isinstance(image, BytesIO):
|
||||
image = Image.open(image)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
return image
|
||||
|
||||
|
||||
def load_batch(path_list: List[Union[str, None, Any, BytesIO]],
|
||||
load_func: Callable[[Any], _T] = load_image) -> List[_T]:
|
||||
res = []
|
||||
assert isinstance(path_list, (list, tuple)), f'path_list: {path_list}'
|
||||
for path in path_list:
|
||||
if path is None: # ignore None
|
||||
continue
|
||||
res.append(load_func(path))
|
||||
return res
|
||||
|
||||
|
||||
def _get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
||||
if bound:
|
||||
start, end = bound[0], bound[1]
|
||||
else:
|
||||
start, end = -100000, 100000
|
||||
start_idx = max(first_idx, round(start * fps))
|
||||
end_idx = min(round(end * fps), max_frame)
|
||||
seg_size = float(end_idx - start_idx) / num_segments
|
||||
frame_indices = np.array(
|
||||
[int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
|
||||
return frame_indices
|
||||
|
||||
|
||||
def transform_image(image, input_size=448, max_num=12):
|
||||
transform = _build_transform(input_size=input_size)
|
||||
images = _dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
return pixel_values
|
||||
|
||||
|
||||
@load_file_decorator
|
||||
def load_video_internvl(video_io: BytesIO, bound=None, num_segments=32):
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
vr = VideoReader(video_io, ctx=cpu(0), num_threads=1)
|
||||
max_frame = len(vr) - 1
|
||||
fps = float(vr.get_avg_fps())
|
||||
|
||||
images = []
|
||||
frame_indices = _get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
|
||||
for frame_index in frame_indices:
|
||||
images.append(Image.fromarray(vr[frame_index].asnumpy()).convert('RGB'))
|
||||
return images
|
||||
|
||||
|
||||
def draw_plot(img_dir: str, bbox: List[int], bbox_type: str, output_file: str):
|
||||
from PIL import Image, ImageDraw
|
||||
from swift.llm.template.template import Template
|
||||
image = Image.open(img_dir)
|
||||
|
||||
objects = [{'bbox': bbox, 'bbox_type': bbox_type, 'image': 0}]
|
||||
Template.normalize_bbox(objects, [image], 'real')
|
||||
bbox = objects[0]['bbox']
|
||||
draw = ImageDraw.Draw(image)
|
||||
draw.rectangle(bbox, outline='red', width=2)
|
||||
image.save(output_file)
|
||||
|
||||
|
||||
@load_file_decorator
|
||||
def load_video_cogvlm2(video_io: BytesIO) -> np.ndarray:
|
||||
from decord import cpu, VideoReader, bridge
|
||||
bridge.set_bridge('torch')
|
||||
clip_end_sec = 60
|
||||
clip_start_sec = 0
|
||||
num_frames = 24
|
||||
decord_vr = VideoReader(video_io, ctx=cpu(0))
|
||||
duration = len(decord_vr) # duration in terms of frames
|
||||
start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
|
||||
end_frame = min(duration, int(clip_end_sec * decord_vr.get_avg_fps())) if \
|
||||
clip_end_sec is not None else duration
|
||||
frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
|
||||
video_data = decord_vr.get_batch(frame_id_list)
|
||||
video_data = video_data.permute(3, 0, 1, 2)
|
||||
return video_data
|
||||
|
||||
|
||||
@load_file_decorator
|
||||
def load_video_llava(video_io: BytesIO) -> np.ndarray:
|
||||
import av
|
||||
container = av.open(video_io)
|
||||
total_frames = container.streams.video[0].frames
|
||||
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
|
||||
frames = []
|
||||
container.seek(0)
|
||||
start_index = indices[0]
|
||||
end_index = indices[-1]
|
||||
for i, frame in enumerate(container.decode(video=0)):
|
||||
if i > end_index:
|
||||
break
|
||||
if i >= start_index and i in indices:
|
||||
frames.append(frame)
|
||||
return np.stack([x.to_ndarray(format='rgb24') for x in frames])
|
||||
|
||||
|
||||
@load_file_decorator
|
||||
def load_video_minicpmv_mplug_owl3(video_io: BytesIO, max_num_frames):
|
||||
from PIL import Image
|
||||
from decord import VideoReader, cpu # pip install decord
|
||||
|
||||
def uniform_sample(_l, _n):
|
||||
gap = len(_l) / _n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(_n)]
|
||||
return [_l[i] for i in idxs]
|
||||
|
||||
vr = VideoReader(video_io, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
|
||||
if len(frame_idx) > max_num_frames:
|
||||
frame_idx = uniform_sample(frame_idx, max_num_frames)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
return frames
|
||||
|
||||
|
||||
@load_file_decorator
|
||||
def load_audio_qwen(audio_io: BytesIO, sampling_rate: int):
|
||||
import librosa
|
||||
return librosa.load(audio_io, sr=sampling_rate)[0]
|
||||
|
||||
|
||||
def load_video_qwen2(video_path: str):
|
||||
from swift.llm.template.template import get_env_args
|
||||
import torchvision
|
||||
from torchvision import io, transforms
|
||||
from qwen_vl_utils.vision_process import (round_by_factor, FPS, FRAME_FACTOR, FPS_MIN_FRAMES, FPS_MAX_FRAMES,
|
||||
VIDEO_MIN_PIXELS, VIDEO_MAX_PIXELS, VIDEO_TOTAL_PIXELS, smart_resize,
|
||||
ceil_by_factor, floor_by_factor)
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
if version.parse(torchvision.__version__) >= version.parse('0.19'):
|
||||
video_path = load_file(video_path)
|
||||
video, _, info = io.read_video(
|
||||
video_path,
|
||||
pts_unit='sec',
|
||||
output_format='TCHW',
|
||||
)
|
||||
nframes = get_env_args('nframes', int, None)
|
||||
fps = get_env_args('fps', int, None)
|
||||
size_factor = get_env_args('size_factor', int, FRAME_FACTOR)
|
||||
assert not (fps and nframes), 'Only accept either `fps` or `nframes`'
|
||||
if nframes is not None:
|
||||
nframes = round_by_factor(nframes, size_factor)
|
||||
else:
|
||||
fps = FPS
|
||||
nframes = video.size(0) / info['video_fps'] * fps
|
||||
nframes = round_by_factor(nframes, size_factor)
|
||||
min_frames = get_env_args('min_frames', int, FPS_MIN_FRAMES)
|
||||
max_frames = get_env_args('max_frames', int, FPS_MAX_FRAMES)
|
||||
if nframes < min_frames:
|
||||
nframes = ceil_by_factor(min_frames, size_factor)
|
||||
if nframes > max_frames:
|
||||
nframes = floor_by_factor(max_frames, size_factor)
|
||||
|
||||
if not (size_factor <= nframes and nframes <= video.size(0)):
|
||||
raise ValueError(f'nframes should in interval [{size_factor}, {video.size(0)}], but got {nframes}.')
|
||||
|
||||
idx = torch.linspace(0, video.size(0) - 1, nframes).round().long()
|
||||
height, width = video.shape[2:]
|
||||
video = video[idx]
|
||||
|
||||
min_pixels = get_env_args('min_pixels', int, VIDEO_MIN_PIXELS)
|
||||
total_pixels = get_env_args('total_pixels', int, VIDEO_TOTAL_PIXELS)
|
||||
max_pixels = get_env_args('max_pixels', int, None)
|
||||
if max_pixels is None:
|
||||
max_pixels = VIDEO_MAX_PIXELS
|
||||
max_pixels = max(min(max_pixels, total_pixels / nframes * size_factor), min_pixels * 1.05)
|
||||
# resize
|
||||
resized_height = get_env_args('resized_height', int, None)
|
||||
resized_width = get_env_args('resized_width', int, None)
|
||||
if resized_height and resized_width:
|
||||
resized_height, resized_width = smart_resize(
|
||||
resized_height,
|
||||
resized_width,
|
||||
factor=size_factor,
|
||||
)
|
||||
else:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=size_factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
|
||||
video = transforms.functional.resize(
|
||||
video,
|
||||
[resized_height, resized_width],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
antialias=True,
|
||||
).float()
|
||||
return video
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# A test main to draw bbox
|
||||
draw_plot('man.jpg', [354, 462, 580, 738], 'norm_1000', 'man_bbox.jpg')
|
||||
@@ -127,7 +127,9 @@ def _patch_pretrained_class():
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
|
||||
ignore_file_pattern = [
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
|
||||
]
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path,
|
||||
ignore_file_pattern, **kwargs)
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
@@ -143,14 +145,18 @@ def _patch_pretrained_class():
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
|
||||
ignore_file_pattern = [
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
|
||||
]
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path,
|
||||
ignore_file_pattern, **kwargs)
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(cls, pretrained_model_name_or_path, **kwargs):
|
||||
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
|
||||
ignore_file_pattern = [
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
|
||||
]
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path,
|
||||
ignore_file_pattern, **kwargs)
|
||||
return ori_get_config_dict(cls, model_dir, **kwargs)
|
||||
@@ -242,11 +248,20 @@ AutoModelForTokenClassification = get_wrapped_class(
|
||||
AutoModelForTokenClassificationHF)
|
||||
|
||||
AutoTokenizer = get_wrapped_class(
|
||||
AutoTokenizerHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
|
||||
AutoTokenizerHF,
|
||||
ignore_file_pattern=[
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
|
||||
])
|
||||
AutoConfig = get_wrapped_class(
|
||||
AutoConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
|
||||
AutoConfigHF,
|
||||
ignore_file_pattern=[
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
|
||||
])
|
||||
GenerationConfig = get_wrapped_class(
|
||||
GenerationConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
|
||||
GenerationConfigHF,
|
||||
ignore_file_pattern=[
|
||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt'
|
||||
])
|
||||
GPTQConfig = GPTQConfigHF
|
||||
AwqConfig = AwqConfigHF
|
||||
BitsAndBytesConfig = BitsAndBytesConfigHF
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Make sure to modify __release_datetime__ to release time when making official release.
|
||||
__version__ = '2.0.0'
|
||||
__version__ = '1.19.0'
|
||||
# default release datetime for branches under active development is set
|
||||
# to be a time far-far-away-into-the-future
|
||||
__release_datetime__ = '2099-09-06 00:00:00'
|
||||
__release_datetime__ = '2024-10-08 23:59:59'
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
addict
|
||||
attrs
|
||||
datasets>=2.18.0,<3.0.0
|
||||
datasets>=3.0.0
|
||||
einops
|
||||
oss2
|
||||
Pillow
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
addict
|
||||
attrs
|
||||
datasets>=2.18.0,<3.0.0
|
||||
datasets>=3.0.0
|
||||
einops
|
||||
oss2
|
||||
Pillow
|
||||
|
||||
@@ -44,6 +44,15 @@ class TestStreamLoad(unittest.TestCase):
|
||||
|
||||
assert sample['question'], f'Failed to load sample from {repo_id}'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_stream_swift_jsonl(self):
|
||||
repo_id: str = 'iic/MSAgent-MultiRole'
|
||||
ds = MsDataset.load(repo_id, split='train', use_streaming=True)
|
||||
sample = next(iter(ds))
|
||||
logger.info(sample)
|
||||
|
||||
assert sample['id'], f'Failed to load sample from {repo_id}'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
0
tests/tools/__init__.py
Normal file
0
tests/tools/__init__.py
Normal file
106
tests/tools/test_to_ollama.py
Normal file
106
tests/tools/test_to_ollama.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
from modelscope.preprocessors.templates import TemplateType
|
||||
from modelscope.preprocessors.templates.loader import TemplateLoader
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestToOllama(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_load_template(self):
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'LLM-Research/Meta-Llama-3-8B-Instruct')
|
||||
self.assertTrue(template.template_type == TemplateType.llama3)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'swift/Meta-Llama-3-70B-Instruct-AWQ')
|
||||
self.assertTrue(template.template_type == TemplateType.llama3)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'deepseek-ai/DeepSeek-V2-Lite-Chat')
|
||||
self.assertTrue(template.template_type == TemplateType.deepseek2)
|
||||
|
||||
template = TemplateLoader.load_by_model_id('deepseek-ai/DeepSeek-V2.5')
|
||||
self.assertTrue(template.template_type == TemplateType.deepseek2_5)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'deepseek-ai/deepseek-coder-1.3b-instruct')
|
||||
self.assertTrue(template.template_type == TemplateType.deepseek_coder)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'OpenBuddy/openbuddy-deepseek-67b-v15.2')
|
||||
self.assertTrue(template is None)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'deepseek-ai/deepseek-llm-67b-chat')
|
||||
self.assertTrue(template.template_type == TemplateType.deepseek)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'deepseek-ai/DeepSeek-Coder-V2-Instruct')
|
||||
self.assertTrue(template.template_type == TemplateType.deepseek2)
|
||||
|
||||
template = TemplateLoader.load_by_model_id('01ai/Yi-1.5-9B-Chat')
|
||||
self.assertTrue(template.template_type == TemplateType.chatml)
|
||||
|
||||
template = TemplateLoader.load_by_model_id('01ai/Yi-Coder-9B-Chat')
|
||||
self.assertTrue(template.template_type == TemplateType.yi_coder)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'LLM-Research/gemma-2-27b-it')
|
||||
self.assertTrue(template.template_type == TemplateType.gemma)
|
||||
|
||||
template = TemplateLoader.load_by_model_id('AI-ModelScope/gemma-2b')
|
||||
self.assertTrue(template is None)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'AI-ModelScope/gemma-2b-instruct')
|
||||
self.assertTrue(template is None)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'AI-ModelScope/gemma2-2b-instruct')
|
||||
self.assertTrue(template.template_type == TemplateType.gemma)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'AI-ModelScope/paligemma-3b-mix-224')
|
||||
self.assertTrue(template is None)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'LLM-Research/Phi-3-vision-128k-instruct')
|
||||
self.assertTrue(template is None)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'LLM-Research/Phi-3-128k-instruct')
|
||||
self.assertTrue(template.template_type == TemplateType.phi3)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'LLM-Research/Phi-3-128k-instruct-GGUF')
|
||||
self.assertTrue(template.template_type == TemplateType.phi3)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_load_ollama(self):
|
||||
ollama = TemplateLoader.to_ollama(
|
||||
'LLM-Research/Meta-Llama-3.1-8B-Instruct-GGUF')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama(
|
||||
'QuantFactory/Gemma-2-Ataraxy-9B-Chat-GGUF')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama('Xorbits/Llama-2-7b-Chat-GGUF')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama(
|
||||
'AI-ModelScope/gemma2-2b-instruct-GGUF')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama(
|
||||
'LLM-Research/Phi-3-128k-instruct-GGUF')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama(template_name='phi3')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama(
|
||||
'QuantFactory/Mistral-Nemo-Japanese-Instruct-2408-GGUF')
|
||||
self.assertTrue(ollama is not None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user