fix template and ollama

This commit is contained in:
yuze.zyz
2024-10-10 18:26:31 +08:00
parent ba7a783f23
commit 2c0a455023
2 changed files with 40 additions and 16 deletions

View File

@@ -4,7 +4,7 @@ from typing import Any, Dict, List
import requests
from modelscope import AutoTokenizer, get_logger, snapshot_download
from modelscope import AutoTokenizer, get_logger, snapshot_download, AutoConfig
from . import TemplateType
from .base import Template, get_template
@@ -262,6 +262,7 @@ class TemplateLoader:
"""
ignore_file_pattern = [r'.+\.bin$', r'.+\.safetensors$', r'.+\.gguf$']
tokenizer = kwargs.get('tokenizer')
config = kwargs.get('config')
for _info in template_info:
if re.fullmatch(_info.template_regex, model_id):
if _info.template:
@@ -273,10 +274,11 @@ class TemplateLoader:
ignore_file_pattern=ignore_file_pattern)
tokenizer = AutoTokenizer.from_pretrained(
model_dir, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True)
except Exception:
pass
return TemplateLoader.load_by_template_name(
_info.template, tokenizer=tokenizer, **kwargs)
_info.template, tokenizer=tokenizer, config=config, **kwargs)
@staticmethod
def load_by_template_name(template_name: str, **kwargs) -> Template:
@@ -292,7 +294,9 @@ class TemplateLoader:
Returns:
The template instance
"""
return get_template(template_name, tokenizer=kwargs.pop('tokenizer', None), **kwargs)
template = get_template(template_name, tokenizer=kwargs.pop('tokenizer', None), **kwargs)
template.config = kwargs.get('config')
return template
@staticmethod
def replace_and_concat(template: Template, template_list: List,
@@ -330,33 +334,41 @@ class TemplateLoader:
Returns:
The ModelFile content, returns `None` if no template found
"""
if not model_id and not template_name:
raise ValueError(
f'Please make sure you model_id: {model_id} '
f'and template_name: {template_name} is supported.')
logger.info('Exporting to ollama:')
if model_id:
for _info in template_info:
if re.fullmatch(_info.template_regex, model_id):
if _info.modelfile_link:
if _info.modelfile_link and not kwargs.get('ignore_oss_model_file', False):
return TemplateLoader._read_content_from_url(
_info.modelfile_link)
elif _info.template and not template_name:
template_name = _info.template
if template_name:
template = TemplateLoader.load_by_template_name(
template_name, **kwargs)
else:
raise ValueError(
f'Please make sure you model_id: {model_id} '
f'and template_name: {template_name} is supported.')
template = TemplateLoader.load_by_model_id(
model_id, **kwargs)
if template is None:
return None
content = ''
content += 'FROM {{gguf_file}}\n'
content += (
f'TEMPLATE """{{{{ if .System }}}}'
f'{TemplateLoader.replace_and_concat(template, template.system_prefix or [], "{{SYSTEM}}", "{{ .System }}")}'
f'{{{{ else }}}}{TemplateLoader.replace_and_concat(template, template.prefix, "", "")}'
f'{{{{ end }}}}')
content += 'FROM {gguf_file}\n'
_prefix = TemplateLoader.replace_and_concat(template, template.prefix, "", "")
if _prefix:
content += (
f'TEMPLATE """{{{{ if .System }}}}'
f'{TemplateLoader.replace_and_concat(template, template.system_prefix or [], "{{SYSTEM}}", "{{ .System }}")}'
f'{{{{ else }}}}{_prefix}'
f'{{{{ end }}}}')
else:
content += (
f'TEMPLATE """{{{{ if .System }}}}'
f'{TemplateLoader.replace_and_concat(template, template.system_prefix or [], "{{SYSTEM}}", "{{ .System }}")}'
f'{{{{ end }}}}')
content += (
f'{{{{ if .Prompt }}}}'
f'{TemplateLoader.replace_and_concat(template, template.prompt, "{{QUERY}}", "{{ .Prompt }}")}'
@@ -364,7 +376,16 @@ class TemplateLoader:
content += '{{ .Response }}'
content += TemplateLoader.replace_and_concat(template, template.suffix,
'', '') + '"""\n'
content += f'PARAMETER stop "{TemplateLoader.replace_and_concat(template, template.suffix, "", "")}"\n'
all_eos_tokens = {TemplateLoader.replace_and_concat(template, template.suffix, "", "")}
if getattr(template, 'tokenizer', None):
eos_token = TemplateLoader.replace_and_concat(template, [["eos_token_id"]], "", "")
all_eos_tokens.add(eos_token)
if getattr(template, 'config', None) and getattr(template.config, 'eos_token_id'):
eos_token_id = template.config.eos_token_id
eos_token = TemplateLoader.replace_and_concat(template, [[eos_token_id]], "", "")
all_eos_tokens.add(eos_token)
for eos_token in all_eos_tokens:
content += f'PARAMETER stop "{eos_token}"\n'
return content
@staticmethod

View File

@@ -105,6 +105,9 @@ class TestToOllama(unittest.TestCase):
ollama = TemplateLoader.to_ollama(
'AI-ModelScope/llava-llama-3-8b-v1_1-gguf')
self.assertTrue(ollama is not None)
ollama = TemplateLoader.to_ollama(
'01ai/Yi-1.5-9B-Chat', ignore_oss_model_file=True)
self.assertTrue(ollama is not None)
if __name__ == '__main__':