template and ollama in modelscope (#995)

This commit is contained in:
tastelikefeet
2024-09-29 19:13:50 +08:00
committed by GitHub
parent 058df0e34c
commit 834db59952
11 changed files with 4548 additions and 0 deletions

View File

@@ -1,3 +1,5 @@
exclude: 'modelscope/preprocessors/templates/'
repos:
- repo: https://github.com/pycqa/flake8.git
rev: 4.0.0

View File

@@ -1,3 +1,5 @@
exclude: 'modelscope/preprocessors/templates/'
repos:
- repo: /home/admin/pre-commit/flake8
rev: 4.0.0

View File

@@ -0,0 +1,2 @@
from .base import Template, get_template
from .template import TemplateType

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,371 @@
import re
from dataclasses import dataclass
from typing import Any, Dict, List
import requests
from modelscope import AutoTokenizer, get_logger, snapshot_download
from . import TemplateType
from .base import Template, get_template
logger = get_logger()
@dataclass
class TemplateInfo:
template: str = None
template_regex: str = None
modelfile_link: str = None
def cases(*names):
ret = []
for name in names:
regex = ''
for letter in name:
if letter.upper() != letter.lower():
regex += f'[{letter.upper()}{letter.lower()}]'
else:
regex += letter
ret.append(regex)
if len(ret) > 1:
ret = '|'.join(ret)
ret = '(' + ret + ')'
else:
ret = ret[0]
return ret
chat_suffix = cases('instruct', 'chat', '-rl', '-it')
def no(*names):
return f'(?!.*{cases(*names)})'
def no_multi_modal():
return no('audio', 'video', 'vl', 'vision')
template_info = [
# llama
TemplateInfo(
template=TemplateType.llama3,
template_regex=
f'.*{cases("llama3", "llama-3")}{no_multi_modal()}.*{chat_suffix}.*',
modelfile_link=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/llama-3.modelfile',
),
TemplateInfo(
template=TemplateType.llama,
template_regex=
f'.*{cases("llama2", "llama-2", "mistral", "codestral", "mixtral")}{no_multi_modal()}.*{chat_suffix}.*'
),
# qwen
TemplateInfo(
template=TemplateType.qwen,
template_regex=f'.*{cases("qwen")}{no_multi_modal()}.*{chat_suffix}.*',
modelfile_link=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/qwen2.modelfile',
),
# codeqwen1.5
TemplateInfo(
template_regex=
f'.*{cases("codeqwen1.5", "codeqwen-1.5")}.*{chat_suffix}.*',
modelfile_link=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/codeqwen1.5.modelfile',
),
# chatml
TemplateInfo(
template=TemplateType.chatml,
template_regex=
f'.*{cases("yi")}{no_multi_modal()}{no("coder")}.*{chat_suffix}.*',
modelfile_link=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/yi-1.5.modelfile',
),
# chatml
TemplateInfo(
template=TemplateType.chatml,
template_regex=f'.*{cases("minicpm")}{no("-v")}.*',
modelfile_link=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/yi-1.5.modelfile'
),
# chatglm
TemplateInfo(
template=TemplateType.chatglm2,
template_regex=f'.*{cases("chatglm2")}{no_multi_modal()}.*'),
TemplateInfo(
template=TemplateType.chatglm3,
template_regex=f'.*{cases("chatglm3")}{no_multi_modal()}.*'),
TemplateInfo(
template=TemplateType.chatglm4,
template_regex=f'.*{cases("glm4")}{no_multi_modal()}.*{chat_suffix}.*',
modelfile_link=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/glm4.modelfile',
),
# baichuan
TemplateInfo(
template=TemplateType.baichuan,
template_regex=
f'.*{cases("baichuan")}{no_multi_modal()}.*{chat_suffix}.*'),
# codegeex
TemplateInfo(
template=TemplateType.codegeex4,
template_regex=f'.*{cases("codegeex4")}{no_multi_modal()}.*'),
# idefics3
TemplateInfo(
template=TemplateType.idefics3,
template_regex=f'.*{cases("idefics3")}{no_multi_modal()}.*'),
# mistral-nemo
TemplateInfo(
template=TemplateType.mistral_nemo,
template_regex=f'.*{cases("Mistral-Nemo")}{no_multi_modal()}.*',
modelfile_link='https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/mistral-nemo.modelfile'),
# internlm
TemplateInfo(
template=TemplateType.internlm,
template_regex=
f'.*{cases("internlm")}{no("internlm2", "internlm3")}{no_multi_modal()}.*{chat_suffix}.*'
),
# internlm2
TemplateInfo(
template=TemplateType.internlm2,
template_regex=
f'.*{cases("internlm2")}{no_multi_modal()}.*{chat_suffix}.*'),
# yi-coder
TemplateInfo(
template=TemplateType.yi_coder,
template_regex=f'.*{cases("yi")}.*{cases("coder")}.*{chat_suffix}.*'),
# yuan
TemplateInfo(
template=TemplateType.yuan,
template_regex=f'.*{cases("Yuan")}{no_multi_modal()}.*'),
# xverse
TemplateInfo(
template=TemplateType.xverse,
template_regex=f'.*{cases("xverse")}{no_multi_modal()}.*{chat_suffix}.*'
),
# skywork
TemplateInfo(
template=TemplateType.skywork,
template_regex=
f'.*{cases("skywork")}{no_multi_modal()}.*{chat_suffix}.*'),
# bluelm
TemplateInfo(
template=TemplateType.bluelm,
template_regex=f'.*{cases("bluelm")}{no_multi_modal()}.*{chat_suffix}.*'
),
# zephyr
TemplateInfo(
template=TemplateType.zephyr,
template_regex=f'.*{cases("zephyr")}{no_multi_modal()}.*'),
# deepseek
TemplateInfo(
template=TemplateType.deepseek,
template_regex=
f'.*{cases("deepseek")}{no("v2", "v2.5", "coder")}{no_multi_modal()}.*{chat_suffix}.*'
),
# deepseek2
TemplateInfo(
template=TemplateType.deepseek2,
template_regex=
f'.*{cases("deepseek")}.*{cases("v2")}{no("v2.5")}{no_multi_modal()}.*{chat_suffix}.*',
modelfile_link=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/deepseek_v2.modelfile',
),
# deepseek_coder
TemplateInfo(
template=TemplateType.deepseek_coder,
template_regex=
f'.*{cases("deepseek")}{no("v2", "v2.5")}.*{cases("coder")}.*{chat_suffix}.*'
),
# deepseek v2.5
TemplateInfo(
template=TemplateType.deepseek2_5,
template_regex=
f'.*{cases("deepseek")}.*{cases("v2.5")}{no_multi_modal()}.*'),
# orion
TemplateInfo(
template=TemplateType.orion,
template_regex=f'.*{cases("orion")}{no_multi_modal()}.*{chat_suffix}.*'
),
# gemma
TemplateInfo(
template=TemplateType.gemma,
template_regex=
f'{no("pali")}.*{cases("gemma2", "gemma-2")}\\b.*{chat_suffix}.*',
modelfile_link=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/gemma2.modelfile',
),
# phi3
TemplateInfo(
template=TemplateType.phi3,
template_regex=
f'.*{cases("phi3", "phi-3")}{no_multi_modal()}.*{chat_suffix}.*',
modelfile_link=
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/phi3.modelfile',
),
# telechat
TemplateInfo(
template=TemplateType.telechat,
template_regex=f'.*{cases("TeleChat")}{no("v2")}.*'),
# telechat_v2
TemplateInfo(
template=TemplateType.telechat_v2,
template_regex=f'.*{cases("TeleChat")}.*{cases("v2")}.*'),
]
class TemplateLoader:
@staticmethod
def load_by_model_id(model_id: str, **kwargs) -> Template:
"""Load a template by model-id
Args:
model_id: The model-id used to load the proper template
kwargs:
revision: the revision of the model, default is `master`
Returns:
The template instance
"""
ignore_file_pattern = [r'.+\.bin$', r'.+\.safetensors$', r'.+\.gguf$']
tokenizer = kwargs.get('tokenizer')
for _info in template_info:
if re.fullmatch(_info.template_regex, model_id):
if _info.template:
if tokenizer is None:
try:
model_dir = snapshot_download(
model_id,
revision=kwargs.pop('revision', 'master'),
ignore_file_pattern=ignore_file_pattern)
tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True)
except Exception:
pass
return TemplateLoader.load_by_template_name(
_info.template, tokenizer=tokenizer, **kwargs)
@staticmethod
def load_by_template_name(template_name: str, **kwargs) -> Template:
"""Load a template by model-id
Args:
template_name: The template name used to load the proper template
kwargs:
tokenizer: The tokenizer of the model
default_system: The extra default system info
max_length: The max_length for the sequence
truncation_strategy: 'delete' or 'truncation_left' the sequence of the length exceeds the limit
Returns:
The template instance
"""
return get_template(template_name, tokenizer=kwargs.pop('tokenizer', None), **kwargs)
@staticmethod
def replace_and_concat(template: Template, template_list: List,
placeholder: str, keyword: str):
final_str = ''
for t in template_list:
if isinstance(t, str):
final_str += t.replace(placeholder, keyword)
elif isinstance(t, (tuple, list)):
if isinstance(t[0], int):
final_str += template.tokenizer.decode(t)
else:
for attr in t:
if attr == 'bos_token_id':
final_str += template.tokenizer.bos_token
elif attr == 'eos_token_id':
final_str += template.tokenizer.eos_token
else:
raise ValueError(f'Unknown token: {attr}')
return final_str
@staticmethod
def to_ollama(model_id: str = None,
template_name: str = None,
gguf_file: str = None,
gguf_meta: Dict[str, Any] = None,
**kwargs) -> str:
"""Export to ollama ModelFile
Args:
model_id: The model-id to use
template_name: An extra template name to use
gguf_file: An extra gguf_file path to use in the `FROM` field
gguf_meta: An gguf extra meta info
Returns:
The ModelFile content, returns `None` if no template found
"""
logger.info('Exporting to ollama:')
if model_id:
for _info in template_info:
if re.fullmatch(_info.template_regex, model_id):
if _info.modelfile_link:
return TemplateLoader._read_content_from_url(
_info.modelfile_link)
elif _info.template and not template_name:
template_name = _info.template
if template_name:
template = TemplateLoader.load_by_template_name(
template_name, **kwargs)
else:
raise ValueError(
f'Please make sure you model_id: {model_id} '
f'and template_name: {template_name} is supported.')
if template is None:
return None
content = ''
content += 'FROM {{gguf_file}}\n'
content += (
f'TEMPLATE """{{{{ if .System }}}}'
f'{TemplateLoader.replace_and_concat(template, template.system_prefix or [], "{{SYSTEM}}", "{{ .System }}")}'
f'{{{{ else }}}}{TemplateLoader.replace_and_concat(template, template.prefix, "", "")}'
f'{{{{ end }}}}')
content += (
f'{{{{ if .Prompt }}}}'
f'{TemplateLoader.replace_and_concat(template, template.prompt, "{{QUERY}}", "{{ .Prompt }}")}'
f'{{{{ end }}}}')
content += '{{ .Response }}'
content += TemplateLoader.replace_and_concat(template, template.suffix,
'', '') + '"""\n'
content += f'PARAMETER stop "{TemplateLoader.replace_and_concat(template, template.suffix, "", "")}"\n'
return content
@staticmethod
def _read_content_from_url(url):
response = requests.get(url)
response.raise_for_status()
content = response.content
return content.decode('utf-8')

