mirror of
https://github.com/coqui-ai/TTS.git
synced 2025-12-24 20:29:30 +01:00
* refactor(punctuation): remove orphan code for handling lone punctuation
The case of lone punctuation is already handled at the top of restore(). The
removed if statement would never be called and would in fact raise an
AttributeError because the _punc_index named tuple doesn't have the attribute
`mark`.
* refactor(punctuation): remove unused argument
* fix(punctuation): correctly handle initial punctuation
Stripping and restoring initial punctuation didn't work correctly because the
string-splitting caused an additional empty string to be inserted in the text
list (because `".A".split(".")` => `["", "A"]`). Now, an initial empty string is
skipped and relevant test cases are added.
Fixes #3333
39 lines
1.6 KiB
Python
39 lines
1.6 KiB
Python
import unittest
|
|
|
|
from TTS.tts.utils.text.punctuation import _DEF_PUNCS, Punctuation
|
|
|
|
|
|
class PunctuationTest(unittest.TestCase):
|
|
def setUp(self):
|
|
self.punctuation = Punctuation()
|
|
self.test_texts = [
|
|
("This, is my text ... to be striped !! from text?", "This is my text to be striped from text"),
|
|
("This, is my text ... to be striped !! from text", "This is my text to be striped from text"),
|
|
("This, is my text ... to be striped from text?", "This is my text to be striped from text"),
|
|
("This, is my text to be striped from text", "This is my text to be striped from text"),
|
|
(".", ""),
|
|
(" . ", ""),
|
|
("!!! Attention !!!", "Attention"),
|
|
("!!! Attention !!! This is just a ... test.", "Attention This is just a test"),
|
|
("!!! Attention! This is just a ... test.", "Attention This is just a test"),
|
|
]
|
|
|
|
def test_get_set_puncs(self):
|
|
self.punctuation.puncs = "-="
|
|
self.assertEqual(self.punctuation.puncs, "-=")
|
|
|
|
self.punctuation.puncs = _DEF_PUNCS
|
|
self.assertEqual(self.punctuation.puncs, _DEF_PUNCS)
|
|
|
|
def test_strip_punc(self):
|
|
for text, gt in self.test_texts:
|
|
text_striped = self.punctuation.strip(text)
|
|
self.assertEqual(text_striped, gt)
|
|
|
|
def test_strip_restore(self):
|
|
for text, gt in self.test_texts:
|
|
text_striped, puncs_map = self.punctuation.strip_to_restore(text)
|
|
text_restored = self.punctuation.restore(text_striped, puncs_map)
|
|
self.assertEqual(" ".join(text_striped), gt)
|
|
self.assertEqual(text_restored[0], text)
|