fix_comment

This commit is contained in:
suluyana
2024-11-04 13:55:34 +08:00
parent 08b8dfdf03
commit ba4eb0b097

View File

@@ -32,11 +32,13 @@ class VisionChatPipeline(VisualQuestionAnsweringPipeline):
self._auto_collate = auto_collate
# ovis
torch_dtype = kwargs.get('torch_dtype', torch.float16)
multimodal_max_length = kwargs.get('multimodal_max_length', 8192)
self.device = 'cuda' if device == 'gpu' else device
self.model = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype=torch.bfloat16,
multimodal_max_length=8192,
torch_dtype=torch_dtype,
multimodal_max_length=multimodal_max_length,
trust_remote_code=True).to(self.device)
self.text_tokenizer = self.model.get_text_tokenizer()
self.visual_tokenizer = self.model.get_visual_tokenizer()