mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
bug fixed in token classification postprecessor
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10763022 * [fix sequence labeling postprocess bug]
This commit is contained in:
committed by
yingda.chen
parent
90a5efa1c2
commit
ed9d2b5436
@@ -92,6 +92,8 @@ class NamedEntityRecognitionPipeline(Pipeline):
|
||||
offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']]
|
||||
|
||||
labels = [self.id2label[x] for x in predictions]
|
||||
if len(labels) > len(offset_mapping):
|
||||
labels = labels[1:-1]
|
||||
chunks = []
|
||||
chunk = {}
|
||||
for label, offsets in zip(labels, offset_mapping):
|
||||
@@ -104,6 +106,20 @@ class NamedEntityRecognitionPipeline(Pipeline):
|
||||
'start': offsets[0],
|
||||
'end': offsets[1]
|
||||
}
|
||||
if label[0] in 'I':
|
||||
if not chunk:
|
||||
chunk = {
|
||||
'type': label[2:],
|
||||
'start': offsets[0],
|
||||
'end': offsets[1]
|
||||
}
|
||||
if label[0] in 'E':
|
||||
if not chunk:
|
||||
chunk = {
|
||||
'type': label[2:],
|
||||
'start': offsets[0],
|
||||
'end': offsets[1]
|
||||
}
|
||||
if label[0] in 'IES':
|
||||
if chunk:
|
||||
chunk['end'] = offsets[1]
|
||||
@@ -118,15 +134,15 @@ class NamedEntityRecognitionPipeline(Pipeline):
|
||||
chunk['span'] = text[chunk['start']:chunk['end']]
|
||||
chunks.append(chunk)
|
||||
|
||||
# for cws output
|
||||
# for cws outputs
|
||||
if len(chunks) > 0 and chunks[0]['type'] == 'cws':
|
||||
spans = [
|
||||
chunk['span'] for chunk in chunks if chunk['span'].strip()
|
||||
]
|
||||
seg_result = ' '.join(spans)
|
||||
outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []}
|
||||
outputs = {OutputKeys.OUTPUT: seg_result}
|
||||
|
||||
# for ner outpus
|
||||
# for ner outputs
|
||||
else:
|
||||
outputs = {OutputKeys.OUTPUT: chunks}
|
||||
return outputs
|
||||
|
||||
@@ -95,6 +95,20 @@ class TokenClassificationPipeline(Pipeline):
|
||||
'start': offsets[0],
|
||||
'end': offsets[1]
|
||||
}
|
||||
if label[0] in 'I':
|
||||
if not chunk:
|
||||
chunk = {
|
||||
'type': label[2:],
|
||||
'start': offsets[0],
|
||||
'end': offsets[1]
|
||||
}
|
||||
if label[0] in 'E':
|
||||
if not chunk:
|
||||
chunk = {
|
||||
'type': label[2:],
|
||||
'start': offsets[0],
|
||||
'end': offsets[1]
|
||||
}
|
||||
if label[0] in 'IES':
|
||||
if chunk:
|
||||
chunk['end'] = offsets[1]
|
||||
|
||||
@@ -80,9 +80,12 @@ class WordSegmentationPipeline(Pipeline):
|
||||
Dict[str, str]: the prediction results
|
||||
"""
|
||||
text = inputs['text']
|
||||
logits = inputs[OutputKeys.LOGITS]
|
||||
predictions = torch.argmax(logits[0], dim=-1)
|
||||
logits = torch_nested_numpify(torch_nested_detach(logits))
|
||||
if not hasattr(inputs, 'predictions'):
|
||||
logits = inputs[OutputKeys.LOGITS]
|
||||
predictions = torch.argmax(logits[0], dim=-1)
|
||||
else:
|
||||
predictions = inputs[OutputKeys.PREDICTIONS].squeeze(
|
||||
0).cpu().numpy()
|
||||
predictions = torch_nested_numpify(torch_nested_detach(predictions))
|
||||
offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']]
|
||||
|
||||
@@ -101,6 +104,20 @@ class WordSegmentationPipeline(Pipeline):
|
||||
'start': offsets[0],
|
||||
'end': offsets[1]
|
||||
}
|
||||
if label[0] in 'I':
|
||||
if not chunk:
|
||||
chunk = {
|
||||
'type': label[2:],
|
||||
'start': offsets[0],
|
||||
'end': offsets[1]
|
||||
}
|
||||
if label[0] in 'E':
|
||||
if not chunk:
|
||||
chunk = {
|
||||
'type': label[2:],
|
||||
'start': offsets[0],
|
||||
'end': offsets[1]
|
||||
}
|
||||
if label[0] in 'IES':
|
||||
if chunk:
|
||||
chunk['end'] = offsets[1]
|
||||
@@ -123,7 +140,7 @@ class WordSegmentationPipeline(Pipeline):
|
||||
seg_result = ' '.join(spans)
|
||||
outputs = {OutputKeys.OUTPUT: seg_result}
|
||||
|
||||
# for ner output
|
||||
# for ner outputs
|
||||
else:
|
||||
outputs = {OutputKeys.OUTPUT: chunks}
|
||||
return outputs
|
||||
|
||||
Reference in New Issue
Block a user