mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
fix lint
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, Generator, Iterator, List, Tuple, Union
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer, AutoConfig
|
||||
from transformers import AutoConfig, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from modelscope import (AutoModelForCausalLM, AutoTokenizer, Pipeline,
|
||||
snapshot_download)
|
||||
@@ -172,10 +172,13 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
self.llm_framework = llm_framework
|
||||
|
||||
if os.path.exists(kwargs['model']):
|
||||
config = AutoConfig.from_pretrained(kwargs['model'], trust_remote_code=True)
|
||||
config = AutoConfig.from_pretrained(
|
||||
kwargs['model'], trust_remote_code=True)
|
||||
q_config = config.__dict__.get('quantization_config', None)
|
||||
if q_config:
|
||||
if q_config.get('quant_method', 'gptq') == 'gptq' and torch.cuda.device_count():
|
||||
if q_config.get(
|
||||
'quant_method',
|
||||
'gptq') == 'gptq' and torch.cuda.device_count():
|
||||
self.device_map = 'cuda'
|
||||
|
||||
self.torch_dtype = kwargs.pop('torch_dtype', None)
|
||||
|
||||
Reference in New Issue
Block a user