feat: ovis_vl

This commit is contained in:
suluyana
2024-10-29 14:21:34 +08:00
parent 339005d9da
commit dae0988f36

View File

@@ -1,19 +1,21 @@
import torch
from typing import Any, Dict, Union
from PIL import Image
from modelscope import AutoModelForCausalLM
from modelscope.metainfo import Pipelines, Preprocessors
from modelscope.models.base import Model
from modelscope.outputs import OutputKeys, AwesomeTaskOutput
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import Fields, Tasks
from modelscope.utils.constant import Frameworks, Fields, Tasks
from modelscope.pipelines.multi_modal.visual_question_answering_pipeline import VisualQuestionAnsweringPipeline
# Pipeline按照任务名称+pipeline名字进行注册。configuration.json中只要添加pipeline.type字段即可使用不需要改动代码
@PIPELINES.register_module(
Tasks.ovis_vision_chat, module_name=Pipelines.ovis_vision_chat)
class VisionChatPipeline(Pipeline):
Tasks.visual_question_answering, module_name="ovis-vl")
class VisionChatPipeline(VisualQuestionAnsweringPipeline):
def __init__(self,
model: Union[Model, str],
@@ -22,13 +24,18 @@ class VisionChatPipeline(Pipeline):
device: str = 'gpu',
auto_collate=True,
**kwargs):
super().__init__(
model=model,
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate,
**kwargs)
# super().__init__
self.device_name = device
self.framework = Frameworks.torch
self._model_prepare = True
self._auto_collate = auto_collate
# ovis
self.device = 'cuda' if device == 'gpu' else device
self.model = AutoModelForCausalLM.from_pretrained("AIDC-AI/Ovis1.6-Gemma2-9B",
torch_dtype=torch.bfloat16,
multimodal_max_length=8192,
trust_remote_code=True).to(self.device)
self.text_tokenizer = self.model.get_text_tokenizer()
self.visual_tokenizer = self.model.get_visual_tokenizer()
@@ -36,16 +43,18 @@ class VisionChatPipeline(Pipeline):
def preprocess(self, inputs: Dict[str, Any]):
text = inputs['text']
image = inputs['image']
image = Image.open(image)
query = f'<image>\n{text}'
_, input_ids, pixel_values = self.model.preprocess_inputs(query, [image])
attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
input_ids = input_ids.unsqueeze(0).to(device=self.model.device)
attention_mask = attention_mask.unsqueeze(0).to(device=self.model.device)
pixel_values = [pixel_values.to(dtype=self.visual_tokenizer.dtype, device=self.visual_tokenizer.device)]
return {'input_ids':input_ids, 'pixel_values': pixel_values, 'attention_mask': attention_mask}
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Union[Dict[str, Any], AwesomeTaskOutput]:
**forward_params) -> Dict[str, Any]:
input_ids = inputs['input_ids']
pixel_values = inputs['pixel_values']
attention_mask = inputs['attention_mask']
@@ -69,11 +78,9 @@ class VisionChatPipeline(Pipeline):
use_cache=True
)
output_ids = self.model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0]
return output_ids
return {'output_ids': output_ids}
def postprocess(self,
inputs: Union[Dict[str, Any],
AwesomeTaskOutput]) -> Dict[str, Any]:
# do some post-processes
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
output_ids = inputs['output_ids']
output = self.text_tokenizer.decode(output_ids, skip_special_tokens=True)
return {OutputKeys.TEXT: output}
return {OutputKeys.TEXT: output}