mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Merge branch 'modelscope_dev'
add argument compile to nlp pipelines
This commit is contained in:
@@ -40,7 +40,8 @@ class DialogIntentPredictionPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
if preprocessor is None:
|
||||
self.preprocessor = DialogIntentPredictionPreprocessor(
|
||||
self.model.model_dir, **kwargs)
|
||||
|
||||
@@ -42,7 +42,9 @@ class DialogStateTrackingPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
compile=kwargs.pop('compile', False),
|
||||
compile_options=kwargs.pop('compile_options', {}))
|
||||
|
||||
if preprocessor is None:
|
||||
self.preprocessor = DialogStateTrackingPreprocessor(
|
||||
|
||||
@@ -46,7 +46,8 @@ class DocumentGroundedDialogGeneratePipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
if preprocessor is None:
|
||||
self.preprocessor = DocumentGroundedDialogGeneratePreprocessor(
|
||||
|
||||
@@ -64,7 +64,8 @@ class DocumentGroundedDialogRerankPipeline(Pipeline):
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate,
|
||||
seed=seed)
|
||||
seed=seed,
|
||||
**kwarg)
|
||||
self.model = model
|
||||
self.preprocessor = preprocessor
|
||||
self.device = device
|
||||
|
||||
@@ -55,7 +55,8 @@ class DocumentGroundedDialogRetrievalPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
if preprocessor is None:
|
||||
self.preprocessor = DocumentGroundedDialogRetrievalPreprocessor(
|
||||
|
||||
@@ -48,8 +48,14 @@ class DocumentSegmentationPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
kwargs = kwargs
|
||||
if 'compile' in kwargs.keys():
|
||||
kwargs.pop('compile')
|
||||
if 'compile_options' in kwargs.keys():
|
||||
kwargs.pop('compile_options')
|
||||
self.model_dir = self.model.model_dir
|
||||
self.model_cfg = self.model.model_cfg
|
||||
if preprocessor is None:
|
||||
|
||||
@@ -41,7 +41,14 @@ class ExtractiveSummarizationPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
kwargs = kwargs
|
||||
if 'compile' in kwargs.keys():
|
||||
kwargs.pop('compile')
|
||||
if 'compile_options' in kwargs.keys():
|
||||
kwargs.pop('compile_options')
|
||||
|
||||
self.model_dir = self.model.model_dir
|
||||
self.model_cfg = self.model.model_cfg
|
||||
|
||||
@@ -53,7 +53,8 @@ class FeatureExtractionPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
assert isinstance(self.model, Model), \
|
||||
f'please check whether model config exists in {ModelFile.CONFIGURATION}'
|
||||
|
||||
@@ -62,7 +62,8 @@ class FillMaskPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
assert isinstance(self.model, Model), \
|
||||
f'please check whether model config exists in {ModelFile.CONFIGURATION}'
|
||||
|
||||
@@ -55,7 +55,8 @@ class NamedEntityRecognitionPipeline(TokenClassificationPipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
assert isinstance(self.model, Model), \
|
||||
f'please check whether model config exists in {ModelFile.CONFIGURATION}'
|
||||
|
||||
@@ -42,7 +42,8 @@ class SentenceEmbeddingPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
assert isinstance(self.model, Model), \
|
||||
f'please check whether model config exists in {ModelFile.CONFIGURATION}'
|
||||
|
||||
@@ -67,7 +67,8 @@ class SiameseUiePipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
assert isinstance(self.model, Model), \
|
||||
f'please check whether model config exists in {ModelFile.CONFIGURATION}'
|
||||
|
||||
@@ -51,7 +51,8 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
assert isinstance(self.model, Model), \
|
||||
f'please check whether model config exists in {ModelFile.CONFIGURATION}'
|
||||
|
||||
@@ -58,7 +58,8 @@ class TextGenerationPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
assert isinstance(self.model, Model), \
|
||||
f'please check whether model config exists in {ModelFile.CONFIGURATION}'
|
||||
|
||||
@@ -43,7 +43,8 @@ class TextRankingPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
assert isinstance(self.model, Model), \
|
||||
f'please check whether model config exists in {ModelFile.CONFIGURATION}'
|
||||
|
||||
@@ -42,7 +42,7 @@ class TranslationEvaluationPipeline(Pipeline):
|
||||
`"EvaluationMode.SRC"`, `"EvaluationMode.REF"`. Aside from hypothesis, the
|
||||
source/reference/source+reference can be presented during evaluation.
|
||||
"""
|
||||
super().__init__(model=model, preprocessor=preprocessor)
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
self.eval_mode = eval_mode
|
||||
self.checking_eval_mode()
|
||||
|
||||
@@ -26,7 +26,8 @@ class UserSatisfactionEstimationPipeline(Pipeline):
|
||||
preprocessor: DialogueClassificationUsePreprocessor = None,
|
||||
config_file: str = None,
|
||||
device: str = 'gpu',
|
||||
auto_collate=True):
|
||||
auto_collate=True,
|
||||
**kwargs):
|
||||
"""The inference pipeline for the user satisfaction estimation task.
|
||||
|
||||
Args:
|
||||
@@ -49,7 +50,8 @@ class UserSatisfactionEstimationPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
|
||||
if hasattr(self.preprocessor, 'id2label'):
|
||||
self.id2label = self.preprocessor.id2label
|
||||
|
||||
@@ -66,7 +66,8 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
preprocessor=preprocessor,
|
||||
config_file=config_file,
|
||||
device=device,
|
||||
auto_collate=auto_collate)
|
||||
auto_collate=auto_collate,
|
||||
**kwargs)
|
||||
self.entailment_id = 0
|
||||
self.contradiction_id = 2
|
||||
|
||||
|
||||
Reference in New Issue
Block a user