Whisper-ASR / app.py
Nymbo's picture
Update app.py
8f7b84c verified
import gradio as gr
import torch
import tempfile
import os
import time
import numpy as np
import warnings
# Suppress expected warnings from transformers/whisper
warnings.filterwarnings("ignore", message=".*deprecated.*")
warnings.filterwarnings("ignore", message=".*Whisper did not predict.*")
warnings.filterwarnings("ignore", message=".*pipelines sequentially.*")
# Optional imports - lazy load the heavy dependencies
try:
from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read
except ImportError:
pipeline = None
ffmpeg_read = None
try:
import yt_dlp as youtube_dl
except ImportError:
youtube_dl = None
# Model configuration
MODEL_ID = "openai/whisper-large-v3-turbo"
BATCH_SIZE = 8
CHUNK_LENGTH_S = 30 # Split long audio into 30-second chunks
SAMPLE_RATE = 16000 # Whisper expects 16kHz audio
YT_LENGTH_LIMIT_S = 3600 # Limit YouTube videos to 1 hour
# Detect if running on Hugging Face Spaces (YouTube won't work there due to network restrictions)
IS_HF_SPACE = os.environ.get("SPACE_ID") is not None
# Lazy load state for the Whisper model
_WHISPER_STATE = {"initialized": False, "pipe": None, "device": "cpu"}
# Supported languages for the dropdown
LANGUAGES = [
("Auto-detect", "auto"),
("English", "english"),
("Spanish", "spanish"),
("French", "french"),
("German", "german"),
("Italian", "italian"),
("Portuguese", "portuguese"),
("Dutch", "dutch"),
("Russian", "russian"),
("Chinese", "chinese"),
("Japanese", "japanese"),
("Korean", "korean"),
("Arabic", "arabic"),
("Hindi", "hindi"),
]
def _init_whisper() -> None:
"""Initialize the Whisper model lazily on first use."""
if _WHISPER_STATE["initialized"]:
return
if pipeline is None:
raise gr.Error(
"Transformers library not properly installed. "
"Please run: pip install transformers>=4.45.0"
)
# Detect device
device = 0 if torch.cuda.is_available() else "cpu"
device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Initializing Whisper model on device: {device_name}")
try:
# Create the pipeline WITHOUT chunk_length_s - we'll chunk manually for streaming
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_ID,
device=device,
)
_WHISPER_STATE.update({
"initialized": True,
"pipe": pipe,
"device": device_name,
})
print("Whisper model initialized successfully.")
except Exception as e:
raise gr.Error(f"Failed to initialize Whisper model: {str(e)[:200]}")
def get_device_info() -> str:
"""Get the current device being used for inference."""
if _WHISPER_STATE["initialized"]:
return _WHISPER_STATE["device"]
return "cuda:0" if torch.cuda.is_available() else "cpu"
def _load_audio(audio_path: str) -> np.ndarray:
"""Load audio file and convert to numpy array at 16kHz."""
if ffmpeg_read is None:
raise gr.Error("transformers not properly installed.")
with open(audio_path, "rb") as f:
audio_bytes = f.read()
# ffmpeg_read returns audio as float32 numpy array at the specified sample rate
audio = ffmpeg_read(audio_bytes, SAMPLE_RATE)
return audio
def _chunk_audio(audio: np.ndarray, chunk_length_s: int = CHUNK_LENGTH_S) -> list[dict]:
"""Split audio array into chunks for streaming processing."""
chunk_length_samples = chunk_length_s * SAMPLE_RATE
total_samples = len(audio)
chunks = []
for start in range(0, total_samples, chunk_length_samples):
end = min(start + chunk_length_samples, total_samples)
chunk_audio = audio[start:end]
# Calculate time offset for this chunk
start_time = start / SAMPLE_RATE
chunks.append({
"array": chunk_audio,
"sampling_rate": SAMPLE_RATE,
"start_time": start_time,
})
return chunks
def transcribe_audio_streaming(
audio_path: str,
task: str,
language: str,
return_timestamps: bool,
):
"""
Transcribe audio with streaming output - yields results chunk by chunk.
Args:
audio_path: Path to the audio file
task: 'transcribe' or 'translate'
language: Language code or 'auto'
return_timestamps: Whether to include timestamps
Yields:
Accumulated transcription text after each chunk
"""
if not audio_path:
raise gr.Error("Please provide an audio file to transcribe.")
# Initialize model on first use
_init_whisper()
pipe = _WHISPER_STATE["pipe"]
# Build generate kwargs
generate_kwargs = {"task": task}
if language != "auto":
generate_kwargs["language"] = language
try:
# Load and chunk the audio
audio = _load_audio(audio_path)
chunks = _chunk_audio(audio)
# If only one chunk, no need for streaming
if len(chunks) == 1:
result = pipe(
{"array": audio, "sampling_rate": SAMPLE_RATE},
batch_size=BATCH_SIZE,
generate_kwargs=generate_kwargs,
return_timestamps=return_timestamps,
)
if return_timestamps and "chunks" in result and result["chunks"]:
lines = []
for chunk in result["chunks"]:
start = chunk.get("timestamp", (0, 0))[0] or 0
end = chunk.get("timestamp", (0, 0))[1] or 0
text = chunk.get("text", "").strip()
lines.append(f"[{start:.2f}s - {end:.2f}s] {text}")
yield "\n".join(lines)
else:
yield result.get("text", "")
return
# Process chunks and stream results
accumulated_text = ""
accumulated_lines = []
for i, chunk_data in enumerate(chunks):
chunk_start_time = chunk_data["start_time"]
# Process this chunk
result = pipe(
{"array": chunk_data["array"], "sampling_rate": chunk_data["sampling_rate"]},
batch_size=BATCH_SIZE,
generate_kwargs=generate_kwargs,
return_timestamps=return_timestamps,
)
if return_timestamps and "chunks" in result and result["chunks"]:
for ts_chunk in result["chunks"]:
# Adjust timestamps to account for chunk offset
start = (ts_chunk.get("timestamp", (0, 0))[0] or 0) + chunk_start_time
end = (ts_chunk.get("timestamp", (0, 0))[1] or 0) + chunk_start_time
text = ts_chunk.get("text", "").strip()
accumulated_lines.append(f"[{start:.2f}s - {end:.2f}s] {text}")
yield "\n".join(accumulated_lines)
else:
chunk_text = result.get("text", "").strip()
if chunk_text:
if accumulated_text:
accumulated_text += " " + chunk_text
else:
accumulated_text = chunk_text
yield accumulated_text
except gr.Error:
raise
except Exception as e:
raise gr.Error(f"Transcription failed: {str(e)[:200]}")
def _get_yt_html_embed(yt_url: str) -> str:
"""Generate YouTube embed HTML for display."""
video_id = yt_url.split("?v=")[-1].split("&")[0]
return (
f'<center><iframe width="500" height="320" '
f'src="https://www.youtube.com/embed/{video_id}"></iframe></center>'
)
def _download_yt_audio(yt_url: str, filepath: str) -> None:
"""Download audio from a YouTube URL."""
if youtube_dl is None:
raise gr.Error("yt-dlp not installed. Please run: pip install yt-dlp")
info_loader = youtube_dl.YoutubeDL()
try:
info = info_loader.extract_info(yt_url, download=False)
except youtube_dl.utils.DownloadError as err:
# Check if this is a network/DNS error (common on HF Spaces)
err_str = str(err)
if "Failed to resolve" in err_str or "No address associated" in err_str:
raise gr.Error(
"YouTube download failed due to network restrictions. "
"This feature requires running the app locally. "
"On Hugging Face Spaces, outbound connections to YouTube are blocked."
)
raise gr.Error(str(err))
# Parse duration
file_length = info.get("duration_string", "0")
file_h_m_s = file_length.split(":")
file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
if len(file_h_m_s) == 1:
file_h_m_s.insert(0, 0)
if len(file_h_m_s) == 2:
file_h_m_s.insert(0, 0)
file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
if file_length_s > YT_LENGTH_LIMIT_S:
yt_limit_hms = time.strftime("%H:%M:%S", time.gmtime(YT_LENGTH_LIMIT_S))
file_hms = time.strftime("%H:%M:%S", time.gmtime(file_length_s))
raise gr.Error(f"Maximum YouTube length is {yt_limit_hms}, got {file_hms}.")
ydl_opts = {
"outtmpl": filepath,
"format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
}
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
try:
ydl.download([yt_url])
except youtube_dl.utils.ExtractorError as err:
raise gr.Error(str(err))
def transcribe_youtube_streaming(
yt_url: str,
task: str,
language: str,
return_timestamps: bool,
):
"""
Transcribe a YouTube video with streaming output.
Yields tuples of (html_embed, accumulated_text).
"""
if not yt_url:
raise gr.Error("Please provide a YouTube URL.")
if youtube_dl is None:
raise gr.Error("yt-dlp not installed. Please run: pip install yt-dlp")
if ffmpeg_read is None:
raise gr.Error("transformers not properly installed.")
html_embed = _get_yt_html_embed(yt_url)
# Initialize model
_init_whisper()
pipe = _WHISPER_STATE["pipe"]
# Download video to temp directory
with tempfile.TemporaryDirectory() as tmpdir:
filepath = os.path.join(tmpdir, "video.mp4")
# Yield initial state while downloading
yield html_embed, "Downloading video..."
_download_yt_audio(yt_url, filepath)
yield html_embed, "Processing audio..."
# Load audio
with open(filepath, "rb") as f:
audio_bytes = f.read()
audio = ffmpeg_read(audio_bytes, SAMPLE_RATE)
# Build generate kwargs
generate_kwargs = {"task": task}
if language != "auto":
generate_kwargs["language"] = language
# Chunk and process
chunks = _chunk_audio(audio)
try:
if len(chunks) == 1:
# Single chunk - no streaming benefit
result = pipe(
{"array": audio, "sampling_rate": SAMPLE_RATE},
batch_size=BATCH_SIZE,
generate_kwargs=generate_kwargs,
return_timestamps=return_timestamps,
)
if return_timestamps and "chunks" in result and result["chunks"]:
lines = []
for chunk in result["chunks"]:
start = chunk.get("timestamp", (0, 0))[0] or 0
end = chunk.get("timestamp", (0, 0))[1] or 0
text = chunk.get("text", "").strip()
lines.append(f"[{start:.2f}s - {end:.2f}s] {text}")
yield html_embed, "\n".join(lines)
else:
yield html_embed, result.get("text", "")
return
# Multi-chunk streaming
accumulated_text = ""
accumulated_lines = []
for i, chunk_data in enumerate(chunks):
chunk_start_time = chunk_data["start_time"]
result = pipe(
{"array": chunk_data["array"], "sampling_rate": chunk_data["sampling_rate"]},
batch_size=BATCH_SIZE,
generate_kwargs=generate_kwargs,
return_timestamps=return_timestamps,
)
if return_timestamps and "chunks" in result and result["chunks"]:
for ts_chunk in result["chunks"]:
start = (ts_chunk.get("timestamp", (0, 0))[0] or 0) + chunk_start_time
end = (ts_chunk.get("timestamp", (0, 0))[1] or 0) + chunk_start_time
text = ts_chunk.get("text", "").strip()
accumulated_lines.append(f"[{start:.2f}s - {end:.2f}s] {text}")
yield html_embed, "\n".join(accumulated_lines)
else:
chunk_text = result.get("text", "").strip()
if chunk_text:
if accumulated_text:
accumulated_text += " " + chunk_text
else:
accumulated_text = chunk_text
yield html_embed, accumulated_text
except gr.Error:
raise
except Exception as e:
raise gr.Error(f"YouTube transcription failed: {str(e)[:200]}")
# Build the Gradio interface
with gr.Blocks(title="Whisper-ASR") as demo:
# Header
gr.HTML(
f"""
<h1 style='text-align: center;'>Whisper-ASR</h1>
<p style='text-align: center;'>
Powered by <code>openai/whisper-large-v3-turbo</code> on
<strong>{get_device_info().upper()}</strong>
</p>
"""
)
with gr.Tabs():
# Tab 1: Audio File / Microphone
with gr.TabItem("Audio File"):
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
label="Audio Input",
sources=["microphone", "upload"],
type="filepath",
)
with gr.Row():
task_radio = gr.Radio(
choices=["transcribe", "translate"],
value="transcribe",
label="Task",
info="Translate converts any language to English",
)
language_dropdown = gr.Dropdown(
choices=LANGUAGES,
value="auto",
label="Language",
info="Source language (auto-detect recommended)",
)
timestamps_checkbox = gr.Checkbox(
label="Return Timestamps",
value=False,
)
transcribe_btn = gr.Button("Transcribe", variant="primary")
with gr.Column():
audio_output = gr.Textbox(
label="Transcription",
placeholder="Transcribed text will appear here...",
lines=12,
)
transcribe_btn.click(
fn=transcribe_audio_streaming,
inputs=[audio_input, task_radio, language_dropdown, timestamps_checkbox],
outputs=audio_output,
api_name="transcribe",
)
# Tab 2: YouTube (only shown when running locally)
if not IS_HF_SPACE:
with gr.TabItem("YouTube"):
with gr.Row():
with gr.Column():
yt_url_input = gr.Textbox(
label="YouTube URL",
placeholder="Paste a YouTube video URL here...",
lines=1,
)
with gr.Row():
yt_task_radio = gr.Radio(
choices=["transcribe", "translate"],
value="transcribe",
label="Task",
info="Translate converts any language to English",
)
yt_language_dropdown = gr.Dropdown(
choices=LANGUAGES,
value="auto",
label="Language",
)
yt_timestamps_checkbox = gr.Checkbox(
label="Return Timestamps",
value=False,
)
yt_transcribe_btn = gr.Button("Transcribe YouTube", variant="primary")
with gr.Column():
yt_embed = gr.HTML(label="Video")
yt_output = gr.Textbox(
label="Transcription",
placeholder="Transcribed text will appear here...",
lines=10,
)
yt_transcribe_btn.click(
fn=transcribe_youtube_streaming,
inputs=[yt_url_input, yt_task_radio, yt_language_dropdown, yt_timestamps_checkbox],
outputs=[yt_embed, yt_output],
api_name="transcribe_youtube",
)
if __name__ == "__main__":
demo.queue().launch(theme="Nymbo/Nymbo_Theme")