VoxSum / src /asr.py
Luigi's picture
try interactive audio player and clickable transcript
c029cee
# asr.py
import numpy as np
import soundfile as sf
from scipy.signal import resample_poly
from silero_vad import load_silero_vad, VADIterator
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
from utils import s2tw_converter
import re
SAMPLING_RATE = 16000
CHUNK_SIZE = 512
tokenizer = load_tokenizer()
def clean_transcript(text):
text = re.sub(r'[�\uFFFD��]', '', text)
text = re.sub(r'([\u4e00-\u9fa5])\1{2,}', r'\1', text)
text = re.sub(r'([\u4e00-\u9fa5]) ([ \u4e00-\u9fa5])', r'\1\2', text)
return text
def transcribe_file(audio_path, vad_threshold, model_name):
vad_model = load_silero_vad(onnx=True)
vad_iterator = VADIterator(model=vad_model, sampling_rate=SAMPLING_RATE, threshold=vad_threshold)
model = MoonshineOnnxModel(model_name=f"moonshine/{model_name}")
wav, orig_sr = sf.read(audio_path)
if orig_sr != SAMPLING_RATE:
gcd = np.gcd(int(orig_sr), SAMPLING_RATE)
up = SAMPLING_RATE // gcd
down = orig_sr // gcd
wav = resample_poly(wav, up, down)
if wav.ndim > 1:
wav = wav.mean(axis=1)
utterances = [] # Store all utterances (start, end, text)
speech_buffer = np.array([], dtype=np.float32)
segment_start = 0.0 # Track start time of current segment
i = 0
while i < len(wav):
chunk = wav[i:i + CHUNK_SIZE]
if len(chunk) < CHUNK_SIZE:
chunk = np.pad(chunk, (0, CHUNK_SIZE - len(chunk)), mode='constant')
i += CHUNK_SIZE
speech_dict = vad_iterator(chunk)
speech_buffer = np.concatenate([speech_buffer, chunk])
if speech_dict:
if "end" in speech_dict:
# Calculate timestamps
segment_end = i / SAMPLING_RATE
text = model.generate(speech_buffer[np.newaxis, :].astype(np.float32))
text = tokenizer.decode_batch(text)[0].strip()
if text:
cleaned_text = clean_transcript(s2tw_converter.convert(text))
utterances.append((segment_start, segment_end, cleaned_text))
# Yield current utterance + all accumulated utterances
yield utterances[-1], utterances.copy()
# Reset for next segment
speech_buffer = np.array([], dtype=np.float32)
segment_start = i / SAMPLING_RATE # Start of next segment
vad_iterator.reset_states()
# Process final segment
if len(speech_buffer) > SAMPLING_RATE * 0.5:
segment_end = len(wav) / SAMPLING_RATE
text = model.generate(speech_buffer[np.newaxis, :].astype(np.float32))
text = tokenizer.decode_batch(text)[0].strip()
if text:
cleaned_text = clean_transcript(s2tw_converter.convert(text))
utterances.append((segment_start, segment_end, cleaned_text))
yield utterances[-1], utterances.copy()
# Final yield with all utterances
if utterances:
yield None, utterances
else:
yield None, [(-1, -1, "No speech detected")]