fix torch 2.x compatible issue

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13045011

* fix torch 2.x compatible issue

* fix torch 2.x compatible issue

* fix complex-valued input tensor matching the output from stft with return_complex=True.

* skip plugin test temporarily for modify torch version

* fix test_speech_signal_process.py compatible issue

* fix lint issue

* upgrade funasr to 0.6.5
This commit is contained in:
mulin.lyh
2023-06-27 14:40:51 +08:00
committed by wenmeng.zwm
parent a58be34384
commit eb0f0216c6
5 changed files with 15 additions and 12 deletions

View File

@@ -54,7 +54,7 @@ def multiclass_nms(multi_bboxes,
if score_factors is not None:
scores = scores * score_factors[:, None]
labels = torch.arange(num_classes, dtype=torch.long)
labels = torch.arange(num_classes, dtype=torch.long, device=scores.device)
labels = labels.view(1, -1).expand_as(scores)
bboxes = bboxes.reshape(-1, 4)

View File

@@ -70,7 +70,8 @@ class ANSDFSMNPipeline(Pipeline):
HOP_LENGTH,
STFT_WIN_LEN,
center=False,
window=window)
window=window,
return_complex=False)
def istft(x, slen):
return librosa.istft(

View File

@@ -82,18 +82,19 @@ class LinearAECPipeline(Pipeline):
window = torch.hamming_window(winlen, periodic=False)
def stft(x):
return torch.stft(
x,
n_fft,
hop_length,
winlen,
center=False,
window=window.to(x.device),
return_complex=False)
return torch.view_as_real(
torch.stft(
x,
n_fft,
hop_length,
winlen,
center=False,
window=window.to(x.device),
return_complex=True))
def istft(x, slen):
return torch.istft(
x,
torch.view_as_complex(x),
n_fft,
hop_length,
winlen,

View File

@@ -1 +1 @@
funasr>=0.6.0
funasr>=0.6.5

View File

@@ -4,6 +4,7 @@ import unittest
from modelscope.utils.plugins import PluginsManager
@unittest.skipUnless(False, reason='For it modify torch version')
class PluginsCMDTest(unittest.TestCase):
def setUp(self):