From dae0988f36db5f8d066ebc1bd02a24decd167d62 Mon Sep 17 00:00:00 2001 From: suluyana Date: Tue, 29 Oct 2024 14:21:34 +0800 Subject: [PATCH] feat: ovis_vl --- ...nguage_pipeline.py => ovis_vl_pipeline.py} | 45 +++++++++++-------- 1 file changed, 26 insertions(+), 19 deletions(-) rename modelscope/pipelines/multi_modal/{vision_language_pipeline.py => ovis_vl_pipeline.py} (65%) diff --git a/modelscope/pipelines/multi_modal/vision_language_pipeline.py b/modelscope/pipelines/multi_modal/ovis_vl_pipeline.py similarity index 65% rename from modelscope/pipelines/multi_modal/vision_language_pipeline.py rename to modelscope/pipelines/multi_modal/ovis_vl_pipeline.py index 98a03d46..1428e971 100644 --- a/modelscope/pipelines/multi_modal/vision_language_pipeline.py +++ b/modelscope/pipelines/multi_modal/ovis_vl_pipeline.py @@ -1,19 +1,21 @@ import torch from typing import Any, Dict, Union +from PIL import Image +from modelscope import AutoModelForCausalLM from modelscope.metainfo import Pipelines, Preprocessors from modelscope.models.base import Model -from modelscope.outputs import OutputKeys, AwesomeTaskOutput +from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import Preprocessor -from modelscope.utils.constant import Fields, Tasks +from modelscope.utils.constant import Frameworks, Fields, Tasks +from modelscope.pipelines.multi_modal.visual_question_answering_pipeline import VisualQuestionAnsweringPipeline -# Pipeline按照任务名称+pipeline名字进行注册。configuration.json中只要添加pipeline.type字段即可使用,不需要改动代码 @PIPELINES.register_module( - Tasks.ovis_vision_chat, module_name=Pipelines.ovis_vision_chat) -class VisionChatPipeline(Pipeline): + Tasks.visual_question_answering, module_name="ovis-vl") +class VisionChatPipeline(VisualQuestionAnsweringPipeline): def __init__(self, model: Union[Model, str], @@ -22,13 +24,18 @@ class VisionChatPipeline(Pipeline): device: str = 'gpu', auto_collate=True, **kwargs): - super().__init__( - model=model, - preprocessor=preprocessor, - config_file=config_file, - device=device, - auto_collate=auto_collate, - **kwargs) + # super().__init__ + self.device_name = device + self.framework = Frameworks.torch + self._model_prepare = True + self._auto_collate = auto_collate + + # ovis + self.device = 'cuda' if device == 'gpu' else device + self.model = AutoModelForCausalLM.from_pretrained("AIDC-AI/Ovis1.6-Gemma2-9B", + torch_dtype=torch.bfloat16, + multimodal_max_length=8192, + trust_remote_code=True).to(self.device) self.text_tokenizer = self.model.get_text_tokenizer() self.visual_tokenizer = self.model.get_visual_tokenizer() @@ -36,16 +43,18 @@ class VisionChatPipeline(Pipeline): def preprocess(self, inputs: Dict[str, Any]): text = inputs['text'] image = inputs['image'] + image = Image.open(image) query = f'\n{text}' _, input_ids, pixel_values = self.model.preprocess_inputs(query, [image]) attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id) + input_ids = input_ids.unsqueeze(0).to(device=self.model.device) attention_mask = attention_mask.unsqueeze(0).to(device=self.model.device) pixel_values = [pixel_values.to(dtype=self.visual_tokenizer.dtype, device=self.visual_tokenizer.device)] return {'input_ids':input_ids, 'pixel_values': pixel_values, 'attention_mask': attention_mask} def forward(self, inputs: Dict[str, Any], - **forward_params) -> Union[Dict[str, Any], AwesomeTaskOutput]: + **forward_params) -> Dict[str, Any]: input_ids = inputs['input_ids'] pixel_values = inputs['pixel_values'] attention_mask = inputs['attention_mask'] @@ -69,11 +78,9 @@ class VisionChatPipeline(Pipeline): use_cache=True ) output_ids = self.model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0] - return output_ids + return {'output_ids': output_ids} - def postprocess(self, - inputs: Union[Dict[str, Any], - AwesomeTaskOutput]) -> Dict[str, Any]: - # do some post-processes + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output_ids = inputs['output_ids'] output = self.text_tokenizer.decode(output_ids, skip_special_tokens=True) - return {OutputKeys.TEXT: output} \ No newline at end of file + return {OutputKeys.TEXT: output}