mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
[bugfix] compat swift 4.0 (#1680)
This commit is contained in:
@@ -12,7 +12,7 @@ import numpy as np
|
||||
#
|
||||
import torch
|
||||
from matplotlib.figure import Figure
|
||||
from swift import LoRAConfig, Swift
|
||||
from swift.tuners import LoRAConfig, Swift
|
||||
from tensorboard.backend.event_processing.event_accumulator import \
|
||||
EventAccumulator
|
||||
from torch import Tensor
|
||||
|
||||
@@ -35,7 +35,7 @@ attention.deprecate = lambda *arg, **kwargs: None
|
||||
__tuner_MAP__ = {'lora': LoRATuner, 'control_lora': ControlLoRATuner}
|
||||
|
||||
if is_swift_available():
|
||||
from swift import AdapterConfig, LoRAConfig, PromptConfig, Swift
|
||||
from swift.tuners import AdapterConfig, LoRAConfig, PromptConfig, Swift
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
|
||||
@@ -281,7 +281,10 @@ def external_engine_for_llm_checker(model: Union[str, List[str], Model,
|
||||
|
||||
llm_framework = kwargs.get('llm_framework', '')
|
||||
if llm_framework == 'swift':
|
||||
from swift.llm import get_model_info_meta
|
||||
try:
|
||||
from swift.model import get_model_info_meta
|
||||
except ImportError:
|
||||
from swift.llm import get_model_info_meta
|
||||
# check if swift supports
|
||||
if os.path.exists(model):
|
||||
model_id = get_model_id_from_cache(model)
|
||||
|
||||
@@ -217,8 +217,11 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
tokenizer_class) if tokenizer is None else tokenizer
|
||||
|
||||
def _init_swift(self, model_id, device) -> None:
|
||||
from swift.llm import prepare_model_template
|
||||
from swift.llm import InferArguments, get_model_info_meta
|
||||
try:
|
||||
from swift.pipelines import prepare_model_template
|
||||
from swift.arguments import InferArguments
|
||||
except ImportError:
|
||||
from swift.llm import prepare_model_template, InferArguments
|
||||
|
||||
def format_messages(messages: Dict[str, List[Dict[str, str]]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
|
||||
@@ -390,7 +390,10 @@ def load_video_internvl(video_io: BytesIO, bound=None, num_segments=32):
|
||||
|
||||
def draw_plot(img_dir: str, bbox: List[int], bbox_type: str, output_file: str):
|
||||
from PIL import Image, ImageDraw
|
||||
from swift.llm.template.template import Template
|
||||
try:
|
||||
from swift.template import Template
|
||||
except ImportError:
|
||||
from swift.llm.template.template import Template
|
||||
image = Image.open(img_dir)
|
||||
|
||||
objects = [{'bbox': bbox, 'bbox_type': bbox_type, 'image': 0}]
|
||||
@@ -465,7 +468,10 @@ def load_audio_qwen(audio_io: BytesIO, sampling_rate: int):
|
||||
|
||||
|
||||
def load_video_qwen2(video_path: str):
|
||||
from swift.llm.template.template import get_env_args
|
||||
try:
|
||||
from swift.utils import get_env_args
|
||||
except ImportError:
|
||||
from swift.llm.template.template import get_env_args
|
||||
import torchvision
|
||||
from torchvision import io, transforms
|
||||
from qwen_vl_utils.vision_process import (round_by_factor, FPS, FRAME_FACTOR, FPS_MIN_FRAMES, FPS_MAX_FRAMES,
|
||||
|
||||
@@ -52,7 +52,7 @@ class SwiftCheckpointProcessor(CheckpointProcessor):
|
||||
raise ValueError(
|
||||
'Please install swift by `pip install ms-swift` to use SwiftHook.'
|
||||
)
|
||||
from swift import SwiftModel
|
||||
from swift.tuners import SwiftModel
|
||||
if isinstance(model, SwiftModel):
|
||||
_swift_output_dir = output_dir + SwiftCheckpointProcessor.SWIFT_SAVE_SUFFIX
|
||||
model.save_pretrained(
|
||||
|
||||
@@ -44,7 +44,7 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase):
|
||||
exists('transformers<5.0'),
|
||||
'Skip test because transformers version is too high.')
|
||||
def test_vision_efficient_tuning_swift_lora_train(self):
|
||||
from swift import LoRAConfig
|
||||
from swift.tuners import LoRAConfig
|
||||
model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
@@ -86,7 +86,7 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase):
|
||||
exists('transformers<5.0'),
|
||||
'Skip test because transformers version is too high.')
|
||||
def test_vision_efficient_tuning_swift_adapter_train(self):
|
||||
from swift import AdapterConfig
|
||||
from swift.tuners import AdapterConfig
|
||||
model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
@@ -126,7 +126,7 @@ class TestVisionEfficientTuningSwiftTrainer(unittest.TestCase):
|
||||
exists('transformers<5.0'),
|
||||
'Skip test because transformers version is too high.')
|
||||
def test_vision_efficient_tuning_swift_prompt_train(self):
|
||||
from swift import PromptConfig
|
||||
from swift.tuners import PromptConfig
|
||||
model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt'
|
||||
|
||||
def cfg_modify_fn(cfg):
|
||||
|
||||
Reference in New Issue
Block a user