mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
fix: set device_map=cuda if using gptq
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, Generator, Iterator, List, Tuple, Union
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoConfig
|
||||
|
||||
from modelscope import (AutoModelForCausalLM, AutoTokenizer, Pipeline,
|
||||
snapshot_download)
|
||||
@@ -170,9 +170,14 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
**kwargs):
|
||||
self.device_map = kwargs.pop('device_map', None)
|
||||
self.llm_framework = llm_framework
|
||||
# TODO: qwen-int4 need 'cuda'/'auto' device_map.
|
||||
if not self.device_map and 'qwen' in kwargs['model'].lower():
|
||||
self.device_map = 'cuda'
|
||||
|
||||
if os.path.exists(kwargs['model']):
|
||||
config = AutoConfig.from_pretrained(kwargs['model'], trust_remote_code=True)
|
||||
q_config = config.__dict__.get('quantization_config', None)
|
||||
if q_config:
|
||||
if q_config.get('quant_method', 'gptq') == 'gptq' and torch.cuda.device_count():
|
||||
self.device_map = 'cuda'
|
||||
|
||||
self.torch_dtype = kwargs.pop('torch_dtype', None)
|
||||
self.ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user