mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-25 04:30:48 +01:00
feat: ovis_vl
This commit is contained in:
@@ -1,19 +1,21 @@
|
||||
import torch
|
||||
from typing import Any, Dict, Union
|
||||
from PIL import Image
|
||||
|
||||
from modelscope import AutoModelForCausalLM
|
||||
from modelscope.metainfo import Pipelines, Preprocessors
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.outputs import OutputKeys, AwesomeTaskOutput
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import Fields, Tasks
|
||||
from modelscope.utils.constant import Frameworks, Fields, Tasks
|
||||
from modelscope.pipelines.multi_modal.visual_question_answering_pipeline import VisualQuestionAnsweringPipeline
|
||||
|
||||
|
||||
# Pipeline按照任务名称+pipeline名字进行注册。configuration.json中只要添加pipeline.type字段即可使用,不需要改动代码
|
||||
@PIPELINES.register_module(
|
||||
Tasks.ovis_vision_chat, module_name=Pipelines.ovis_vision_chat)
|
||||
class VisionChatPipeline(Pipeline):
|
||||
Tasks.visual_question_answering, module_name="ovis-vl")
|
||||
class VisionChatPipeline(VisualQuestionAnsweringPipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[Model, str],
|
||||
@@ -22,13 +24,18 @@ class VisionChatPipeline(Pipeline):
|
||||
device: str = 'gpu',
|
||||
auto_collate=True,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
model=model,
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
# super().__init__
|
||||
self.device_name = device
|
||||
self.framework = Frameworks.torch
|
||||
self._model_prepare = True
|
||||
self._auto_collate = auto_collate
|
||||
|
||||
# ovis
|
||||
self.device = 'cuda' if device == 'gpu' else device
|
||||
self.model = AutoModelForCausalLM.from_pretrained("AIDC-AI/Ovis1.6-Gemma2-9B",
|
||||
torch_dtype=torch.bfloat16,
|
||||
multimodal_max_length=8192,
|
||||
trust_remote_code=True).to(self.device)
|
||||
self.text_tokenizer = self.model.get_text_tokenizer()
|
||||
self.visual_tokenizer = self.model.get_visual_tokenizer()
|
||||
|
||||
@@ -36,16 +43,18 @@ class VisionChatPipeline(Pipeline):
|
||||
def preprocess(self, inputs: Dict[str, Any]):
|
||||
text = inputs['text']
|
||||
image = inputs['image']
|
||||
image = Image.open(image)
|
||||
query = f'<image>\n{text}'
|
||||
_, input_ids, pixel_values = self.model.preprocess_inputs(query, [image])
|
||||
attention_mask = torch.ne(input_ids, self.text_tokenizer.pad_token_id)
|
||||
input_ids = input_ids.unsqueeze(0).to(device=self.model.device)
|
||||
attention_mask = attention_mask.unsqueeze(0).to(device=self.model.device)
|
||||
pixel_values = [pixel_values.to(dtype=self.visual_tokenizer.dtype, device=self.visual_tokenizer.device)]
|
||||
|
||||
return {'input_ids':input_ids, 'pixel_values': pixel_values, 'attention_mask': attention_mask}
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Union[Dict[str, Any], AwesomeTaskOutput]:
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
input_ids = inputs['input_ids']
|
||||
pixel_values = inputs['pixel_values']
|
||||
attention_mask = inputs['attention_mask']
|
||||
@@ -69,11 +78,9 @@ class VisionChatPipeline(Pipeline):
|
||||
use_cache=True
|
||||
)
|
||||
output_ids = self.model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0]
|
||||
return output_ids
|
||||
return {'output_ids': output_ids}
|
||||
|
||||
def postprocess(self,
|
||||
inputs: Union[Dict[str, Any],
|
||||
AwesomeTaskOutput]) -> Dict[str, Any]:
|
||||
# do some post-processes
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
output_ids = inputs['output_ids']
|
||||
output = self.text_tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
return {OutputKeys.TEXT: output}
|
||||
return {OutputKeys.TEXT: output}
|
||||
Reference in New Issue
Block a user