diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7e6e9b77..a8565f16 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,3 +1,5 @@
+exclude: 'modelscope/preprocessors/templates/'
+
repos:
- repo: https://github.com/pycqa/flake8.git
rev: 4.0.0
diff --git a/.pre-commit-config_local.yaml b/.pre-commit-config_local.yaml
index a68a5b78..869d8fd6 100644
--- a/.pre-commit-config_local.yaml
+++ b/.pre-commit-config_local.yaml
@@ -1,3 +1,5 @@
+exclude: 'modelscope/preprocessors/templates/'
+
repos:
- repo: /home/admin/pre-commit/flake8
rev: 4.0.0
diff --git a/modelscope/preprocessors/templates/__init__.py b/modelscope/preprocessors/templates/__init__.py
new file mode 100644
index 00000000..5ac1780d
--- /dev/null
+++ b/modelscope/preprocessors/templates/__init__.py
@@ -0,0 +1,2 @@
+from .base import Template, get_template
+from .template import TemplateType
diff --git a/modelscope/preprocessors/templates/base.py b/modelscope/preprocessors/templates/base.py
new file mode 100644
index 00000000..4504a4bc
--- /dev/null
+++ b/modelscope/preprocessors/templates/base.py
@@ -0,0 +1,1041 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import json
+import re
+from copy import deepcopy
+from typing import Any, Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from modelscope import get_logger
+from torch.nn import Module
+from torch.nn.utils.rnn import pad_sequence
+from transformers import PreTrainedTokenizerBase, StoppingCriteria
+from .loss_scale import loss_scale_map
+from .tools_prompt import get_tools_prompt
+from .utils import load_batch, load_image, rescale_image, fetch_one, to_device, decode_base64
+from .utils import History, Prompt, StopWords, Context, Messages
+
+logger = get_logger()
+
+DEFAULT_SYSTEM = 'You are a helpful assistant.'
+
+TEMPLATE_MAPPING: Dict[str, Dict[str, Any]] = {}
+
+
+def get_template(
+ template_type: str,
+ tokenizer: PreTrainedTokenizerBase,
+ default_system: Optional[str] = None,
+ max_length: Optional[int] = None,
+ truncation_strategy: Literal['delete', 'truncation_left'] = 'delete',
+ **kwargs,
+) -> 'Template':
+ template_info = TEMPLATE_MAPPING[template_type]
+ template = deepcopy(template_info['template'])
+ template.init_template(tokenizer, default_system, max_length, truncation_strategy, **kwargs)
+ return template
+
+
+def _findall(token_list: List[int], sub_token_list: Union[int, List[int]]) -> List[int]:
+ """Find the index of a token in the token_list."""
+ if isinstance(sub_token_list, int):
+ sub_token_list = [sub_token_list]
+ res = []
+ idx = -1
+ try:
+ while True:
+ idx = token_list.index(sub_token_list[0], idx + 1)
+ if len(sub_token_list) == 1 or sub_token_list == token_list[idx:idx + len(sub_token_list)]:
+ res.append(idx)
+ except ValueError:
+ pass
+ return res
+
+
+def replace_img_tag(messages: Messages,
+ replace_token: str,
+ pattern=r'
(.+?)') -> Tuple[str, History, List[str]]:
+ images_path = []
+ new_messages = []
+ for i, m in enumerate(messages):
+ m = m.copy()
+ if m['content'] is None or m['role'] in ('tool', 'system', 'assistant'):
+ new_messages.append(m)
+ else:
+ images_path += re.findall(pattern, m['content'])
+ m['content'] = re.sub(pattern, replace_token, m['content'])
+ new_messages.append(m)
+ return messages, images_path
+
+
+class StopWordsCriteria(StoppingCriteria):
+ """Adding extra stop words in template to prevent unstoppable generation
+ Like suffixes and chat seps in the template.
+ """
+ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_words: StopWords, **tokenizer_kwargs) -> None:
+ self.tokenizer = tokenizer
+ self.stop_words = stop_words
+ self.tokenizer_kwargs = tokenizer_kwargs
+ self.start_idx = -1
+
+ def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs) -> bool:
+ if self.start_idx == -1:
+ self.start_idx = len(input_ids[0]) - 1
+ tokenizer = self.tokenizer
+ stop_words = self.stop_words
+ # [-20:]: Assuming the end tokens do not exceed 20 tokens,
+ # to avoid input_ids being too long and affecting efficiency.
+ text = tokenizer.decode(input_ids[0, self.start_idx:][-20:], **self.tokenizer_kwargs)
+ for stop_word in stop_words:
+ if isinstance(stop_word, str):
+ if stop_word in text:
+ return True
+ else: # list
+ if len(stop_word) > 0 and input_ids[0].tolist()[-len(stop_word):] == stop_word:
+ return True
+ return False
+
+
+class Template:
+ """A template class for all supported models.
+
+ Args:
+ prefix: Prefix tokens before the first turn's prompt
+ prompt: A list of elements whose types are str and list of integers. The input query part of every turn.
+ chat_sep: The chat separators between every turn.
+ suffix: The end tokens after the chat finished.
+ default_system: A default system instruction.
+ system_prefix: The prefix if the `system` is not empty.
+ auto_add_bos: By default, the bos_token is not added. The auto_add_bos option will determine
+ whether to add it based on `tokenizer.encode('')`.
+ tools_prompt: The tools prompt name
+ tool_prompt: The tool prompt, usually useful when there is a tool role
+ padding_side: The padding side
+ infer_media_type: The media type supported by the multi-modals
+ Examples:
+ system\nYou are a helpful assistant!\nWho are you?\nassistant:I am a robot\nWho are you?\nassistant:I am a robot # noqa
+ ----------system------------ ---query---- --response- -----chatsep----- ---query--- --response- ----suffix-----
+ ----------------------------system_prefix---------------------------- ---------------------------- prompt ------------------------------------- ---------------------------- prompt -------------------------------------
+
+ """
+
+ special_tokens = ['', '