View File

@@ -0,0 +1,101 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import os
from typing import Dict, List, Optional, Tuple
from .utils import split_str_parts_by, split_parts_by_regex
def calculate_loss_scale(query: str,
response: str,
response_loss_scale_map: Optional[Dict[str, list]] = None,
query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]:
"""Calculate the loss scale by splitting the agent response.
This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf
Agent response format:
```text
Thought: you should always think about what to do
Action: the action to take, should be one of the above tools[fire_recognition,
fire_alert, call_police, call_fireman]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question
```
Returns:
A tuple of agent response parts and their weights.
"""
# query loss scale map
if query_loss_scale_map is not None:
for key in query_loss_scale_map.keys():
if key in query:
if isinstance(query_loss_scale_map[key], (float, int)):
query_loss_scale_map[key] = [query_loss_scale_map[key]]
loss_scale_value = query_loss_scale_map[key][0]
return [response], [float(loss_scale_value)]
delimiters = list(k for k in response_loss_scale_map.keys() if len(response_loss_scale_map[k]) == 2)
agent_parts = split_str_parts_by(response, delimiters)
regex_delimiters = {k: v for k, v in response_loss_scale_map.items() if len(v) == 1}
if len(regex_delimiters):
split_parts_by_regex(agent_parts, regex_delimiters)
weights = []
agent_content = []
for c in agent_parts:
if isinstance(c['key'], (float, int)):
weights += [c['key']]
agent_content.append(c['content'])
else:
if c['key'] in response_loss_scale_map:
weights += [response_loss_scale_map[c['key']][0]]
weights += [response_loss_scale_map[c['key']][1]]
agent_content.append(c['key'])
agent_content.append(c['content'])
else:
weights += [1.0]
agent_content.append(c['content'])
return agent_content, weights
def alpha_umi_loss_scale(query: str, response: str):
cwd = os.getcwd()
loss_scale_config_path = 'alpha_umi_loss_scale_config.json'
config_path = os.path.join(cwd, loss_scale_config_path)
with open(config_path, 'r') as json_file:
loss_scale_map = json.load(json_file)
return calculate_loss_scale(query, response, loss_scale_map)
def agentflan_loss_scale(query: str, response: str):
cwd = os.getcwd()
loss_scale_config_path = 'agentflan.json'
config_path = os.path.join(cwd, loss_scale_config_path)
with open(config_path, 'r') as json_file:
loss_scale_map = json.load(json_file)
query_loss_scale_map = loss_scale_map['query']
response_loss_scale_map = loss_scale_map['response']
return calculate_loss_scale(query, response, response_loss_scale_map, query_loss_scale_map)
def react_loss_scale(query: str, response: str):
cwd = os.getcwd()
loss_scale_config_path = 'default_loss_scale_config.json'
config_path = os.path.join(cwd, loss_scale_config_path)
with open(config_path, 'r') as json_file:
loss_scale_map = json.load(json_file)
return calculate_loss_scale(query, response, loss_scale_map)
def default_loss_scale(query: str, response: str):
return [response], [1.0]
loss_scale_map = {
'agentflan': agentflan_loss_scale,
'react': react_loss_scale,
'alpha_umi': alpha_umi_loss_scale,
'default': default_loss_scale,
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,107 @@
from typing import List, Dict, Union, Optional
def format_react_en(tool_names, tool_descs):
REACT_PROMPT = """Answer the following questions as best as you can. You have access to the following tools:
{tool_list}
Use the following format:
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
Final Answer: the final answer to the original input question
Begin!
"""
return REACT_PROMPT.format(tool_list='\n\n'.join(tool_descs), tool_names=','.join(tool_names))
def format_react_zh(tool_names, tool_descs):
REACT_ZH_PROMPT = """尽你所能回答以下问题。你拥有如下工具:
{tool_list}
使用以下格式回答:
Thought: 思考你应该做什么
Action: 工具的名称,必须是[{tool_names}]之一
Action Input: 工具的输入
Observation: 工具返回的结果
... (Thought/Action/Action Input/Observation的过程可以重复零次或多次)
Final Answer: 对输入问题的最终答案
开始!
"""
return REACT_ZH_PROMPT.format(tool_list='\n\n'.join(tool_descs), tool_names=','.join(tool_names))
def format_glm4(tool_names, tool_descs):
GLM4_PROMPT = '''你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。
# 可用工具
{tool_list}'''
tool_list = ''
for name, tool in zip(tool_names, tool_descs):
tool_list += f'## {name}\n\n{tool}\n\n'
return GLM4_PROMPT.format(tool_list=tool_list)
def format_toolbench(tool_names, tool_descs):
TOOLBENCH_PROMPT = '''You can use many tools(functions) to do the following task.
First I will give you the task description, and your task start.
At each step, you need to give your thought to analyze the status now and what to do next, \
with a function call to actually excute your step. Your output should follow this format:
Thought:
Action:
Action Input:
After the call, you will get the call result, and you are now in a new state.
Then you will analyze your status now, then decide what to do next...
After many (Thought-call) pairs, you finally perform the task, then you can give your finial answer.
Remember:
1.the state change is irreversible, you can't go back to one of the former state, if you want to restart the task, \
say \"I give up and restart\".
2.All the thought is short, at most in 5 sentence.
3.You can do more then one trys, so if your plan is to continusly try some conditions, \
you can do one of the conditions per try.
Let's Begin!
Task description: You should use functions to help handle the real time user querys. Remember:
1.ALWAYS call \"Finish\" function at the end of the task. And the final answer should contain enough information \
to show to the user,If you can't handle the task, \
or you find that function calls always fail(the function is not valid now), \
use function Finish->give_up_and_restart.
2.Do not use origin tool names, use only subfunctions' names.
Specifically, you have access to the following APIs: {tool_list}'''
return TOOLBENCH_PROMPT.format(tool_list='\n\n'.join(tool_descs))
tools_prompt = {
'react_en': format_react_en,
'react_zh': format_react_zh,
'glm4': format_glm4,
'toolbench': format_toolbench,
}
def get_tools_prompt(TOOLS: List[Dict[str, Union[str, dict]]], prompt_format: str = 'react_en') -> Optional[str]:
tool_descs = []
tool_names = []
for info in TOOLS: # info: Dict[str, Union[str, dict]]
try:
if 'function' in info:
info = info['function']
tool_names.append(info['name'])
tool_descs.append(str(info)) # info: dict
except KeyError:
print('invalid tools format, please check'
'https://github.com/modelscope/swift/blob/main/docs/source_en/LLM/Agent-deployment-best-practice.md')
return None
prompt_format = tools_prompt.get(prompt_format) or format_toolbench
return prompt_format(tool_names, tool_descs)

View File

@@ -0,0 +1,542 @@
import base64
import hashlib
import math
import os
import re
from collections.abc import Mapping
from copy import deepcopy
from io import BytesIO
from typing import Any, Callable, List, TypeVar, Union, Tuple, Set, Dict, Type, Optional, Sequence
import numpy as np
import requests
import torch
from packaging import version
History = List[Union[Tuple[str, str], List[str]]]
Prompt = List[Union[str, List[int], List[str]]]
StopWords = Prompt
Context = Union[str, List[int]]
Messages = List[Dict[str, Union[str, List[Dict]]]]
# >>> internvl
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def split_str_parts_by(text: str, delimiters: List[str]):
"""Split the text field into parts.
Args:
text: A text to be split.
delimiters: The delimiters.
Returns:
The split text in list of dicts.
"""
assert isinstance(text, str), f'text: {text}'
all_start_chars = [d[0] for d in delimiters]
all_length = [len(d) for d in delimiters]
text_list = []
last_words = ''
while len(text) > 0:
for char_idx, char in enumerate(text):
match_index = [idx for idx, start_char in enumerate(all_start_chars) if start_char == char]
is_delimiter = False
for index in match_index:
if text[char_idx:char_idx + all_length[index]] == delimiters[index]:
if text_list:
text_list[-1]['content'] = last_words
elif last_words:
text_list.append({'key': '', 'content': last_words})
last_words = ''
text_list.append({'key': delimiters[index]})
text = text[char_idx + all_length[index]:]
is_delimiter = True
break
if not is_delimiter:
last_words += char
else:
break
if last_words == text:
text = ''
if len(text_list):
text_list[-1]['content'] = last_words
else:
text_list.append({'key': '', 'content': last_words})
return text_list
def split_parts_by_regex(text_list: list, regex_delimiters: Dict[str, List[float]]) -> None:
import re
compiled_patterns = [(re.compile(pattern), scale) for pattern, scale in regex_delimiters.items()]
for i in range(len(text_list) - 1, -1, -1):
item = text_list[i]
if item.get('key') == '':
res_text = item['content']
last_idx = 0
segments = []
for pattern, scale in compiled_patterns:
matches = list(re.finditer(pattern, res_text))
for match in matches:
if match.start() > last_idx:
segments.append({'key': '', 'content': res_text[last_idx:match.start()]})
segments.append({'key': scale[0], 'content': match.group(0)})
last_idx = match.end()
if last_idx < len(res_text):
segments.insert(0, {'key': '', 'content': res_text[last_idx:]})
if segments:
text_list[i:i + 1] = segments
def _decode_prompt(prompt: str, tmp_dir: str = 'tmp') -> str:
pattern = r'<(?:img|audio|video)>(.+?)</(?:img|audio|video)>'
match_iter = re.finditer(pattern, prompt)
new_content = ''
idx = 0
for m in match_iter:
span = m.span(1)
img_base64 = m.group(1)
img_path = _from_base64(img_base64, tmp_dir)
new_content += prompt[idx:span[0]] + img_path
idx = span[1]
new_content += prompt[idx:]
return new_content
def _to_base64(img_path: Union[str, 'PIL.Image.Image', bytes]) -> str:
if isinstance(img_path, str) and not os.path.isfile(img_path):
# base64
return img_path
if isinstance(img_path, str):
# local_path
with open(img_path, 'rb') as f:
_bytes = f.read()
elif not isinstance(img_path, bytes): # PIL.Image.Image
bytes_io = BytesIO()
img_path.save(bytes_io, format='png')
_bytes = bytes_io.getvalue()
else:
_bytes = img_path
img_base64: str = base64.b64encode(_bytes).decode('utf-8')
return img_base64
def _from_base64(img_base64: Union[str, 'PIL.Image.Image'], tmp_dir: str = 'tmp') -> str:
from PIL import Image
if not isinstance(img_base64, str): # PIL.Image.Image
img_base64 = _to_base64(img_base64)
if os.path.isfile(img_base64) or img_base64.startswith('http'):
return img_base64
sha256_hash = hashlib.sha256(img_base64.encode('utf-8')).hexdigest()
img_path = os.path.join(tmp_dir, f'{sha256_hash}.png')
image = Image.open(BytesIO(base64.b64decode(img_base64)))
if not os.path.exists(img_path):
image.save(img_path)
return img_path
def decode_base64(*,
messages: Optional[Messages] = None,
prompt: Optional[str] = None,
images: Optional[List[str]] = None,
tmp_dir: str = 'tmp') -> Dict[str, Any]:
# base64 -> local_path
os.makedirs(tmp_dir, exist_ok=True)
res = {}
if messages is not None:
res_messages = []
for m in messages:
m_new = deepcopy(m)
m_new['content'] = _decode_prompt(m_new['content'], tmp_dir)
res_messages.append(m_new)
res['messages'] = res_messages
if prompt is not None:
prompt = _decode_prompt(prompt, tmp_dir)
res['prompt'] = prompt
if images is not None:
res_images = []
for image in images:
image = _from_base64(image, tmp_dir)
res_images.append(image)
res['images'] = res_images
return res
def to_device(inputs: Any, device: torch.device) -> Any:
"""Move inputs to a device"""
if callable(getattr(inputs, 'to', None)):
return inputs.to(device=device)
if isinstance(inputs, Mapping):
res = {}
for k, v in inputs.items():
res[k] = to_device(v, device)
elif isinstance(inputs, Sequence) and not isinstance(inputs, str):
res = []
for b in inputs:
res.append(to_device(b, device))
else:
res = inputs
return res
def upper_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int:
# The upper bound satisfying the condition "cond".
while lo < hi:
mid = (lo + hi + 1) >> 1 # lo + (hi-lo+1)>>1
if cond(mid):
lo = mid
else:
hi = mid - 1
return lo
def fetch_one(element: Union[Tuple, List, Set, Dict, Any], type: Type = None) -> Any:
if isinstance(element, (tuple, set, list)):
for ele in element:
out = fetch_one(ele)
if out and (type is None or isinstance(out, type)):
return out
elif isinstance(element, dict):
return fetch_one(list(element.values()))
else:
return element
def _build_transform(input_size):
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def _dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = _find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size, ((i //
(target_width // image_size)) + 1) * image_size)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
# <<< internvl
def rescale_image(img: 'PIL.Image.Image', rescale_image: int = -1) -> 'PIL.Image.Image':
import torchvision.transforms as T
width = img.width
height = img.height
if rescale_image <= 0 or width * height <= rescale_image:
return img
ratio = width / height
height_scaled = math.pow(rescale_image / ratio, 0.5)
width_scaled = height_scaled * ratio
return T.Resize((int(width_scaled), int(height_scaled)))(img)
_T = TypeVar('_T')
def load_file(path: Union[str, _T]) -> Union[BytesIO, _T]:
res = path
if isinstance(path, str):
path = path.strip()
if path.startswith('http'):
request_kwargs = {}
timeout = float(os.getenv('TIMEOUT', '60'))
if timeout > 0:
request_kwargs['timeout'] = timeout
content = requests.get(path, **request_kwargs).content
res = BytesIO(content)
elif os.path.exists(path):
with open(path, 'rb') as f:
res = BytesIO(f.read())
else: # base64_str
import binascii
try:
data = base64.b64decode(path)
res = BytesIO(data)
except (ValueError, binascii.Error) as error:
if len(path) < 200:
raise ValueError(f'invalid image: "{path}"')
else:
raise ValueError(f'invalid image: {error}')
return res
def load_file_decorator(func):
def new_func(path, *args, **kwargs):
path = load_file(path)
res = func(path, *args, **kwargs)
return res
return new_func
@load_file_decorator
def load_image(image: Union['PIL.Image.Image', BytesIO]) -> 'PIL.Image.Image':
from PIL import Image
if isinstance(image, BytesIO):
image = Image.open(image)
if image.mode != 'RGB':
image = image.convert('RGB')
return image
def load_batch(path_list: List[Union[str, None, Any, BytesIO]],
load_func: Callable[[Any], _T] = load_image) -> List[_T]:
res = []
assert isinstance(path_list, (list, tuple)), f'path_list: {path_list}'
for path in path_list:
if path is None: # ignore None
continue
res.append(load_func(path))
return res
def _get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
if bound:
start, end = bound[0], bound[1]
else:
start, end = -100000, 100000
start_idx = max(first_idx, round(start * fps))
end_idx = min(round(end * fps), max_frame)
seg_size = float(end_idx - start_idx) / num_segments
frame_indices = np.array(
[int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
return frame_indices
def transform_image(image, input_size=448, max_num=12):
transform = _build_transform(input_size=input_size)
images = _dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
@load_file_decorator
def load_video_internvl(video_io: BytesIO, bound=None, num_segments=32):
from decord import VideoReader, cpu
from PIL import Image
vr = VideoReader(video_io, ctx=cpu(0), num_threads=1)
max_frame = len(vr) - 1
fps = float(vr.get_avg_fps())
images = []
frame_indices = _get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
for frame_index in frame_indices:
images.append(Image.fromarray(vr[frame_index].asnumpy()).convert('RGB'))
return images
def draw_plot(img_dir: str, bbox: List[int], bbox_type: str, output_file: str):
from PIL import Image, ImageDraw
from swift.llm.template.template import Template
image = Image.open(img_dir)
objects = [{'bbox': bbox, 'bbox_type': bbox_type, 'image': 0}]
Template.normalize_bbox(objects, [image], 'real')
bbox = objects[0]['bbox']
draw = ImageDraw.Draw(image)
draw.rectangle(bbox, outline='red', width=2)
image.save(output_file)
@load_file_decorator
def load_video_cogvlm2(video_io: BytesIO) -> np.ndarray:
from decord import cpu, VideoReader, bridge
bridge.set_bridge('torch')
clip_end_sec = 60
clip_start_sec = 0
num_frames = 24
decord_vr = VideoReader(video_io, ctx=cpu(0))
duration = len(decord_vr) # duration in terms of frames
start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
end_frame = min(duration, int(clip_end_sec * decord_vr.get_avg_fps())) if \
clip_end_sec is not None else duration
frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
video_data = decord_vr.get_batch(frame_id_list)
video_data = video_data.permute(3, 0, 1, 2)
return video_data
@load_file_decorator
def load_video_llava(video_io: BytesIO) -> np.ndarray:
import av
container = av.open(video_io)
total_frames = container.streams.video[0].frames
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
frames = []
container.seek(0)
start_index = indices[0]
end_index = indices[-1]
for i, frame in enumerate(container.decode(video=0)):
if i > end_index:
break
if i >= start_index and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format='rgb24') for x in frames])
@load_file_decorator
def load_video_minicpmv_mplug_owl3(video_io: BytesIO, max_num_frames):
from PIL import Image
from decord import VideoReader, cpu # pip install decord
def uniform_sample(_l, _n):
gap = len(_l) / _n
idxs = [int(i * gap + gap / 2) for i in range(_n)]
return [_l[i] for i in idxs]
vr = VideoReader(video_io, ctx=cpu(0))
sample_fps = round(vr.get_avg_fps() / 1) # FPS
frame_idx = [i for i in range(0, len(vr), sample_fps)]
if len(frame_idx) > max_num_frames:
frame_idx = uniform_sample(frame_idx, max_num_frames)
frames = vr.get_batch(frame_idx).asnumpy()
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
return frames
@load_file_decorator
def load_audio_qwen(audio_io: BytesIO, sampling_rate: int):
import librosa
return librosa.load(audio_io, sr=sampling_rate)[0]
def load_video_qwen2(video_path: str):
from swift.llm.template.template import get_env_args
import torchvision
from torchvision import io, transforms
from qwen_vl_utils.vision_process import (round_by_factor, FPS, FRAME_FACTOR, FPS_MIN_FRAMES, FPS_MAX_FRAMES,
VIDEO_MIN_PIXELS, VIDEO_MAX_PIXELS, VIDEO_TOTAL_PIXELS, smart_resize,
ceil_by_factor, floor_by_factor)
from torchvision.transforms import InterpolationMode
if version.parse(torchvision.__version__) >= version.parse('0.19'):
video_path = load_file(video_path)
video, _, info = io.read_video(
video_path,
pts_unit='sec',
output_format='TCHW',
)
nframes = get_env_args('nframes', int, None)
fps = get_env_args('fps', int, None)
size_factor = get_env_args('size_factor', int, FRAME_FACTOR)
assert not (fps and nframes), 'Only accept either `fps` or `nframes`'
if nframes is not None:
nframes = round_by_factor(nframes, size_factor)
else:
fps = FPS
nframes = video.size(0) / info['video_fps'] * fps
nframes = round_by_factor(nframes, size_factor)
min_frames = get_env_args('min_frames', int, FPS_MIN_FRAMES)
max_frames = get_env_args('max_frames', int, FPS_MAX_FRAMES)
if nframes < min_frames:
nframes = ceil_by_factor(min_frames, size_factor)
if nframes > max_frames:
nframes = floor_by_factor(max_frames, size_factor)
if not (size_factor <= nframes and nframes <= video.size(0)):
raise ValueError(f'nframes should in interval [{size_factor}, {video.size(0)}], but got {nframes}.')
idx = torch.linspace(0, video.size(0) - 1, nframes).round().long()
height, width = video.shape[2:]
video = video[idx]
min_pixels = get_env_args('min_pixels', int, VIDEO_MIN_PIXELS)
total_pixels = get_env_args('total_pixels', int, VIDEO_TOTAL_PIXELS)
max_pixels = get_env_args('max_pixels', int, None)
if max_pixels is None:
max_pixels = VIDEO_MAX_PIXELS
max_pixels = max(min(max_pixels, total_pixels / nframes * size_factor), min_pixels * 1.05)
# resize
resized_height = get_env_args('resized_height', int, None)
resized_width = get_env_args('resized_width', int, None)
if resized_height and resized_width:
resized_height, resized_width = smart_resize(
resized_height,
resized_width,
factor=size_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
video = transforms.functional.resize(
video,
[resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
).float()
return video
if __name__ == '__main__':
# A test main to draw bbox
draw_plot('man.jpg', [354, 462, 580, 738], 'norm_1000', 'man_bbox.jpg')

0
tests/tools/__init__.py Normal file
View File

View File

@@ -0,0 +1,106 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from modelscope.preprocessors.templates import TemplateType
from modelscope.preprocessors.templates.loader import TemplateLoader
from modelscope.utils.test_utils import test_level
class TestToOllama(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_load_template(self):
template = TemplateLoader.load_by_model_id(
'LLM-Research/Meta-Llama-3-8B-Instruct')
self.assertTrue(template.template_type == TemplateType.llama3)
template = TemplateLoader.load_by_model_id(
'swift/Meta-Llama-3-70B-Instruct-AWQ')
self.assertTrue(template.template_type == TemplateType.llama3)
template = TemplateLoader.load_by_model_id(
'deepseek-ai/DeepSeek-V2-Lite-Chat')
self.assertTrue(template.template_type == TemplateType.deepseek2)
template = TemplateLoader.load_by_model_id('deepseek-ai/DeepSeek-V2.5')
self.assertTrue(template.template_type == TemplateType.deepseek2_5)
template = TemplateLoader.load_by_model_id(
'deepseek-ai/deepseek-coder-1.3b-instruct')
self.assertTrue(template.template_type == TemplateType.deepseek_coder)
template = TemplateLoader.load_by_model_id(
'OpenBuddy/openbuddy-deepseek-67b-v15.2')
self.assertTrue(template is None)
template = TemplateLoader.load_by_model_id(
'deepseek-ai/deepseek-llm-67b-chat')
self.assertTrue(template.template_type == TemplateType.deepseek)
template = TemplateLoader.load_by_model_id(
'deepseek-ai/DeepSeek-Coder-V2-Instruct')
self.assertTrue(template.template_type == TemplateType.deepseek2)
template = TemplateLoader.load_by_model_id('01ai/Yi-1.5-9B-Chat')
self.assertTrue(template.template_type == TemplateType.chatml)
template = TemplateLoader.load_by_model_id('01ai/Yi-Coder-9B-Chat')
self.assertTrue(template.template_type == TemplateType.yi_coder)
template = TemplateLoader.load_by_model_id(
'LLM-Research/gemma-2-27b-it')
self.assertTrue(template.template_type == TemplateType.gemma)
template = TemplateLoader.load_by_model_id('AI-ModelScope/gemma-2b')
self.assertTrue(template is None)
template = TemplateLoader.load_by_model_id(
'AI-ModelScope/gemma-2b-instruct')
self.assertTrue(template is None)
template = TemplateLoader.load_by_model_id(
'AI-ModelScope/gemma2-2b-instruct')
self.assertTrue(template.template_type == TemplateType.gemma)
template = TemplateLoader.load_by_model_id(
'AI-ModelScope/paligemma-3b-mix-224')
self.assertTrue(template is None)
template = TemplateLoader.load_by_model_id(
'LLM-Research/Phi-3-vision-128k-instruct')
self.assertTrue(template is None)
template = TemplateLoader.load_by_model_id(
'LLM-Research/Phi-3-128k-instruct')
self.assertTrue(template.template_type == TemplateType.phi3)
template = TemplateLoader.load_by_model_id(
'LLM-Research/Phi-3-128k-instruct-GGUF')
self.assertTrue(template.template_type == TemplateType.phi3)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_load_ollama(self):
ollama = TemplateLoader.to_ollama(
'LLM-Research/Meta-Llama-3.1-8B-Instruct-GGUF')
self.assertTrue(ollama is not None)
ollama = TemplateLoader.to_ollama(
'QuantFactory/Gemma-2-Ataraxy-9B-Chat-GGUF')
self.assertTrue(ollama is not None)
ollama = TemplateLoader.to_ollama('Xorbits/Llama-2-7b-Chat-GGUF')
self.assertTrue(ollama is not None)
ollama = TemplateLoader.to_ollama(
'AI-ModelScope/gemma2-2b-instruct-GGUF')
self.assertTrue(ollama is not None)
ollama = TemplateLoader.to_ollama(
'LLM-Research/Phi-3-128k-instruct-GGUF')
self.assertTrue(ollama is not None)
ollama = TemplateLoader.to_ollama(template_name='phi3')
self.assertTrue(ollama is not None)
ollama = TemplateLoader.to_ollama(
'QuantFactory/Mistral-Nemo-Japanese-Instruct-2408-GGUF')
self.assertTrue(ollama is not None)
if __name__ == '__main__':
unittest.main()