mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
1.修复token classification preprocessor finetune结果错误问题
2.修复word segmentation output 无用属性
3. 修复nlp preprocessor传use_fast错误
4. 修复torch model exporter bug
5. 修复文档撰写过程中发现trainer相关bug
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10573269
32 lines
906 B
Python
32 lines
906 B
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from modelscope.outputs import TextClassificationModelOutput
|
|
from modelscope.utils.test_utils import test_level
|
|
|
|
|
|
class TestModelOutput(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
pass
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
|
def test_model_outputs(self):
|
|
outputs = TextClassificationModelOutput(logits=torch.Tensor([1]))
|
|
self.assertEqual(outputs['logits'], torch.Tensor([1]))
|
|
self.assertEqual(outputs[0], torch.Tensor([1]))
|
|
self.assertEqual(outputs.logits, torch.Tensor([1]))
|
|
outputs.loss = torch.Tensor([2])
|
|
logits, loss = outputs
|
|
self.assertEqual(logits, torch.Tensor([1]))
|
|
self.assertTrue(loss is not None)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|