mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
1. add device util to verify, create and place device
2. pipeline and trainer support update
3. fix pipeline which use tf models does not place model to the right device
usage
```python
pipe = pipeline('damo/xxx', device='cpu')
pipe = pipeline('damo/xxx', device='gpu')
pipe = pipeline('damo/xxx', device='gpu:0')
pipe = pipeline('damo/xxx', device='gpu:2')
pipe = pipeline('damo/xxx', device='cuda')
pipe = pipeline('damo/xxx', device='cuda:1')
```
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9800672
48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
import os.path
|
|
import unittest
|
|
|
|
from modelscope.fileio import File
|
|
from modelscope.pipelines import pipeline
|
|
from modelscope.utils.constant import Tasks
|
|
from modelscope.utils.test_utils import test_level
|
|
|
|
TEST_SPEECH_FILE = 'data/test/audios/3ch_nihaomiya.wav'
|
|
|
|
|
|
class KWSFarfieldTest(unittest.TestCase):
|
|
|
|
def setUp(self) -> None:
|
|
self.model_id = 'damo/speech_dfsmn_kws_char_farfield_16k_nihaomiya'
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
|
def test_normal(self):
|
|
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
|
inputs = {'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE)}
|
|
result = kws(inputs)
|
|
self.assertEqual(len(result['kws_list']), 5)
|
|
print(result['kws_list'][-1])
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
|
def test_output(self):
|
|
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
|
inputs = {
|
|
'input_file': os.path.join(os.getcwd(), TEST_SPEECH_FILE),
|
|
'output_file': 'output.wav'
|
|
}
|
|
result = kws(inputs)
|
|
self.assertEqual(len(result['kws_list']), 5)
|
|
print(result['kws_list'][-1])
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
|
def test_input_bytes(self):
|
|
with open(os.path.join(os.getcwd(), TEST_SPEECH_FILE), 'rb') as f:
|
|
data = f.read()
|
|
kws = pipeline(Tasks.keyword_spotting, model=self.model_id)
|
|
result = kws(data)
|
|
self.assertEqual(len(result['kws_list']), 5)
|
|
print(result['kws_list'][-1])
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|