Funasr1.0 (#733)

* funasr1.0 model.generate

* funasr1.0 update

* funasr1.0
This commit is contained in:
zhifu gao
2024-01-23 19:28:58 +08:00
committed by GitHub
parent c3bb9e71cf
commit ade394d68c
2 changed files with 13 additions and 11 deletions

View File

@@ -179,16 +179,17 @@ class SegmentationClusteringPipeline(Pipeline):
if not hasattr(self, 'vad_pipeline'):
self.vad_pipeline = pipeline(
task=Tasks.voice_activity_detection,
model=self.config['vad_model'])
vad_time = self.vad_pipeline(audio, audio_fs=self.fs)
model=self.config['vad_model'],
model_revision='v2.0.2')
vad_time = self.vad_pipeline(
audio, fs=self.fs, is_final=True)[0]['value']
vad_segments = []
if isinstance(vad_time['text'], str):
vad_time_list = ast.literal_eval(vad_time['text'])
elif isinstance(vad_time['text'], list):
vad_time_list = vad_time['text']
if isinstance(vad_time, str):
vad_time_list = ast.literal_eval(vad_time)
elif isinstance(vad_time, list):
vad_time_list = vad_time
else:
raise ValueError('Incorrect vad result. Get %s' %
(type(vad_time['text'])))
raise ValueError('Incorrect vad result. Get %s' % (type(vad_time)))
for t in vad_time_list:
st = int(t[0]) / 1000
ed = int(t[1]) / 1000

View File

@@ -44,7 +44,7 @@ class Pipeline(ABC):
"""Pipeline base.
"""
def initiate_single_model(self, model):
def initiate_single_model(self, model, **kwargs):
if isinstance(model, str):
logger.info(f'initiate model from {model}')
if isinstance(model, str) and is_official_hub_path(model):
@@ -55,7 +55,8 @@ class Pipeline(ABC):
device=self.device_name,
model_prefetched=True,
invoked_by=Invoke.PIPELINE,
device_map=self.device_map) if is_model(model) else model
device_map=self.device_map,
**kwargs) if is_model(model) else model
else:
return model
@@ -96,7 +97,7 @@ class Pipeline(ABC):
self.device_name = device
if not isinstance(model, List):
self.model = self.initiate_single_model(model)
self.model = self.initiate_single_model(model, **kwargs)
self.models = [self.model]
else:
self.model = None