From 98832eaea0c7dc6fd14dc21796193dc8f3c47758 Mon Sep 17 00:00:00 2001 From: suluyana Date: Thu, 17 Oct 2024 10:32:14 +0800 Subject: [PATCH] fix lint --- modelscope/pipelines/nlp/llm_pipeline.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modelscope/pipelines/nlp/llm_pipeline.py b/modelscope/pipelines/nlp/llm_pipeline.py index 3deaeac3..cb801bd4 100644 --- a/modelscope/pipelines/nlp/llm_pipeline.py +++ b/modelscope/pipelines/nlp/llm_pipeline.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, Generator, Iterator, List, Tuple, Union import json import numpy as np import torch -from transformers import PreTrainedModel, PreTrainedTokenizer, AutoConfig +from transformers import AutoConfig, PreTrainedModel, PreTrainedTokenizer from modelscope import (AutoModelForCausalLM, AutoTokenizer, Pipeline, snapshot_download) @@ -172,10 +172,13 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin): self.llm_framework = llm_framework if os.path.exists(kwargs['model']): - config = AutoConfig.from_pretrained(kwargs['model'], trust_remote_code=True) + config = AutoConfig.from_pretrained( + kwargs['model'], trust_remote_code=True) q_config = config.__dict__.get('quantization_config', None) if q_config: - if q_config.get('quant_method', 'gptq') == 'gptq' and torch.cuda.device_count(): + if q_config.get( + 'quant_method', + 'gptq') == 'gptq' and torch.cuda.device_count(): self.device_map = 'cuda' self.torch_dtype = kwargs.pop('torch_dtype', None)