This commit is contained in:
suluyana
2024-10-17 10:32:14 +08:00
parent 1eaf8a0ead
commit 98832eaea0

View File

@@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, Generator, Iterator, List, Tuple, Union
import json
import numpy as np
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoConfig
from transformers import AutoConfig, PreTrainedModel, PreTrainedTokenizer
from modelscope import (AutoModelForCausalLM, AutoTokenizer, Pipeline,
snapshot_download)
@@ -172,10 +172,13 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
self.llm_framework = llm_framework
if os.path.exists(kwargs['model']):
config = AutoConfig.from_pretrained(kwargs['model'], trust_remote_code=True)
config = AutoConfig.from_pretrained(
kwargs['model'], trust_remote_code=True)
q_config = config.__dict__.get('quantization_config', None)
if q_config:
if q_config.get('quant_method', 'gptq') == 'gptq' and torch.cuda.device_count():
if q_config.get(
'quant_method',
'gptq') == 'gptq' and torch.cuda.device_count():
self.device_map = 'cuda'
self.torch_dtype = kwargs.pop('torch_dtype', None)