From 2c0a455023bb506a2ef52ca270d8286366779064 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Thu, 10 Oct 2024 18:26:31 +0800 Subject: [PATCH] fix template and ollama --- modelscope/preprocessors/templates/loader.py | 53 ++++++++++++++------ tests/tools/test_to_ollama.py | 3 ++ 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/modelscope/preprocessors/templates/loader.py b/modelscope/preprocessors/templates/loader.py index e286802b..4f83a9a9 100644 --- a/modelscope/preprocessors/templates/loader.py +++ b/modelscope/preprocessors/templates/loader.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List import requests -from modelscope import AutoTokenizer, get_logger, snapshot_download +from modelscope import AutoTokenizer, get_logger, snapshot_download, AutoConfig from . import TemplateType from .base import Template, get_template @@ -262,6 +262,7 @@ class TemplateLoader: """ ignore_file_pattern = [r'.+\.bin$', r'.+\.safetensors$', r'.+\.gguf$'] tokenizer = kwargs.get('tokenizer') + config = kwargs.get('config') for _info in template_info: if re.fullmatch(_info.template_regex, model_id): if _info.template: @@ -273,10 +274,11 @@ class TemplateLoader: ignore_file_pattern=ignore_file_pattern) tokenizer = AutoTokenizer.from_pretrained( model_dir, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) except Exception: pass return TemplateLoader.load_by_template_name( - _info.template, tokenizer=tokenizer, **kwargs) + _info.template, tokenizer=tokenizer, config=config, **kwargs) @staticmethod def load_by_template_name(template_name: str, **kwargs) -> Template: @@ -292,7 +294,9 @@ class TemplateLoader: Returns: The template instance """ - return get_template(template_name, tokenizer=kwargs.pop('tokenizer', None), **kwargs) + template = get_template(template_name, tokenizer=kwargs.pop('tokenizer', None), **kwargs) + template.config = kwargs.get('config') + return template @staticmethod def replace_and_concat(template: Template, template_list: List, @@ -330,33 +334,41 @@ class TemplateLoader: Returns: The ModelFile content, returns `None` if no template found """ + if not model_id and not template_name: + raise ValueError( + f'Please make sure you model_id: {model_id} ' + f'and template_name: {template_name} is supported.') logger.info('Exporting to ollama:') if model_id: for _info in template_info: if re.fullmatch(_info.template_regex, model_id): - if _info.modelfile_link: + if _info.modelfile_link and not kwargs.get('ignore_oss_model_file', False): return TemplateLoader._read_content_from_url( _info.modelfile_link) - elif _info.template and not template_name: - template_name = _info.template if template_name: template = TemplateLoader.load_by_template_name( template_name, **kwargs) else: - raise ValueError( - f'Please make sure you model_id: {model_id} ' - f'and template_name: {template_name} is supported.') + template = TemplateLoader.load_by_model_id( + model_id, **kwargs) if template is None: return None content = '' - content += 'FROM {{gguf_file}}\n' - content += ( - f'TEMPLATE """{{{{ if .System }}}}' - f'{TemplateLoader.replace_and_concat(template, template.system_prefix or [], "{{SYSTEM}}", "{{ .System }}")}' - f'{{{{ else }}}}{TemplateLoader.replace_and_concat(template, template.prefix, "", "")}' - f'{{{{ end }}}}') + content += 'FROM {gguf_file}\n' + _prefix = TemplateLoader.replace_and_concat(template, template.prefix, "", "") + if _prefix: + content += ( + f'TEMPLATE """{{{{ if .System }}}}' + f'{TemplateLoader.replace_and_concat(template, template.system_prefix or [], "{{SYSTEM}}", "{{ .System }}")}' + f'{{{{ else }}}}{_prefix}' + f'{{{{ end }}}}') + else: + content += ( + f'TEMPLATE """{{{{ if .System }}}}' + f'{TemplateLoader.replace_and_concat(template, template.system_prefix or [], "{{SYSTEM}}", "{{ .System }}")}' + f'{{{{ end }}}}') content += ( f'{{{{ if .Prompt }}}}' f'{TemplateLoader.replace_and_concat(template, template.prompt, "{{QUERY}}", "{{ .Prompt }}")}' @@ -364,7 +376,16 @@ class TemplateLoader: content += '{{ .Response }}' content += TemplateLoader.replace_and_concat(template, template.suffix, '', '') + '"""\n' - content += f'PARAMETER stop "{TemplateLoader.replace_and_concat(template, template.suffix, "", "")}"\n' + all_eos_tokens = {TemplateLoader.replace_and_concat(template, template.suffix, "", "")} + if getattr(template, 'tokenizer', None): + eos_token = TemplateLoader.replace_and_concat(template, [["eos_token_id"]], "", "") + all_eos_tokens.add(eos_token) + if getattr(template, 'config', None) and getattr(template.config, 'eos_token_id'): + eos_token_id = template.config.eos_token_id + eos_token = TemplateLoader.replace_and_concat(template, [[eos_token_id]], "", "") + all_eos_tokens.add(eos_token) + for eos_token in all_eos_tokens: + content += f'PARAMETER stop "{eos_token}"\n' return content @staticmethod diff --git a/tests/tools/test_to_ollama.py b/tests/tools/test_to_ollama.py index ba92c1ea..ad7a3e87 100644 --- a/tests/tools/test_to_ollama.py +++ b/tests/tools/test_to_ollama.py @@ -105,6 +105,9 @@ class TestToOllama(unittest.TestCase): ollama = TemplateLoader.to_ollama( 'AI-ModelScope/llava-llama-3-8b-v1_1-gguf') self.assertTrue(ollama is not None) + ollama = TemplateLoader.to_ollama( + '01ai/Yi-1.5-9B-Chat', ignore_oss_model_file=True) + self.assertTrue(ollama is not None) if __name__ == '__main__':