diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e6e9b77..a8565f16 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,5 @@ +exclude: 'modelscope/preprocessors/templates/' + repos: - repo: https://github.com/pycqa/flake8.git rev: 4.0.0 diff --git a/.pre-commit-config_local.yaml b/.pre-commit-config_local.yaml index a68a5b78..869d8fd6 100644 --- a/.pre-commit-config_local.yaml +++ b/.pre-commit-config_local.yaml @@ -1,3 +1,5 @@ +exclude: 'modelscope/preprocessors/templates/' + repos: - repo: /home/admin/pre-commit/flake8 rev: 4.0.0 diff --git a/modelscope/preprocessors/templates/__init__.py b/modelscope/preprocessors/templates/__init__.py new file mode 100644 index 00000000..5ac1780d --- /dev/null +++ b/modelscope/preprocessors/templates/__init__.py @@ -0,0 +1,2 @@ +from .base import Template, get_template +from .template import TemplateType diff --git a/modelscope/preprocessors/templates/base.py b/modelscope/preprocessors/templates/base.py new file mode 100644 index 00000000..4504a4bc --- /dev/null +++ b/modelscope/preprocessors/templates/base.py @@ -0,0 +1,1041 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import json +import re +from copy import deepcopy +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from modelscope import get_logger +from torch.nn import Module +from torch.nn.utils.rnn import pad_sequence +from transformers import PreTrainedTokenizerBase, StoppingCriteria +from .loss_scale import loss_scale_map +from .tools_prompt import get_tools_prompt +from .utils import load_batch, load_image, rescale_image, fetch_one, to_device, decode_base64 +from .utils import History, Prompt, StopWords, Context, Messages + +logger = get_logger() + +DEFAULT_SYSTEM = 'You are a helpful assistant.' + +TEMPLATE_MAPPING: Dict[str, Dict[str, Any]] = {} + + +def get_template( + template_type: str, + tokenizer: PreTrainedTokenizerBase, + default_system: Optional[str] = None, + max_length: Optional[int] = None, + truncation_strategy: Literal['delete', 'truncation_left'] = 'delete', + **kwargs, +) -> 'Template': + template_info = TEMPLATE_MAPPING[template_type] + template = deepcopy(template_info['template']) + template.init_template(tokenizer, default_system, max_length, truncation_strategy, **kwargs) + return template + + +def _findall(token_list: List[int], sub_token_list: Union[int, List[int]]) -> List[int]: + """Find the index of a token in the token_list.""" + if isinstance(sub_token_list, int): + sub_token_list = [sub_token_list] + res = [] + idx = -1 + try: + while True: + idx = token_list.index(sub_token_list[0], idx + 1) + if len(sub_token_list) == 1 or sub_token_list == token_list[idx:idx + len(sub_token_list)]: + res.append(idx) + except ValueError: + pass + return res + + +def replace_img_tag(messages: Messages, + replace_token: str, + pattern=r'(.+?)') -> Tuple[str, History, List[str]]: + images_path = [] + new_messages = [] + for i, m in enumerate(messages): + m = m.copy() + if m['content'] is None or m['role'] in ('tool', 'system', 'assistant'): + new_messages.append(m) + else: + images_path += re.findall(pattern, m['content']) + m['content'] = re.sub(pattern, replace_token, m['content']) + new_messages.append(m) + return messages, images_path + + +class StopWordsCriteria(StoppingCriteria): + """Adding extra stop words in template to prevent unstoppable generation + Like suffixes and chat seps in the template. + """ + def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_words: StopWords, **tokenizer_kwargs) -> None: + self.tokenizer = tokenizer + self.stop_words = stop_words + self.tokenizer_kwargs = tokenizer_kwargs + self.start_idx = -1 + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs) -> bool: + if self.start_idx == -1: + self.start_idx = len(input_ids[0]) - 1 + tokenizer = self.tokenizer + stop_words = self.stop_words + # [-20:]: Assuming the end tokens do not exceed 20 tokens, + # to avoid input_ids being too long and affecting efficiency. + text = tokenizer.decode(input_ids[0, self.start_idx:][-20:], **self.tokenizer_kwargs) + for stop_word in stop_words: + if isinstance(stop_word, str): + if stop_word in text: + return True + else: # list + if len(stop_word) > 0 and input_ids[0].tolist()[-len(stop_word):] == stop_word: + return True + return False + + +class Template: + """A template class for all supported models. + + Args: + prefix: Prefix tokens before the first turn's prompt + prompt: A list of elements whose types are str and list of integers. The input query part of every turn. + chat_sep: The chat separators between every turn. + suffix: The end tokens after the chat finished. + default_system: A default system instruction. + system_prefix: The prefix if the `system` is not empty. + auto_add_bos: By default, the bos_token is not added. The auto_add_bos option will determine + whether to add it based on `tokenizer.encode('')`. + tools_prompt: The tools prompt name + tool_prompt: The tool prompt, usually useful when there is a tool role + padding_side: The padding side + infer_media_type: The media type supported by the multi-modals + Examples: + system\nYou are a helpful assistant!\nWho are you?\nassistant:I am a robot\nWho are you?\nassistant:I am a robot # noqa + ----------system------------ ---query---- --response- -----chatsep----- ---query--- --response- ----suffix----- + ----------------------------system_prefix---------------------------- ---------------------------- prompt ------------------------------------- ---------------------------- prompt ------------------------------------- + + """ + + special_tokens = ['', '