diff --git a/modelscope/pipelines/nlp/llm_pipeline.py b/modelscope/pipelines/nlp/llm_pipeline.py index 9e0c6241..3deaeac3 100644 --- a/modelscope/pipelines/nlp/llm_pipeline.py +++ b/modelscope/pipelines/nlp/llm_pipeline.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, Generator, Iterator, List, Tuple, Union import json import numpy as np import torch -from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers import PreTrainedModel, PreTrainedTokenizer, AutoConfig from modelscope import (AutoModelForCausalLM, AutoTokenizer, Pipeline, snapshot_download) @@ -170,9 +170,14 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin): **kwargs): self.device_map = kwargs.pop('device_map', None) self.llm_framework = llm_framework - # TODO: qwen-int4 need 'cuda'/'auto' device_map. - if not self.device_map and 'qwen' in kwargs['model'].lower(): - self.device_map = 'cuda' + + if os.path.exists(kwargs['model']): + config = AutoConfig.from_pretrained(kwargs['model'], trust_remote_code=True) + q_config = config.__dict__.get('quantization_config', None) + if q_config: + if q_config.get('quant_method', 'gptq') == 'gptq' and torch.cuda.device_count(): + self.device_map = 'cuda' + self.torch_dtype = kwargs.pop('torch_dtype', None) self.ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)