mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
add eval() to pipeline call
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Optional, Union, Any
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import Model
|
||||
from ...models.nlp.masked_language_model import \
|
||||
@@ -35,6 +37,7 @@ class FillMaskPipeline(Pipeline):
|
||||
fill_mask_model.model_dir,
|
||||
first_sequence=first_sequence,
|
||||
second_sequence=None)
|
||||
fill_mask_model.eval()
|
||||
super().__init__(model=fill_mask_model, preprocessor=preprocessor, **kwargs)
|
||||
self.preprocessor = preprocessor
|
||||
self.tokenizer = preprocessor.tokenizer
|
||||
@@ -61,6 +64,11 @@ class FillMaskPipeline(Pipeline):
|
||||
}
|
||||
}
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
return super().forward(inputs, **forward_params)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
"""process the prediction results
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
import uuid
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
@@ -42,9 +42,15 @@ class NLIPipeline(Pipeline):
|
||||
sc_model.model_dir,
|
||||
first_sequence=first_sequence,
|
||||
second_sequence=second_sequence)
|
||||
sc_model.eval()
|
||||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
||||
assert len(sc_model.id2label) > 0
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
return super().forward(inputs, **forward_params)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from ...metainfo import Pipelines
|
||||
from ...models.nlp import SbertForSentenceSimilarity
|
||||
from ...preprocessors import SequenceClassificationPreprocessor
|
||||
@@ -39,11 +39,17 @@ class SentenceSimilarityPipeline(Pipeline):
|
||||
sc_model.model_dir,
|
||||
first_sequence=first_sequence,
|
||||
second_sequence=second_sequence)
|
||||
sc_model.eval()
|
||||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
assert hasattr(self.model, 'id2label'), \
|
||||
'id2label map should be initalizaed in init function.'
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
return super().forward(inputs, **forward_params)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
@@ -43,9 +43,15 @@ class SentimentClassificationPipeline(Pipeline):
|
||||
sc_model.model_dir,
|
||||
first_sequence=first_sequence,
|
||||
second_sequence=second_sequence)
|
||||
sc_model.eval()
|
||||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
||||
assert len(sc_model.id2label) > 0
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
return super().forward(inputs, **forward_params)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from typing import Dict, Optional, Union, Any
|
||||
import torch
|
||||
from ...metainfo import Pipelines
|
||||
from ...models import Model
|
||||
from ...models.nlp import PalmForTextGeneration
|
||||
@@ -33,9 +33,15 @@ class TextGenerationPipeline(Pipeline):
|
||||
model.tokenizer,
|
||||
first_sequence='sentence',
|
||||
second_sequence=None)
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.tokenizer = model.tokenizer
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
return super().forward(inputs, **forward_params)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from ...metainfo import Pipelines
|
||||
from ...models import Model
|
||||
from ...models.nlp import SbertForTokenClassification
|
||||
@@ -30,12 +30,18 @@ class WordSegmentationPipeline(Pipeline):
|
||||
SbertForTokenClassification) else Model.from_pretrained(model)
|
||||
if preprocessor is None:
|
||||
preprocessor = TokenClassifcationPreprocessor(model.model_dir)
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.tokenizer = preprocessor.tokenizer
|
||||
self.config = model.config
|
||||
assert len(self.config.id2label) > 0
|
||||
self.id2label = self.config.id2label
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
return super().forward(inputs, **forward_params)
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **postprocess_params) -> Dict[str, str]:
|
||||
"""process the prediction results
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
from scipy.special import softmax
|
||||
@@ -44,6 +44,7 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
if preprocessor is None:
|
||||
preprocessor = ZeroShotClassificationPreprocessor(
|
||||
sc_model.model_dir)
|
||||
model.eval()
|
||||
super().__init__(model=sc_model, preprocessor=preprocessor, **kwargs)
|
||||
|
||||
def _sanitize_parameters(self, **kwargs):
|
||||
@@ -62,6 +63,11 @@ class ZeroShotClassificationPipeline(Pipeline):
|
||||
postprocess_params['multi_label'] = kwargs.pop('multi_label', False)
|
||||
return preprocess_params, {}, postprocess_params
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
return super().forward(inputs, **forward_params)
|
||||
|
||||
def postprocess(self,
|
||||
inputs: Dict[str, Any],
|
||||
candidate_labels,
|
||||
|
||||
Reference in New Issue
Block a user