Merge pull request #297 from slin000111/master

add torch2 compile to nlp pipelines
This commit is contained in:
Xingjun.Wang
2023-05-12 11:03:15 +08:00
committed by GitHub
18 changed files with 53 additions and 19 deletions

View File

@@ -40,7 +40,8 @@ class DialogIntentPredictionPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
if preprocessor is None:
self.preprocessor = DialogIntentPredictionPreprocessor(
self.model.model_dir, **kwargs)

View File

@@ -42,7 +42,9 @@ class DialogStateTrackingPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
compile=kwargs.pop('compile', False),
compile_options=kwargs.pop('compile_options', {}))
if preprocessor is None:
self.preprocessor = DialogStateTrackingPreprocessor(

View File

@@ -46,7 +46,8 @@ class DocumentGroundedDialogGeneratePipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
if preprocessor is None:
self.preprocessor = DocumentGroundedDialogGeneratePreprocessor(

View File

@@ -64,7 +64,8 @@ class DocumentGroundedDialogRerankPipeline(Pipeline):
config_file=config_file,
device=device,
auto_collate=auto_collate,
seed=seed)
seed=seed,
**kwarg)
self.model = model
self.preprocessor = preprocessor
self.device = device

View File

@@ -55,7 +55,8 @@ class DocumentGroundedDialogRetrievalPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
if preprocessor is None:
self.preprocessor = DocumentGroundedDialogRetrievalPreprocessor(

View File

@@ -48,8 +48,14 @@ class DocumentSegmentationPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
kwargs = kwargs
if 'compile' in kwargs.keys():
kwargs.pop('compile')
if 'compile_options' in kwargs.keys():
kwargs.pop('compile_options')
self.model_dir = self.model.model_dir
self.model_cfg = self.model.model_cfg
if preprocessor is None:

View File

@@ -41,7 +41,14 @@ class ExtractiveSummarizationPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
kwargs = kwargs
if 'compile' in kwargs.keys():
kwargs.pop('compile')
if 'compile_options' in kwargs.keys():
kwargs.pop('compile_options')
self.model_dir = self.model.model_dir
self.model_cfg = self.model.model_cfg

View File

@@ -53,7 +53,8 @@ class FeatureExtractionPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
assert isinstance(self.model, Model), \
f'please check whether model config exists in {ModelFile.CONFIGURATION}'

View File

@@ -62,7 +62,8 @@ class FillMaskPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
assert isinstance(self.model, Model), \
f'please check whether model config exists in {ModelFile.CONFIGURATION}'

View File

@@ -55,7 +55,8 @@ class NamedEntityRecognitionPipeline(TokenClassificationPipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
assert isinstance(self.model, Model), \
f'please check whether model config exists in {ModelFile.CONFIGURATION}'

View File

@@ -42,7 +42,8 @@ class SentenceEmbeddingPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
assert isinstance(self.model, Model), \
f'please check whether model config exists in {ModelFile.CONFIGURATION}'

View File

@@ -67,7 +67,8 @@ class SiameseUiePipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
assert isinstance(self.model, Model), \
f'please check whether model config exists in {ModelFile.CONFIGURATION}'

View File

@@ -51,7 +51,8 @@ class TableQuestionAnsweringPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
assert isinstance(self.model, Model), \
f'please check whether model config exists in {ModelFile.CONFIGURATION}'

View File

@@ -58,7 +58,8 @@ class TextGenerationPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
assert isinstance(self.model, Model), \
f'please check whether model config exists in {ModelFile.CONFIGURATION}'

View File

@@ -43,7 +43,8 @@ class TextRankingPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
assert isinstance(self.model, Model), \
f'please check whether model config exists in {ModelFile.CONFIGURATION}'

View File

@@ -42,7 +42,11 @@ class TranslationEvaluationPipeline(Pipeline):
`"EvaluationMode.SRC"`, `"EvaluationMode.REF"`. Aside from hypothesis, the
source/reference/source+reference can be presented during evaluation.
"""
super().__init__(model=model, preprocessor=preprocessor)
super().__init__(
model=model,
preprocessor=preprocessor,
compile=kwargs.pop('compile', False),
compile_options=kwargs.pop('compile_options', {}))
self.eval_mode = eval_mode
self.checking_eval_mode()

View File

@@ -26,7 +26,8 @@ class UserSatisfactionEstimationPipeline(Pipeline):
preprocessor: DialogueClassificationUsePreprocessor = None,
config_file: str = None,
device: str = 'gpu',
auto_collate=True):
auto_collate=True,
**kwargs):
"""The inference pipeline for the user satisfaction estimation task.
Args:
@@ -49,7 +50,8 @@ class UserSatisfactionEstimationPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
if hasattr(self.preprocessor, 'id2label'):
self.id2label = self.preprocessor.id2label

View File

@@ -66,7 +66,8 @@ class ZeroShotClassificationPipeline(Pipeline):
preprocessor=preprocessor,
config_file=config_file,
device=device,
auto_collate=auto_collate)
auto_collate=auto_collate,
**kwargs)
self.entailment_id = 0
self.contradiction_id = 2