bug fixed in token classification postprecessor

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10763022

    * [fix sequence labeling postprocess bug]
This commit is contained in:
zhangzhicheng.zzc
2022-11-17 18:11:04 +08:00
committed by yingda.chen
parent 90a5efa1c2
commit ed9d2b5436
3 changed files with 54 additions and 7 deletions

View File

@@ -92,6 +92,8 @@ class NamedEntityRecognitionPipeline(Pipeline):
offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']]
labels = [self.id2label[x] for x in predictions]
if len(labels) > len(offset_mapping):
labels = labels[1:-1]
chunks = []
chunk = {}
for label, offsets in zip(labels, offset_mapping):
@@ -104,6 +106,20 @@ class NamedEntityRecognitionPipeline(Pipeline):
'start': offsets[0],
'end': offsets[1]
}
if label[0] in 'I':
if not chunk:
chunk = {
'type': label[2:],
'start': offsets[0],
'end': offsets[1]
}
if label[0] in 'E':
if not chunk:
chunk = {
'type': label[2:],
'start': offsets[0],
'end': offsets[1]
}
if label[0] in 'IES':
if chunk:
chunk['end'] = offsets[1]
@@ -118,15 +134,15 @@ class NamedEntityRecognitionPipeline(Pipeline):
chunk['span'] = text[chunk['start']:chunk['end']]
chunks.append(chunk)
# for cws output
# for cws outputs
if len(chunks) > 0 and chunks[0]['type'] == 'cws':
spans = [
chunk['span'] for chunk in chunks if chunk['span'].strip()
]
seg_result = ' '.join(spans)
outputs = {OutputKeys.OUTPUT: seg_result, OutputKeys.LABELS: []}
outputs = {OutputKeys.OUTPUT: seg_result}
# for ner outpus
# for ner outputs
else:
outputs = {OutputKeys.OUTPUT: chunks}
return outputs

View File

@@ -95,6 +95,20 @@ class TokenClassificationPipeline(Pipeline):
'start': offsets[0],
'end': offsets[1]
}
if label[0] in 'I':
if not chunk:
chunk = {
'type': label[2:],
'start': offsets[0],
'end': offsets[1]
}
if label[0] in 'E':
if not chunk:
chunk = {
'type': label[2:],
'start': offsets[0],
'end': offsets[1]
}
if label[0] in 'IES':
if chunk:
chunk['end'] = offsets[1]

View File

@@ -80,9 +80,12 @@ class WordSegmentationPipeline(Pipeline):
Dict[str, str]: the prediction results
"""
text = inputs['text']
logits = inputs[OutputKeys.LOGITS]
predictions = torch.argmax(logits[0], dim=-1)
logits = torch_nested_numpify(torch_nested_detach(logits))
if not hasattr(inputs, 'predictions'):
logits = inputs[OutputKeys.LOGITS]
predictions = torch.argmax(logits[0], dim=-1)
else:
predictions = inputs[OutputKeys.PREDICTIONS].squeeze(
0).cpu().numpy()
predictions = torch_nested_numpify(torch_nested_detach(predictions))
offset_mapping = [x.cpu().tolist() for x in inputs['offset_mapping']]
@@ -101,6 +104,20 @@ class WordSegmentationPipeline(Pipeline):
'start': offsets[0],
'end': offsets[1]
}
if label[0] in 'I':
if not chunk:
chunk = {
'type': label[2:],
'start': offsets[0],
'end': offsets[1]
}
if label[0] in 'E':
if not chunk:
chunk = {
'type': label[2:],
'start': offsets[0],
'end': offsets[1]
}
if label[0] in 'IES':
if chunk:
chunk['end'] = offsets[1]
@@ -123,7 +140,7 @@ class WordSegmentationPipeline(Pipeline):
seg_result = ' '.join(spans)
outputs = {OutputKeys.OUTPUT: seg_result}
# for ner output
# for ner outputs
else:
outputs = {OutputKeys.OUTPUT: chunks}
return outputs