mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-14 15:27:42 +01:00
template and ollama in modelscope (#995)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
exclude: 'modelscope/preprocessors/templates/'
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pycqa/flake8.git
|
||||
rev: 4.0.0
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
exclude: 'modelscope/preprocessors/templates/'
|
||||
|
||||
repos:
|
||||
- repo: /home/admin/pre-commit/flake8
|
||||
rev: 4.0.0
|
||||
|
||||
2
modelscope/preprocessors/templates/__init__.py
Normal file
2
modelscope/preprocessors/templates/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .base import Template, get_template
|
||||
from .template import TemplateType
|
||||
1041
modelscope/preprocessors/templates/base.py
Normal file
1041
modelscope/preprocessors/templates/base.py
Normal file
File diff suppressed because it is too large
Load Diff
371
modelscope/preprocessors/templates/loader.py
Normal file
371
modelscope/preprocessors/templates/loader.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
from modelscope import AutoTokenizer, get_logger, snapshot_download
|
||||
from . import TemplateType
|
||||
from .base import Template, get_template
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemplateInfo:
|
||||
|
||||
template: str = None
|
||||
template_regex: str = None
|
||||
modelfile_link: str = None
|
||||
|
||||
|
||||
def cases(*names):
|
||||
ret = []
|
||||
for name in names:
|
||||
regex = ''
|
||||
for letter in name:
|
||||
if letter.upper() != letter.lower():
|
||||
regex += f'[{letter.upper()}{letter.lower()}]'
|
||||
else:
|
||||
regex += letter
|
||||
ret.append(regex)
|
||||
if len(ret) > 1:
|
||||
ret = '|'.join(ret)
|
||||
ret = '(' + ret + ')'
|
||||
else:
|
||||
ret = ret[0]
|
||||
return ret
|
||||
|
||||
|
||||
chat_suffix = cases('instruct', 'chat', '-rl', '-it')
|
||||
|
||||
|
||||
def no(*names):
|
||||
return f'(?!.*{cases(*names)})'
|
||||
|
||||
|
||||
def no_multi_modal():
|
||||
return no('audio', 'video', 'vl', 'vision')
|
||||
|
||||
|
||||
template_info = [
|
||||
# llama
|
||||
TemplateInfo(
|
||||
template=TemplateType.llama3,
|
||||
template_regex=
|
||||
f'.*{cases("llama3", "llama-3")}{no_multi_modal()}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/llama-3.modelfile',
|
||||
),
|
||||
TemplateInfo(
|
||||
template=TemplateType.llama,
|
||||
template_regex=
|
||||
f'.*{cases("llama2", "llama-2", "mistral", "codestral", "mixtral")}{no_multi_modal()}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# qwen
|
||||
TemplateInfo(
|
||||
template=TemplateType.qwen,
|
||||
template_regex=f'.*{cases("qwen")}{no_multi_modal()}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/qwen2.modelfile',
|
||||
),
|
||||
|
||||
# codeqwen1.5
|
||||
TemplateInfo(
|
||||
template_regex=
|
||||
f'.*{cases("codeqwen1.5", "codeqwen-1.5")}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/codeqwen1.5.modelfile',
|
||||
),
|
||||
|
||||
# chatml
|
||||
TemplateInfo(
|
||||
template=TemplateType.chatml,
|
||||
template_regex=
|
||||
f'.*{cases("yi")}{no_multi_modal()}{no("coder")}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/yi-1.5.modelfile',
|
||||
),
|
||||
|
||||
# chatml
|
||||
TemplateInfo(
|
||||
template=TemplateType.chatml,
|
||||
template_regex=f'.*{cases("minicpm")}{no("-v")}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/yi-1.5.modelfile'
|
||||
),
|
||||
|
||||
# chatglm
|
||||
TemplateInfo(
|
||||
template=TemplateType.chatglm2,
|
||||
template_regex=f'.*{cases("chatglm2")}{no_multi_modal()}.*'),
|
||||
TemplateInfo(
|
||||
template=TemplateType.chatglm3,
|
||||
template_regex=f'.*{cases("chatglm3")}{no_multi_modal()}.*'),
|
||||
TemplateInfo(
|
||||
template=TemplateType.chatglm4,
|
||||
template_regex=f'.*{cases("glm4")}{no_multi_modal()}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/glm4.modelfile',
|
||||
),
|
||||
|
||||
# baichuan
|
||||
TemplateInfo(
|
||||
template=TemplateType.baichuan,
|
||||
template_regex=
|
||||
f'.*{cases("baichuan")}{no_multi_modal()}.*{chat_suffix}.*'),
|
||||
|
||||
# codegeex
|
||||
TemplateInfo(
|
||||
template=TemplateType.codegeex4,
|
||||
template_regex=f'.*{cases("codegeex4")}{no_multi_modal()}.*'),
|
||||
|
||||
# idefics3
|
||||
TemplateInfo(
|
||||
template=TemplateType.idefics3,
|
||||
template_regex=f'.*{cases("idefics3")}{no_multi_modal()}.*'),
|
||||
|
||||
# mistral-nemo
|
||||
TemplateInfo(
|
||||
template=TemplateType.mistral_nemo,
|
||||
template_regex=f'.*{cases("Mistral-Nemo")}{no_multi_modal()}.*',
|
||||
modelfile_link='https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/mistral-nemo.modelfile'),
|
||||
|
||||
# internlm
|
||||
TemplateInfo(
|
||||
template=TemplateType.internlm,
|
||||
template_regex=
|
||||
f'.*{cases("internlm")}{no("internlm2", "internlm3")}{no_multi_modal()}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# internlm2
|
||||
TemplateInfo(
|
||||
template=TemplateType.internlm2,
|
||||
template_regex=
|
||||
f'.*{cases("internlm2")}{no_multi_modal()}.*{chat_suffix}.*'),
|
||||
|
||||
# yi-coder
|
||||
TemplateInfo(
|
||||
template=TemplateType.yi_coder,
|
||||
template_regex=f'.*{cases("yi")}.*{cases("coder")}.*{chat_suffix}.*'),
|
||||
|
||||
# yuan
|
||||
TemplateInfo(
|
||||
template=TemplateType.yuan,
|
||||
template_regex=f'.*{cases("Yuan")}{no_multi_modal()}.*'),
|
||||
|
||||
# xverse
|
||||
TemplateInfo(
|
||||
template=TemplateType.xverse,
|
||||
template_regex=f'.*{cases("xverse")}{no_multi_modal()}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# skywork
|
||||
TemplateInfo(
|
||||
template=TemplateType.skywork,
|
||||
template_regex=
|
||||
f'.*{cases("skywork")}{no_multi_modal()}.*{chat_suffix}.*'),
|
||||
|
||||
# bluelm
|
||||
TemplateInfo(
|
||||
template=TemplateType.bluelm,
|
||||
template_regex=f'.*{cases("bluelm")}{no_multi_modal()}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# zephyr
|
||||
TemplateInfo(
|
||||
template=TemplateType.zephyr,
|
||||
template_regex=f'.*{cases("zephyr")}{no_multi_modal()}.*'),
|
||||
|
||||
# deepseek
|
||||
TemplateInfo(
|
||||
template=TemplateType.deepseek,
|
||||
template_regex=
|
||||
f'.*{cases("deepseek")}{no("v2", "v2.5", "coder")}{no_multi_modal()}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# deepseek2
|
||||
TemplateInfo(
|
||||
template=TemplateType.deepseek2,
|
||||
template_regex=
|
||||
f'.*{cases("deepseek")}.*{cases("v2")}{no("v2.5")}{no_multi_modal()}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/deepseek_v2.modelfile',
|
||||
),
|
||||
|
||||
# deepseek_coder
|
||||
TemplateInfo(
|
||||
template=TemplateType.deepseek_coder,
|
||||
template_regex=
|
||||
f'.*{cases("deepseek")}{no("v2", "v2.5")}.*{cases("coder")}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# deepseek v2.5
|
||||
TemplateInfo(
|
||||
template=TemplateType.deepseek2_5,
|
||||
template_regex=
|
||||
f'.*{cases("deepseek")}.*{cases("v2.5")}{no_multi_modal()}.*'),
|
||||
|
||||
# orion
|
||||
TemplateInfo(
|
||||
template=TemplateType.orion,
|
||||
template_regex=f'.*{cases("orion")}{no_multi_modal()}.*{chat_suffix}.*'
|
||||
),
|
||||
|
||||
# gemma
|
||||
TemplateInfo(
|
||||
template=TemplateType.gemma,
|
||||
template_regex=
|
||||
f'{no("pali")}.*{cases("gemma2", "gemma-2")}\\b.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/gemma2.modelfile',
|
||||
),
|
||||
|
||||
# phi3
|
||||
TemplateInfo(
|
||||
template=TemplateType.phi3,
|
||||
template_regex=
|
||||
f'.*{cases("phi3", "phi-3")}{no_multi_modal()}.*{chat_suffix}.*',
|
||||
modelfile_link=
|
||||
'https://modelscope.oss-cn-beijing.aliyuncs.com/llm_template/ollama/phi3.modelfile',
|
||||
),
|
||||
|
||||
# telechat
|
||||
TemplateInfo(
|
||||
template=TemplateType.telechat,
|
||||
template_regex=f'.*{cases("TeleChat")}{no("v2")}.*'),
|
||||
|
||||
# telechat_v2
|
||||
TemplateInfo(
|
||||
template=TemplateType.telechat_v2,
|
||||
template_regex=f'.*{cases("TeleChat")}.*{cases("v2")}.*'),
|
||||
]
|
||||
|
||||
|
||||
class TemplateLoader:
|
||||
|
||||
@staticmethod
|
||||
def load_by_model_id(model_id: str, **kwargs) -> Template:
|
||||
"""Load a template by model-id
|
||||
|
||||
Args:
|
||||
model_id: The model-id used to load the proper template
|
||||
kwargs:
|
||||
revision: the revision of the model, default is `master`
|
||||
Returns:
|
||||
The template instance
|
||||
"""
|
||||
ignore_file_pattern = [r'.+\.bin$', r'.+\.safetensors$', r'.+\.gguf$']
|
||||
tokenizer = kwargs.get('tokenizer')
|
||||
for _info in template_info:
|
||||
if re.fullmatch(_info.template_regex, model_id):
|
||||
if _info.template:
|
||||
if tokenizer is None:
|
||||
try:
|
||||
model_dir = snapshot_download(
|
||||
model_id,
|
||||
revision=kwargs.pop('revision', 'master'),
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_dir, trust_remote_code=True)
|
||||
except Exception:
|
||||
pass
|
||||
return TemplateLoader.load_by_template_name(
|
||||
_info.template, tokenizer=tokenizer, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load_by_template_name(template_name: str, **kwargs) -> Template:
|
||||
"""Load a template by model-id
|
||||
|
||||
Args:
|
||||
template_name: The template name used to load the proper template
|
||||
kwargs:
|
||||
tokenizer: The tokenizer of the model
|
||||
default_system: The extra default system info
|
||||
max_length: The max_length for the sequence
|
||||
truncation_strategy: 'delete' or 'truncation_left' the sequence of the length exceeds the limit
|
||||
Returns:
|
||||
The template instance
|
||||
"""
|
||||
return get_template(template_name, tokenizer=kwargs.pop('tokenizer', None), **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def replace_and_concat(template: Template, template_list: List,
|
||||
placeholder: str, keyword: str):
|
||||
final_str = ''
|
||||
for t in template_list:
|
||||
if isinstance(t, str):
|
||||
final_str += t.replace(placeholder, keyword)
|
||||
elif isinstance(t, (tuple, list)):
|
||||
if isinstance(t[0], int):
|
||||
final_str += template.tokenizer.decode(t)
|
||||
else:
|
||||
for attr in t:
|
||||
if attr == 'bos_token_id':
|
||||
final_str += template.tokenizer.bos_token
|
||||
elif attr == 'eos_token_id':
|
||||
final_str += template.tokenizer.eos_token
|
||||
else:
|
||||
raise ValueError(f'Unknown token: {attr}')
|
||||
return final_str
|
||||
|
||||
@staticmethod
|
||||
def to_ollama(model_id: str = None,
|
||||
template_name: str = None,
|
||||
gguf_file: str = None,
|
||||
gguf_meta: Dict[str, Any] = None,
|
||||
**kwargs) -> str:
|
||||
"""Export to ollama ModelFile
|
||||
|
||||
Args:
|
||||
model_id: The model-id to use
|
||||
template_name: An extra template name to use
|
||||
gguf_file: An extra gguf_file path to use in the `FROM` field
|
||||
gguf_meta: An gguf extra meta info
|
||||
Returns:
|
||||
The ModelFile content, returns `None` if no template found
|
||||
"""
|
||||
logger.info('Exporting to ollama:')
|
||||
if model_id:
|
||||
for _info in template_info:
|
||||
if re.fullmatch(_info.template_regex, model_id):
|
||||
if _info.modelfile_link:
|
||||
return TemplateLoader._read_content_from_url(
|
||||
_info.modelfile_link)
|
||||
elif _info.template and not template_name:
|
||||
template_name = _info.template
|
||||
if template_name:
|
||||
template = TemplateLoader.load_by_template_name(
|
||||
template_name, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Please make sure you model_id: {model_id} '
|
||||
f'and template_name: {template_name} is supported.')
|
||||
|
||||
if template is None:
|
||||
return None
|
||||
|
||||
content = ''
|
||||
content += 'FROM {{gguf_file}}\n'
|
||||
content += (
|
||||
f'TEMPLATE """{{{{ if .System }}}}'
|
||||
f'{TemplateLoader.replace_and_concat(template, template.system_prefix or [], "{{SYSTEM}}", "{{ .System }}")}'
|
||||
f'{{{{ else }}}}{TemplateLoader.replace_and_concat(template, template.prefix, "", "")}'
|
||||
f'{{{{ end }}}}')
|
||||
content += (
|
||||
f'{{{{ if .Prompt }}}}'
|
||||
f'{TemplateLoader.replace_and_concat(template, template.prompt, "{{QUERY}}", "{{ .Prompt }}")}'
|
||||
f'{{{{ end }}}}')
|
||||
content += '{{ .Response }}'
|
||||
content += TemplateLoader.replace_and_concat(template, template.suffix,
|
||||
'', '') + '"""\n'
|
||||
content += f'PARAMETER stop "{TemplateLoader.replace_and_concat(template, template.suffix, "", "")}"\n'
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
def _read_content_from_url(url):
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
return content.decode('utf-8')
|
||||
101
modelscope/preprocessors/templates/loss_scale.py
Normal file
101
modelscope/preprocessors/templates/loss_scale.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from .utils import split_str_parts_by, split_parts_by_regex
|
||||
|
||||
|
||||
def calculate_loss_scale(query: str,
|
||||
response: str,
|
||||
response_loss_scale_map: Optional[Dict[str, list]] = None,
|
||||
query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]:
|
||||
"""Calculate the loss scale by splitting the agent response.
|
||||
|
||||
This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf
|
||||
|
||||
Agent response format:
|
||||
|
||||
```text
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of the above tools[fire_recognition,
|
||||
fire_alert, call_police, call_fireman]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
||||
Thought: I now know the final answer
|
||||
Final Answer: the final answer to the original input question
|
||||
```
|
||||
Returns:
|
||||
A tuple of agent response parts and their weights.
|
||||
"""
|
||||
# query loss scale map
|
||||
if query_loss_scale_map is not None:
|
||||
for key in query_loss_scale_map.keys():
|
||||
if key in query:
|
||||
if isinstance(query_loss_scale_map[key], (float, int)):
|
||||
query_loss_scale_map[key] = [query_loss_scale_map[key]]
|
||||
loss_scale_value = query_loss_scale_map[key][0]
|
||||
return [response], [float(loss_scale_value)]
|
||||
delimiters = list(k for k in response_loss_scale_map.keys() if len(response_loss_scale_map[k]) == 2)
|
||||
agent_parts = split_str_parts_by(response, delimiters)
|
||||
regex_delimiters = {k: v for k, v in response_loss_scale_map.items() if len(v) == 1}
|
||||
if len(regex_delimiters):
|
||||
split_parts_by_regex(agent_parts, regex_delimiters)
|
||||
weights = []
|
||||
agent_content = []
|
||||
for c in agent_parts:
|
||||
if isinstance(c['key'], (float, int)):
|
||||
weights += [c['key']]
|
||||
agent_content.append(c['content'])
|
||||
else:
|
||||
if c['key'] in response_loss_scale_map:
|
||||
weights += [response_loss_scale_map[c['key']][0]]
|
||||
weights += [response_loss_scale_map[c['key']][1]]
|
||||
agent_content.append(c['key'])
|
||||
agent_content.append(c['content'])
|
||||
else:
|
||||
weights += [1.0]
|
||||
agent_content.append(c['content'])
|
||||
return agent_content, weights
|
||||
|
||||
|
||||
def alpha_umi_loss_scale(query: str, response: str):
|
||||
cwd = os.getcwd()
|
||||
loss_scale_config_path = 'alpha_umi_loss_scale_config.json'
|
||||
config_path = os.path.join(cwd, loss_scale_config_path)
|
||||
with open(config_path, 'r') as json_file:
|
||||
loss_scale_map = json.load(json_file)
|
||||
return calculate_loss_scale(query, response, loss_scale_map)
|
||||
|
||||
|
||||
def agentflan_loss_scale(query: str, response: str):
|
||||
cwd = os.getcwd()
|
||||
loss_scale_config_path = 'agentflan.json'
|
||||
config_path = os.path.join(cwd, loss_scale_config_path)
|
||||
with open(config_path, 'r') as json_file:
|
||||
loss_scale_map = json.load(json_file)
|
||||
query_loss_scale_map = loss_scale_map['query']
|
||||
response_loss_scale_map = loss_scale_map['response']
|
||||
return calculate_loss_scale(query, response, response_loss_scale_map, query_loss_scale_map)
|
||||
|
||||
|
||||
def react_loss_scale(query: str, response: str):
|
||||
cwd = os.getcwd()
|
||||
loss_scale_config_path = 'default_loss_scale_config.json'
|
||||
config_path = os.path.join(cwd, loss_scale_config_path)
|
||||
with open(config_path, 'r') as json_file:
|
||||
loss_scale_map = json.load(json_file)
|
||||
return calculate_loss_scale(query, response, loss_scale_map)
|
||||
|
||||
|
||||
def default_loss_scale(query: str, response: str):
|
||||
return [response], [1.0]
|
||||
|
||||
|
||||
loss_scale_map = {
|
||||
'agentflan': agentflan_loss_scale,
|
||||
'react': react_loss_scale,
|
||||
'alpha_umi': alpha_umi_loss_scale,
|
||||
'default': default_loss_scale,
|
||||
}
|
||||
2274
modelscope/preprocessors/templates/template.py
Normal file
2274
modelscope/preprocessors/templates/template.py
Normal file
File diff suppressed because it is too large
Load Diff
107
modelscope/preprocessors/templates/tools_prompt.py
Normal file
107
modelscope/preprocessors/templates/tools_prompt.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from typing import List, Dict, Union, Optional
|
||||
|
||||
|
||||
def format_react_en(tool_names, tool_descs):
|
||||
REACT_PROMPT = """Answer the following questions as best as you can. You have access to the following tools:
|
||||
|
||||
{tool_list}
|
||||
|
||||
Use the following format:
|
||||
|
||||
Thought: you should always think about what to do
|
||||
Action: the action to take, should be one of [{tool_names}]
|
||||
Action Input: the input to the action
|
||||
Observation: the result of the action
|
||||
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
||||
Final Answer: the final answer to the original input question
|
||||
|
||||
Begin!
|
||||
"""
|
||||
return REACT_PROMPT.format(tool_list='\n\n'.join(tool_descs), tool_names=','.join(tool_names))
|
||||
|
||||
|
||||
def format_react_zh(tool_names, tool_descs):
|
||||
REACT_ZH_PROMPT = """尽你所能回答以下问题。你拥有如下工具:
|
||||
|
||||
{tool_list}
|
||||
|
||||
使用以下格式回答:
|
||||
|
||||
Thought: 思考你应该做什么
|
||||
Action: 工具的名称,必须是[{tool_names}]之一
|
||||
Action Input: 工具的输入
|
||||
Observation: 工具返回的结果
|
||||
... (Thought/Action/Action Input/Observation的过程可以重复零次或多次)
|
||||
Final Answer: 对输入问题的最终答案
|
||||
|
||||
开始!
|
||||
"""
|
||||
return REACT_ZH_PROMPT.format(tool_list='\n\n'.join(tool_descs), tool_names=','.join(tool_names))
|
||||
|
||||
|
||||
def format_glm4(tool_names, tool_descs):
|
||||
GLM4_PROMPT = '''你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。
|
||||
|
||||
# 可用工具
|
||||
|
||||
{tool_list}'''
|
||||
tool_list = ''
|
||||
for name, tool in zip(tool_names, tool_descs):
|
||||
tool_list += f'## {name}\n\n{tool}\n\n'
|
||||
return GLM4_PROMPT.format(tool_list=tool_list)
|
||||
|
||||
|
||||
def format_toolbench(tool_names, tool_descs):
|
||||
TOOLBENCH_PROMPT = '''You can use many tools(functions) to do the following task.
|
||||
First I will give you the task description, and your task start.
|
||||
At each step, you need to give your thought to analyze the status now and what to do next, \
|
||||
with a function call to actually excute your step. Your output should follow this format:
|
||||
Thought:
|
||||
Action:
|
||||
Action Input:
|
||||
|
||||
After the call, you will get the call result, and you are now in a new state.
|
||||
Then you will analyze your status now, then decide what to do next...
|
||||
After many (Thought-call) pairs, you finally perform the task, then you can give your finial answer.
|
||||
Remember:
|
||||
1.the state change is irreversible, you can't go back to one of the former state, if you want to restart the task, \
|
||||
say \"I give up and restart\".
|
||||
2.All the thought is short, at most in 5 sentence.
|
||||
3.You can do more then one trys, so if your plan is to continusly try some conditions, \
|
||||
you can do one of the conditions per try.
|
||||
Let's Begin!
|
||||
Task description: You should use functions to help handle the real time user querys. Remember:
|
||||
1.ALWAYS call \"Finish\" function at the end of the task. And the final answer should contain enough information \
|
||||
to show to the user,If you can't handle the task, \
|
||||
or you find that function calls always fail(the function is not valid now), \
|
||||
use function Finish->give_up_and_restart.
|
||||
2.Do not use origin tool names, use only subfunctions' names.
|
||||
Specifically, you have access to the following APIs: {tool_list}'''
|
||||
return TOOLBENCH_PROMPT.format(tool_list='\n\n'.join(tool_descs))
|
||||
|
||||
|
||||
tools_prompt = {
|
||||
'react_en': format_react_en,
|
||||
'react_zh': format_react_zh,
|
||||
'glm4': format_glm4,
|
||||
'toolbench': format_toolbench,
|
||||
}
|
||||
|
||||
|
||||
def get_tools_prompt(TOOLS: List[Dict[str, Union[str, dict]]], prompt_format: str = 'react_en') -> Optional[str]:
|
||||
tool_descs = []
|
||||
tool_names = []
|
||||
for info in TOOLS: # info: Dict[str, Union[str, dict]]
|
||||
try:
|
||||
if 'function' in info:
|
||||
info = info['function']
|
||||
tool_names.append(info['name'])
|
||||
tool_descs.append(str(info)) # info: dict
|
||||
except KeyError:
|
||||
print('invalid tools format, please check'
|
||||
'https://github.com/modelscope/swift/blob/main/docs/source_en/LLM/Agent-deployment-best-practice.md')
|
||||
return None
|
||||
prompt_format = tools_prompt.get(prompt_format) or format_toolbench
|
||||
return prompt_format(tool_names, tool_descs)
|
||||
|
||||
|
||||
542
modelscope/preprocessors/templates/utils.py
Normal file
542
modelscope/preprocessors/templates/utils.py
Normal file
@@ -0,0 +1,542 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from copy import deepcopy
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, List, TypeVar, Union, Tuple, Set, Dict, Type, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
|
||||
History = List[Union[Tuple[str, str], List[str]]]
|
||||
Prompt = List[Union[str, List[int], List[str]]]
|
||||
StopWords = Prompt
|
||||
Context = Union[str, List[int]]
|
||||
Messages = List[Dict[str, Union[str, List[Dict]]]]
|
||||
|
||||
|
||||
# >>> internvl
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
|
||||
def split_str_parts_by(text: str, delimiters: List[str]):
|
||||
"""Split the text field into parts.
|
||||
|
||||
Args:
|
||||
text: A text to be split.
|
||||
delimiters: The delimiters.
|
||||
|
||||
Returns:
|
||||
The split text in list of dicts.
|
||||
"""
|
||||
assert isinstance(text, str), f'text: {text}'
|
||||
all_start_chars = [d[0] for d in delimiters]
|
||||
all_length = [len(d) for d in delimiters]
|
||||
|
||||
text_list = []
|
||||
last_words = ''
|
||||
|
||||
while len(text) > 0:
|
||||
for char_idx, char in enumerate(text):
|
||||
match_index = [idx for idx, start_char in enumerate(all_start_chars) if start_char == char]
|
||||
is_delimiter = False
|
||||
for index in match_index:
|
||||
if text[char_idx:char_idx + all_length[index]] == delimiters[index]:
|
||||
if text_list:
|
||||
text_list[-1]['content'] = last_words
|
||||
elif last_words:
|
||||
text_list.append({'key': '', 'content': last_words})
|
||||
last_words = ''
|
||||
text_list.append({'key': delimiters[index]})
|
||||
text = text[char_idx + all_length[index]:]
|
||||
is_delimiter = True
|
||||
break
|
||||
if not is_delimiter:
|
||||
last_words += char
|
||||
else:
|
||||
break
|
||||
if last_words == text:
|
||||
text = ''
|
||||
|
||||
if len(text_list):
|
||||
text_list[-1]['content'] = last_words
|
||||
else:
|
||||
text_list.append({'key': '', 'content': last_words})
|
||||
return text_list
|
||||
|
||||
|
||||
def split_parts_by_regex(text_list: list, regex_delimiters: Dict[str, List[float]]) -> None:
|
||||
import re
|
||||
compiled_patterns = [(re.compile(pattern), scale) for pattern, scale in regex_delimiters.items()]
|
||||
for i in range(len(text_list) - 1, -1, -1):
|
||||
item = text_list[i]
|
||||
if item.get('key') == '':
|
||||
res_text = item['content']
|
||||
last_idx = 0
|
||||
segments = []
|
||||
|
||||
for pattern, scale in compiled_patterns:
|
||||
matches = list(re.finditer(pattern, res_text))
|
||||
for match in matches:
|
||||
if match.start() > last_idx:
|
||||
segments.append({'key': '', 'content': res_text[last_idx:match.start()]})
|
||||
segments.append({'key': scale[0], 'content': match.group(0)})
|
||||
last_idx = match.end()
|
||||
|
||||
if last_idx < len(res_text):
|
||||
segments.insert(0, {'key': '', 'content': res_text[last_idx:]})
|
||||
|
||||
if segments:
|
||||
text_list[i:i + 1] = segments
|
||||
|
||||
|
||||
def _decode_prompt(prompt: str, tmp_dir: str = 'tmp') -> str:
|
||||
pattern = r'<(?:img|audio|video)>(.+?)</(?:img|audio|video)>'
|
||||
match_iter = re.finditer(pattern, prompt)
|
||||
new_content = ''
|
||||
idx = 0
|
||||
for m in match_iter:
|
||||
span = m.span(1)
|
||||
img_base64 = m.group(1)
|
||||
img_path = _from_base64(img_base64, tmp_dir)
|
||||
new_content += prompt[idx:span[0]] + img_path
|
||||
idx = span[1]
|
||||
new_content += prompt[idx:]
|
||||
return new_content
|
||||
|
||||
|
||||
def _to_base64(img_path: Union[str, 'PIL.Image.Image', bytes]) -> str:
|
||||
if isinstance(img_path, str) and not os.path.isfile(img_path):
|
||||
# base64
|
||||
return img_path
|
||||
if isinstance(img_path, str):
|
||||
# local_path
|
||||
with open(img_path, 'rb') as f:
|
||||
_bytes = f.read()
|
||||
elif not isinstance(img_path, bytes): # PIL.Image.Image
|
||||
bytes_io = BytesIO()
|
||||
img_path.save(bytes_io, format='png')
|
||||
_bytes = bytes_io.getvalue()
|
||||
else:
|
||||
_bytes = img_path
|
||||
img_base64: str = base64.b64encode(_bytes).decode('utf-8')
|
||||
return img_base64
|
||||
|
||||
|
||||
def _from_base64(img_base64: Union[str, 'PIL.Image.Image'], tmp_dir: str = 'tmp') -> str:
|
||||
from PIL import Image
|
||||
if not isinstance(img_base64, str): # PIL.Image.Image
|
||||
img_base64 = _to_base64(img_base64)
|
||||
if os.path.isfile(img_base64) or img_base64.startswith('http'):
|
||||
return img_base64
|
||||
sha256_hash = hashlib.sha256(img_base64.encode('utf-8')).hexdigest()
|
||||
img_path = os.path.join(tmp_dir, f'{sha256_hash}.png')
|
||||
image = Image.open(BytesIO(base64.b64decode(img_base64)))
|
||||
if not os.path.exists(img_path):
|
||||
image.save(img_path)
|
||||
return img_path
|
||||
|
||||
|
||||
def decode_base64(*,
|
||||
messages: Optional[Messages] = None,
|
||||
prompt: Optional[str] = None,
|
||||
images: Optional[List[str]] = None,
|
||||
tmp_dir: str = 'tmp') -> Dict[str, Any]:
|
||||
# base64 -> local_path
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
res = {}
|
||||
if messages is not None:
|
||||
res_messages = []
|
||||
for m in messages:
|
||||
m_new = deepcopy(m)
|
||||
m_new['content'] = _decode_prompt(m_new['content'], tmp_dir)
|
||||
res_messages.append(m_new)
|
||||
res['messages'] = res_messages
|
||||
if prompt is not None:
|
||||
prompt = _decode_prompt(prompt, tmp_dir)
|
||||
res['prompt'] = prompt
|
||||
if images is not None:
|
||||
res_images = []
|
||||
for image in images:
|
||||
image = _from_base64(image, tmp_dir)
|
||||
res_images.append(image)
|
||||
res['images'] = res_images
|
||||
return res
|
||||
|
||||
|
||||
def to_device(inputs: Any, device: torch.device) -> Any:
|
||||
"""Move inputs to a device"""
|
||||
if callable(getattr(inputs, 'to', None)):
|
||||
return inputs.to(device=device)
|
||||
|
||||
if isinstance(inputs, Mapping):
|
||||
res = {}
|
||||
for k, v in inputs.items():
|
||||
res[k] = to_device(v, device)
|
||||
elif isinstance(inputs, Sequence) and not isinstance(inputs, str):
|
||||
res = []
|
||||
for b in inputs:
|
||||
res.append(to_device(b, device))
|
||||
else:
|
||||
res = inputs
|
||||
return res
|
||||
|
||||
|
||||
def upper_bound(lo: int, hi: int, cond: Callable[[int], bool]) -> int:
|
||||
# The upper bound satisfying the condition "cond".
|
||||
while lo < hi:
|
||||
mid = (lo + hi + 1) >> 1 # lo + (hi-lo+1)>>1
|
||||
if cond(mid):
|
||||
lo = mid
|
||||
else:
|
||||
hi = mid - 1
|
||||
return lo
|
||||
|
||||
|
||||
def fetch_one(element: Union[Tuple, List, Set, Dict, Any], type: Type = None) -> Any:
|
||||
if isinstance(element, (tuple, set, list)):
|
||||
for ele in element:
|
||||
out = fetch_one(ele)
|
||||
if out and (type is None or isinstance(out, type)):
|
||||
return out
|
||||
elif isinstance(element, dict):
|
||||
return fetch_one(list(element.values()))
|
||||
else:
|
||||
return element
|
||||
|
||||
|
||||
def _build_transform(input_size):
|
||||
import torchvision.transforms as T
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||
transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD)
|
||||
])
|
||||
return transform
|
||||
|
||||
|
||||
def _find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
||||
best_ratio_diff = float('inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
|
||||
def _dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
|
||||
orig_width, orig_height = image.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = _find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = ((i % (target_width // image_size)) * image_size, (i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size, ((i //
|
||||
(target_width // image_size)) + 1) * image_size)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images
|
||||
|
||||
|
||||
# <<< internvl
|
||||
|
||||
|
||||
def rescale_image(img: 'PIL.Image.Image', rescale_image: int = -1) -> 'PIL.Image.Image':
|
||||
import torchvision.transforms as T
|
||||
width = img.width
|
||||
height = img.height
|
||||
if rescale_image <= 0 or width * height <= rescale_image:
|
||||
return img
|
||||
|
||||
ratio = width / height
|
||||
height_scaled = math.pow(rescale_image / ratio, 0.5)
|
||||
width_scaled = height_scaled * ratio
|
||||
return T.Resize((int(width_scaled), int(height_scaled)))(img)
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
|
||||
def load_file(path: Union[str, _T]) -> Union[BytesIO, _T]:
|
||||
res = path
|
||||
if isinstance(path, str):
|
||||
path = path.strip()
|
||||
if path.startswith('http'):
|
||||
request_kwargs = {}
|
||||
timeout = float(os.getenv('TIMEOUT', '60'))
|
||||
if timeout > 0:
|
||||
request_kwargs['timeout'] = timeout
|
||||
content = requests.get(path, **request_kwargs).content
|
||||
res = BytesIO(content)
|
||||
elif os.path.exists(path):
|
||||
with open(path, 'rb') as f:
|
||||
res = BytesIO(f.read())
|
||||
else: # base64_str
|
||||
import binascii
|
||||
try:
|
||||
data = base64.b64decode(path)
|
||||
res = BytesIO(data)
|
||||
except (ValueError, binascii.Error) as error:
|
||||
if len(path) < 200:
|
||||
raise ValueError(f'invalid image: "{path}"')
|
||||
else:
|
||||
raise ValueError(f'invalid image: {error}')
|
||||
return res
|
||||
|
||||
|
||||
def load_file_decorator(func):
|
||||
|
||||
def new_func(path, *args, **kwargs):
|
||||
path = load_file(path)
|
||||
res = func(path, *args, **kwargs)
|
||||
return res
|
||||
|
||||
return new_func
|
||||
|
||||
|
||||
@load_file_decorator
|
||||
def load_image(image: Union['PIL.Image.Image', BytesIO]) -> 'PIL.Image.Image':
|
||||
from PIL import Image
|
||||
if isinstance(image, BytesIO):
|
||||
image = Image.open(image)
|
||||
if image.mode != 'RGB':
|
||||
image = image.convert('RGB')
|
||||
return image
|
||||
|
||||
|
||||
def load_batch(path_list: List[Union[str, None, Any, BytesIO]],
|
||||
load_func: Callable[[Any], _T] = load_image) -> List[_T]:
|
||||
res = []
|
||||
assert isinstance(path_list, (list, tuple)), f'path_list: {path_list}'
|
||||
for path in path_list:
|
||||
if path is None: # ignore None
|
||||
continue
|
||||
res.append(load_func(path))
|
||||
return res
|
||||
|
||||
|
||||
def _get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
||||
if bound:
|
||||
start, end = bound[0], bound[1]
|
||||
else:
|
||||
start, end = -100000, 100000
|
||||
start_idx = max(first_idx, round(start * fps))
|
||||
end_idx = min(round(end * fps), max_frame)
|
||||
seg_size = float(end_idx - start_idx) / num_segments
|
||||
frame_indices = np.array(
|
||||
[int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) for idx in range(num_segments)])
|
||||
return frame_indices
|
||||
|
||||
|
||||
def transform_image(image, input_size=448, max_num=12):
|
||||
transform = _build_transform(input_size=input_size)
|
||||
images = _dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
return pixel_values
|
||||
|
||||
|
||||
@load_file_decorator
|
||||
def load_video_internvl(video_io: BytesIO, bound=None, num_segments=32):
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
vr = VideoReader(video_io, ctx=cpu(0), num_threads=1)
|
||||
max_frame = len(vr) - 1
|
||||
fps = float(vr.get_avg_fps())
|
||||
|
||||
images = []
|
||||
frame_indices = _get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments)
|
||||
for frame_index in frame_indices:
|
||||
images.append(Image.fromarray(vr[frame_index].asnumpy()).convert('RGB'))
|
||||
return images
|
||||
|
||||
|
||||
def draw_plot(img_dir: str, bbox: List[int], bbox_type: str, output_file: str):
|
||||
from PIL import Image, ImageDraw
|
||||
from swift.llm.template.template import Template
|
||||
image = Image.open(img_dir)
|
||||
|
||||
objects = [{'bbox': bbox, 'bbox_type': bbox_type, 'image': 0}]
|
||||
Template.normalize_bbox(objects, [image], 'real')
|
||||
bbox = objects[0]['bbox']
|
||||
draw = ImageDraw.Draw(image)
|
||||
draw.rectangle(bbox, outline='red', width=2)
|
||||
image.save(output_file)
|
||||
|
||||
|
||||
@load_file_decorator
|
||||
def load_video_cogvlm2(video_io: BytesIO) -> np.ndarray:
|
||||
from decord import cpu, VideoReader, bridge
|
||||
bridge.set_bridge('torch')
|
||||
clip_end_sec = 60
|
||||
clip_start_sec = 0
|
||||
num_frames = 24
|
||||
decord_vr = VideoReader(video_io, ctx=cpu(0))
|
||||
duration = len(decord_vr) # duration in terms of frames
|
||||
start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
|
||||
end_frame = min(duration, int(clip_end_sec * decord_vr.get_avg_fps())) if \
|
||||
clip_end_sec is not None else duration
|
||||
frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
|
||||
video_data = decord_vr.get_batch(frame_id_list)
|
||||
video_data = video_data.permute(3, 0, 1, 2)
|
||||
return video_data
|
||||
|
||||
|
||||
@load_file_decorator
|
||||
def load_video_llava(video_io: BytesIO) -> np.ndarray:
|
||||
import av
|
||||
container = av.open(video_io)
|
||||
total_frames = container.streams.video[0].frames
|
||||
indices = np.arange(0, total_frames, total_frames / 8).astype(int)
|
||||
frames = []
|
||||
container.seek(0)
|
||||
start_index = indices[0]
|
||||
end_index = indices[-1]
|
||||
for i, frame in enumerate(container.decode(video=0)):
|
||||
if i > end_index:
|
||||
break
|
||||
if i >= start_index and i in indices:
|
||||
frames.append(frame)
|
||||
return np.stack([x.to_ndarray(format='rgb24') for x in frames])
|
||||
|
||||
|
||||
@load_file_decorator
|
||||
def load_video_minicpmv_mplug_owl3(video_io: BytesIO, max_num_frames):
|
||||
from PIL import Image
|
||||
from decord import VideoReader, cpu # pip install decord
|
||||
|
||||
def uniform_sample(_l, _n):
|
||||
gap = len(_l) / _n
|
||||
idxs = [int(i * gap + gap / 2) for i in range(_n)]
|
||||
return [_l[i] for i in idxs]
|
||||
|
||||
vr = VideoReader(video_io, ctx=cpu(0))
|
||||
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
||||
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
||||
|
||||
if len(frame_idx) > max_num_frames:
|
||||
frame_idx = uniform_sample(frame_idx, max_num_frames)
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
frames = [Image.fromarray(v.astype('uint8')) for v in frames]
|
||||
return frames
|
||||
|
||||
|
||||
@load_file_decorator
|
||||
def load_audio_qwen(audio_io: BytesIO, sampling_rate: int):
|
||||
import librosa
|
||||
return librosa.load(audio_io, sr=sampling_rate)[0]
|
||||
|
||||
|
||||
def load_video_qwen2(video_path: str):
|
||||
from swift.llm.template.template import get_env_args
|
||||
import torchvision
|
||||
from torchvision import io, transforms
|
||||
from qwen_vl_utils.vision_process import (round_by_factor, FPS, FRAME_FACTOR, FPS_MIN_FRAMES, FPS_MAX_FRAMES,
|
||||
VIDEO_MIN_PIXELS, VIDEO_MAX_PIXELS, VIDEO_TOTAL_PIXELS, smart_resize,
|
||||
ceil_by_factor, floor_by_factor)
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
if version.parse(torchvision.__version__) >= version.parse('0.19'):
|
||||
video_path = load_file(video_path)
|
||||
video, _, info = io.read_video(
|
||||
video_path,
|
||||
pts_unit='sec',
|
||||
output_format='TCHW',
|
||||
)
|
||||
nframes = get_env_args('nframes', int, None)
|
||||
fps = get_env_args('fps', int, None)
|
||||
size_factor = get_env_args('size_factor', int, FRAME_FACTOR)
|
||||
assert not (fps and nframes), 'Only accept either `fps` or `nframes`'
|
||||
if nframes is not None:
|
||||
nframes = round_by_factor(nframes, size_factor)
|
||||
else:
|
||||
fps = FPS
|
||||
nframes = video.size(0) / info['video_fps'] * fps
|
||||
nframes = round_by_factor(nframes, size_factor)
|
||||
min_frames = get_env_args('min_frames', int, FPS_MIN_FRAMES)
|
||||
max_frames = get_env_args('max_frames', int, FPS_MAX_FRAMES)
|
||||
if nframes < min_frames:
|
||||
nframes = ceil_by_factor(min_frames, size_factor)
|
||||
if nframes > max_frames:
|
||||
nframes = floor_by_factor(max_frames, size_factor)
|
||||
|
||||
if not (size_factor <= nframes and nframes <= video.size(0)):
|
||||
raise ValueError(f'nframes should in interval [{size_factor}, {video.size(0)}], but got {nframes}.')
|
||||
|
||||
idx = torch.linspace(0, video.size(0) - 1, nframes).round().long()
|
||||
height, width = video.shape[2:]
|
||||
video = video[idx]
|
||||
|
||||
min_pixels = get_env_args('min_pixels', int, VIDEO_MIN_PIXELS)
|
||||
total_pixels = get_env_args('total_pixels', int, VIDEO_TOTAL_PIXELS)
|
||||
max_pixels = get_env_args('max_pixels', int, None)
|
||||
if max_pixels is None:
|
||||
max_pixels = VIDEO_MAX_PIXELS
|
||||
max_pixels = max(min(max_pixels, total_pixels / nframes * size_factor), min_pixels * 1.05)
|
||||
# resize
|
||||
resized_height = get_env_args('resized_height', int, None)
|
||||
resized_width = get_env_args('resized_width', int, None)
|
||||
if resized_height and resized_width:
|
||||
resized_height, resized_width = smart_resize(
|
||||
resized_height,
|
||||
resized_width,
|
||||
factor=size_factor,
|
||||
)
|
||||
else:
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=size_factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
|
||||
video = transforms.functional.resize(
|
||||
video,
|
||||
[resized_height, resized_width],
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
antialias=True,
|
||||
).float()
|
||||
return video
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# A test main to draw bbox
|
||||
draw_plot('man.jpg', [354, 462, 580, 738], 'norm_1000', 'man_bbox.jpg')
|
||||
0
tests/tools/__init__.py
Normal file
0
tests/tools/__init__.py
Normal file
106
tests/tools/test_to_ollama.py
Normal file
106
tests/tools/test_to_ollama.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
from modelscope.preprocessors.templates import TemplateType
|
||||
from modelscope.preprocessors.templates.loader import TemplateLoader
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestToOllama(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_load_template(self):
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'LLM-Research/Meta-Llama-3-8B-Instruct')
|
||||
self.assertTrue(template.template_type == TemplateType.llama3)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'swift/Meta-Llama-3-70B-Instruct-AWQ')
|
||||
self.assertTrue(template.template_type == TemplateType.llama3)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'deepseek-ai/DeepSeek-V2-Lite-Chat')
|
||||
self.assertTrue(template.template_type == TemplateType.deepseek2)
|
||||
|
||||
template = TemplateLoader.load_by_model_id('deepseek-ai/DeepSeek-V2.5')
|
||||
self.assertTrue(template.template_type == TemplateType.deepseek2_5)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'deepseek-ai/deepseek-coder-1.3b-instruct')
|
||||
self.assertTrue(template.template_type == TemplateType.deepseek_coder)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'OpenBuddy/openbuddy-deepseek-67b-v15.2')
|
||||
self.assertTrue(template is None)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'deepseek-ai/deepseek-llm-67b-chat')
|
||||
self.assertTrue(template.template_type == TemplateType.deepseek)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'deepseek-ai/DeepSeek-Coder-V2-Instruct')
|
||||
self.assertTrue(template.template_type == TemplateType.deepseek2)
|
||||
|
||||
template = TemplateLoader.load_by_model_id('01ai/Yi-1.5-9B-Chat')
|
||||
self.assertTrue(template.template_type == TemplateType.chatml)
|
||||
|
||||
template = TemplateLoader.load_by_model_id('01ai/Yi-Coder-9B-Chat')
|
||||
self.assertTrue(template.template_type == TemplateType.yi_coder)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'LLM-Research/gemma-2-27b-it')
|
||||
self.assertTrue(template.template_type == TemplateType.gemma)
|
||||
|
||||
template = TemplateLoader.load_by_model_id('AI-ModelScope/gemma-2b')
|
||||
self.assertTrue(template is None)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'AI-ModelScope/gemma-2b-instruct')
|
||||
self.assertTrue(template is None)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'AI-ModelScope/gemma2-2b-instruct')
|
||||
self.assertTrue(template.template_type == TemplateType.gemma)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'AI-ModelScope/paligemma-3b-mix-224')
|
||||
self.assertTrue(template is None)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'LLM-Research/Phi-3-vision-128k-instruct')
|
||||
self.assertTrue(template is None)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'LLM-Research/Phi-3-128k-instruct')
|
||||
self.assertTrue(template.template_type == TemplateType.phi3)
|
||||
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
'LLM-Research/Phi-3-128k-instruct-GGUF')
|
||||
self.assertTrue(template.template_type == TemplateType.phi3)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_load_ollama(self):
|
||||
ollama = TemplateLoader.to_ollama(
|
||||
'LLM-Research/Meta-Llama-3.1-8B-Instruct-GGUF')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama(
|
||||
'QuantFactory/Gemma-2-Ataraxy-9B-Chat-GGUF')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama('Xorbits/Llama-2-7b-Chat-GGUF')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama(
|
||||
'AI-ModelScope/gemma2-2b-instruct-GGUF')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama(
|
||||
'LLM-Research/Phi-3-128k-instruct-GGUF')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama(template_name='phi3')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama(
|
||||
'QuantFactory/Mistral-Nemo-Japanese-Instruct-2408-GGUF')
|
||||
self.assertTrue(ollama is not None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user