mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
fix punkt
This commit is contained in:
@@ -31,6 +31,7 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
|
||||
python -m spacy download en_core_web_sm
|
||||
pip install faiss-gpu
|
||||
pip install healpy
|
||||
pip install huggingface-hub==0.25.2
|
||||
# test with install
|
||||
pip install .
|
||||
else
|
||||
|
||||
@@ -269,7 +269,8 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
|
||||
assert model_id.lower() in SWIFT_MODEL_ID_MAPPING,\
|
||||
f'Invalid model id {model_id} or Swift framework does not support this model.'
|
||||
args = InferArguments(model_type=SWIFT_MODEL_ID_MAPPING[model_id.lower()])
|
||||
args = InferArguments(
|
||||
model_type=SWIFT_MODEL_ID_MAPPING[model_id.lower()])
|
||||
model, template = prepare_model_template(
|
||||
args, device_map=self.device_map)
|
||||
self.model = add_stream_generate(model)
|
||||
|
||||
@@ -213,11 +213,14 @@ class FillMaskPoNetPreprocessor(FillMaskPreprocessorBase):
|
||||
osp.join(model_dir, ModelFile.CONFIGURATION))
|
||||
self.language = self.cfg.model.get('language', 'en')
|
||||
if self.language == 'en':
|
||||
from nltk.tokenize import sent_tokenize
|
||||
import nltk
|
||||
nltk.download('punkt_tab')
|
||||
# import_external_nltk_data(
|
||||
# osp.join(model_dir, 'nltk_data'), 'tokenizers/punkt_tab')
|
||||
from nltk.tokenize import sent_tokenize
|
||||
from packaging import version
|
||||
if version.parse(nltk.__version__) >= version.parse('3.8.2'):
|
||||
nltk.download('punkt_tab')
|
||||
else:
|
||||
import_external_nltk_data(
|
||||
osp.join(model_dir, 'nltk_data'), 'tokenizers/punkt_tab')
|
||||
elif self.language in ['zh', 'cn']:
|
||||
|
||||
def sent_tokenize(para):
|
||||
|
||||
@@ -178,7 +178,8 @@ class PretrainedModelStreamingOutputMixin(StreamingOutputMixin):
|
||||
if version.parse(transformers.__version__) >= version.parse('4.43.0'):
|
||||
greedy_search_name = 'stream_greedy_search'
|
||||
sample_name = '_sample'
|
||||
elif version.parse(transformers.__version__) >= version.parse('4.39.0'):
|
||||
elif version.parse(
|
||||
transformers.__version__) >= version.parse('4.39.0'):
|
||||
greedy_search_name = '_greedy_search'
|
||||
sample_name = '_sample'
|
||||
else:
|
||||
@@ -452,7 +453,8 @@ class PretrainedModelStreamingOutputMixin(StreamingOutputMixin):
|
||||
break
|
||||
|
||||
# prepare model inputs
|
||||
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
||||
model_kwargs = self._get_initial_cache_position(
|
||||
input_ids, model_kwargs)
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, **model_kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user