mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
fix template and ollama
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Any, Dict, List
|
||||
|
||||
import requests
|
||||
|
||||
from modelscope import AutoTokenizer, get_logger, snapshot_download
|
||||
from modelscope import AutoTokenizer, get_logger, snapshot_download, AutoConfig
|
||||
from . import TemplateType
|
||||
from .base import Template, get_template
|
||||
|
||||
@@ -262,6 +262,7 @@ class TemplateLoader:
|
||||
"""
|
||||
ignore_file_pattern = [r'.+\.bin$', r'.+\.safetensors$', r'.+\.gguf$']
|
||||
tokenizer = kwargs.get('tokenizer')
|
||||
config = kwargs.get('config')
|
||||
for _info in template_info:
|
||||
if re.fullmatch(_info.template_regex, model_id):
|
||||
if _info.template:
|
||||
@@ -273,10 +274,11 @@ class TemplateLoader:
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_dir, trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
|
||||
except Exception:
|
||||
pass
|
||||
return TemplateLoader.load_by_template_name(
|
||||
_info.template, tokenizer=tokenizer, **kwargs)
|
||||
_info.template, tokenizer=tokenizer, config=config, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load_by_template_name(template_name: str, **kwargs) -> Template:
|
||||
@@ -292,7 +294,9 @@ class TemplateLoader:
|
||||
Returns:
|
||||
The template instance
|
||||
"""
|
||||
return get_template(template_name, tokenizer=kwargs.pop('tokenizer', None), **kwargs)
|
||||
template = get_template(template_name, tokenizer=kwargs.pop('tokenizer', None), **kwargs)
|
||||
template.config = kwargs.get('config')
|
||||
return template
|
||||
|
||||
@staticmethod
|
||||
def replace_and_concat(template: Template, template_list: List,
|
||||
@@ -330,33 +334,41 @@ class TemplateLoader:
|
||||
Returns:
|
||||
The ModelFile content, returns `None` if no template found
|
||||
"""
|
||||
if not model_id and not template_name:
|
||||
raise ValueError(
|
||||
f'Please make sure you model_id: {model_id} '
|
||||
f'and template_name: {template_name} is supported.')
|
||||
logger.info('Exporting to ollama:')
|
||||
if model_id:
|
||||
for _info in template_info:
|
||||
if re.fullmatch(_info.template_regex, model_id):
|
||||
if _info.modelfile_link:
|
||||
if _info.modelfile_link and not kwargs.get('ignore_oss_model_file', False):
|
||||
return TemplateLoader._read_content_from_url(
|
||||
_info.modelfile_link)
|
||||
elif _info.template and not template_name:
|
||||
template_name = _info.template
|
||||
if template_name:
|
||||
template = TemplateLoader.load_by_template_name(
|
||||
template_name, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Please make sure you model_id: {model_id} '
|
||||
f'and template_name: {template_name} is supported.')
|
||||
template = TemplateLoader.load_by_model_id(
|
||||
model_id, **kwargs)
|
||||
|
||||
if template is None:
|
||||
return None
|
||||
|
||||
content = ''
|
||||
content += 'FROM {{gguf_file}}\n'
|
||||
content += (
|
||||
f'TEMPLATE """{{{{ if .System }}}}'
|
||||
f'{TemplateLoader.replace_and_concat(template, template.system_prefix or [], "{{SYSTEM}}", "{{ .System }}")}'
|
||||
f'{{{{ else }}}}{TemplateLoader.replace_and_concat(template, template.prefix, "", "")}'
|
||||
f'{{{{ end }}}}')
|
||||
content += 'FROM {gguf_file}\n'
|
||||
_prefix = TemplateLoader.replace_and_concat(template, template.prefix, "", "")
|
||||
if _prefix:
|
||||
content += (
|
||||
f'TEMPLATE """{{{{ if .System }}}}'
|
||||
f'{TemplateLoader.replace_and_concat(template, template.system_prefix or [], "{{SYSTEM}}", "{{ .System }}")}'
|
||||
f'{{{{ else }}}}{_prefix}'
|
||||
f'{{{{ end }}}}')
|
||||
else:
|
||||
content += (
|
||||
f'TEMPLATE """{{{{ if .System }}}}'
|
||||
f'{TemplateLoader.replace_and_concat(template, template.system_prefix or [], "{{SYSTEM}}", "{{ .System }}")}'
|
||||
f'{{{{ end }}}}')
|
||||
content += (
|
||||
f'{{{{ if .Prompt }}}}'
|
||||
f'{TemplateLoader.replace_and_concat(template, template.prompt, "{{QUERY}}", "{{ .Prompt }}")}'
|
||||
@@ -364,7 +376,16 @@ class TemplateLoader:
|
||||
content += '{{ .Response }}'
|
||||
content += TemplateLoader.replace_and_concat(template, template.suffix,
|
||||
'', '') + '"""\n'
|
||||
content += f'PARAMETER stop "{TemplateLoader.replace_and_concat(template, template.suffix, "", "")}"\n'
|
||||
all_eos_tokens = {TemplateLoader.replace_and_concat(template, template.suffix, "", "")}
|
||||
if getattr(template, 'tokenizer', None):
|
||||
eos_token = TemplateLoader.replace_and_concat(template, [["eos_token_id"]], "", "")
|
||||
all_eos_tokens.add(eos_token)
|
||||
if getattr(template, 'config', None) and getattr(template.config, 'eos_token_id'):
|
||||
eos_token_id = template.config.eos_token_id
|
||||
eos_token = TemplateLoader.replace_and_concat(template, [[eos_token_id]], "", "")
|
||||
all_eos_tokens.add(eos_token)
|
||||
for eos_token in all_eos_tokens:
|
||||
content += f'PARAMETER stop "{eos_token}"\n'
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -105,6 +105,9 @@ class TestToOllama(unittest.TestCase):
|
||||
ollama = TemplateLoader.to_ollama(
|
||||
'AI-ModelScope/llava-llama-3-8b-v1_1-gguf')
|
||||
self.assertTrue(ollama is not None)
|
||||
ollama = TemplateLoader.to_ollama(
|
||||
'01ai/Yi-1.5-9B-Chat', ignore_oss_model_file=True)
|
||||
self.assertTrue(ollama is not None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user