mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-19 01:29:24 +01:00
add qwen 7b base and chat
添加QWen 7b base模型和chat模型及相关pipelines Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13482235 * add qwen 7b base and chat * fix logger * update examples, lint test * add unittest for qwen base and chat * rename qwen to qwen-7b * resolve imports and add a registry to text-generation * reset load model from pretrained * fix precheck * skip qwen test case now * remove strange file
This commit is contained in:
committed by
wenmeng.zwm
parent
1a6583eee2
commit
33bd74a7be
@@ -3,34 +3,24 @@ import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Counter, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Counter, Dict, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import Dataset as HfDataset
|
||||
from numpy import ndarray
|
||||
from tensorboard.backend.event_processing.event_accumulator import \
|
||||
EventAccumulator
|
||||
from torch import Tensor
|
||||
from torch import device as Device
|
||||
from torch import dtype as Dtype
|
||||
from torch.nn import Module
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torchmetrics import Accuracy, MeanMetric
|
||||
from tqdm import tqdm
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
from transformers import GenerationConfig, HfArgumentParser, TextStreamer
|
||||
|
||||
from modelscope import get_logger
|
||||
from modelscope.metrics.base import Metric
|
||||
from modelscope.metrics.builder import METRICS
|
||||
from modelscope.swift import LoRAConfig, Swift
|
||||
from modelscope.trainers import EpochBasedTrainer
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
from modelscope.utils.registry import default_group
|
||||
|
||||
COLOR, COLOR_S = '#FFE2D9', '#FF7043'
|
||||
@@ -318,3 +308,15 @@ def inference(input_ids: List[int],
|
||||
generation_config=generation_config)
|
||||
output_text = tokenizer.decode(generate_ids[0])
|
||||
return output_text
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
|
||||
def parse_args(class_type: Type[_T],
|
||||
argv: Optional[List[str]] = None) -> Tuple[_T, List[str]]:
|
||||
parser = HfArgumentParser([class_type])
|
||||
args, remaining_args = parser.parse_args_into_dataclasses(
|
||||
argv, return_remaining_strings=True)
|
||||
logger.info(f'args: {args}')
|
||||
return args, remaining_args
|
||||
|
||||
Reference in New Issue
Block a user