Skip to content

Commit

Permalink
clearer implementation of remove_non_speech with respect to audio wit…
Browse files Browse the repository at this point in the history
…hout speech
  • Loading branch information
Jeronymous committed Mar 11, 2024
1 parent bdee5d3 commit 9f903c7
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions whisper_timestamped/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__author__ = "Jérôme Louradour"
__credits__ = ["Jérôme Louradour"]
__license__ = "GPLv3"
__version__ = "1.15.1"
__version__ = "1.15.2"

# Set some environment variables
import os
Expand Down Expand Up @@ -277,9 +277,9 @@ def transcribe_timestamped(
compression_ratio_threshold=compression_ratio_threshold,
)

if vad:
if vad is not None:
audio = get_audio_tensor(audio)
audio, vad_segments, convert_timestamps = remove_non_speech(audio, method=vad, sample_rate=SAMPLE_RATE, plot=plot_word_alignment)
audio, vad_segments, convert_timestamps = remove_non_speech(audio, method=vad, sample_rate=SAMPLE_RATE, plot=plot_word_alignment, avoid_empty_speech=True)
else:
vad_segments = None

Expand Down Expand Up @@ -1856,8 +1856,8 @@ def check_vad_method(method, with_version=False):
"""
if method in [True, "True", "true"]:
return check_vad_method("silero") # default method
elif method in [False, "False", "false"]:
return False
elif method in [None, False, "False", "false", "None", "none"]:
return None
elif not isinstance(method, str) and hasattr(method, '__iter__'):
# list of explicit timestamps
checked_pairs = []
Expand Down Expand Up @@ -2063,6 +2063,7 @@ def remove_non_speech(audio,
dilatation=0.5,
sample_rate=SAMPLE_RATE,
method="silero",
avoid_empty_speech=False,
plot=False,
):
"""
Expand All @@ -2083,6 +2084,8 @@ def remove_non_speech(audio,
how much (in sec) to enlarge each speech segment detected by the VAD
method: str
method to use to remove non-speech segments
avoid_empty_speech: bool
if True, avoid returning an empty speech segment (re)
plot: bool or str
if True, plot the result.
If a string, save the plot to the given file
Expand All @@ -2100,7 +2103,10 @@ def remove_non_speech(audio,

segments = [(seg["start"], seg["end"]) for seg in segments]
if len(segments) == 0:
segments = [(0, audio.shape[-1])]
if avoid_empty_speech:
segments = [(0, audio.shape[-1])]
else:
return torch.Tensor([]), [], lambda t, t2 = None: do_convert_timestamps(segments, t, t2)

audio_speech = torch.cat([audio[..., s:e] for s,e in segments], dim=-1)

Expand All @@ -2121,7 +2127,7 @@ def remove_non_speech(audio,
if not use_sample:
segments = [(float(s)/sample_rate, float(e)/sample_rate) for s,e in segments]

return audio_speech, segments, lambda t, t2 = None: do_convert_timestamps(segments, t, t2)
return audio_speech, segments, lambda t, t2 = None: t if t2 is None else [t, t2]

def do_convert_timestamps(segments, t, t2 = None):
"""
Expand Down

0 comments on commit 9f903c7

Please sign in to comment.