From 1eaf8a0ead5d20c69aeb318e9ab968190640ba59 Mon Sep 17 00:00:00 2001 From: DaozeZhang Date: Wed, 16 Oct 2024 20:07:32 +0800 Subject: [PATCH] fix: set device_map=cuda if using gptq --- modelscope/pipelines/nlp/llm_pipeline.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/modelscope/pipelines/nlp/llm_pipeline.py b/modelscope/pipelines/nlp/llm_pipeline.py index 9e0c6241..3deaeac3 100644 --- a/modelscope/pipelines/nlp/llm_pipeline.py +++ b/modelscope/pipelines/nlp/llm_pipeline.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, Generator, Iterator, List, Tuple, Union import json import numpy as np import torch -from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers import PreTrainedModel, PreTrainedTokenizer, AutoConfig from modelscope import (AutoModelForCausalLM, AutoTokenizer, Pipeline, snapshot_download) @@ -170,9 +170,14 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin): **kwargs): self.device_map = kwargs.pop('device_map', None) self.llm_framework = llm_framework - # TODO: qwen-int4 need 'cuda'/'auto' device_map. - if not self.device_map and 'qwen' in kwargs['model'].lower(): - self.device_map = 'cuda' + + if os.path.exists(kwargs['model']): + config = AutoConfig.from_pretrained(kwargs['model'], trust_remote_code=True) + q_config = config.__dict__.get('quantization_config', None) + if q_config: + if q_config.get('quant_method', 'gptq') == 'gptq' and torch.cuda.device_count(): + self.device_map = 'cuda' + self.torch_dtype = kwargs.pop('torch_dtype', None) self.ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)