staraks commited on
Commit
ae60bd6
·
verified ·
1 Parent(s): 66a1d7f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +411 -0
app.py CHANGED
@@ -1,3 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  HTML_UI = """
2
  <!DOCTYPE html>
3
  <html lang="en">
@@ -591,3 +988,17 @@ HTML_UI = """
591
  </body>
592
  </html>
593
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import tempfile
4
+ from typing import List, Literal, Optional
5
+
6
+ import torch
7
+ import pyzipper
8
+ import soundfile as sf # noqa: F401 (ensure audio backend is available)
9
+
10
+ from docx import Document
11
+ from fastapi import FastAPI, File, UploadFile, Form, HTTPException
12
+ from fastapi.responses import (
13
+ FileResponse,
14
+ JSONResponse,
15
+ PlainTextResponse,
16
+ HTMLResponse,
17
+ )
18
+ from pydantic import BaseModel
19
+ from transformers import pipeline
20
+ import spaces
21
+
22
+ # ===================== CONFIG =====================
23
+
24
+ MODEL_NAME = "openai/whisper-large-v3"
25
+
26
+ AUDIO_EXTENSIONS = (
27
+ ".wav",
28
+ ".mp3",
29
+ ".m4a",
30
+ ".flac",
31
+ ".ogg",
32
+ ".opus",
33
+ ".webm",
34
+ )
35
+
36
+ # Use GPU if available on the Space
37
+ device = 0 if torch.cuda.is_available() else "cpu"
38
+
39
+ # Lazy-loaded pipeline (created on first request)
40
+ asr_pipe = None
41
+
42
+
43
+ def get_pipeline():
44
+ global asr_pipe
45
+ if asr_pipe is None:
46
+ asr_pipe = pipeline(
47
+ task="automatic-speech-recognition",
48
+ model=MODEL_NAME,
49
+ chunk_length_s=30,
50
+ device=device,
51
+ )
52
+ return asr_pipe
53
+
54
+
55
+ # ===================== Pydantic models =====================
56
+
57
+ class FileTranscript(BaseModel):
58
+ filename: str
59
+ text: str
60
+
61
+
62
+ class TranscriptionResponse(BaseModel):
63
+ mode: Literal["general", "medical_en"]
64
+ combined_transcript: str
65
+ items: List[FileTranscript]
66
+
67
+
68
+ # ===================== Helper functions =====================
69
+
70
+ def build_generate_kwargs(mode: str):
71
+ """
72
+ mode: 'general' | 'medical_en'
73
+ Always transcribe with auto language detection,
74
+ but in medical_en we bias towards English medical dictation.
75
+ """
76
+ generate_kwargs = {
77
+ "task": "transcribe", # keep same language as audio
78
+ }
79
+
80
+ if mode == "medical_en":
81
+ # Strong bias towards English medical terminology
82
+ generate_kwargs["language"] = "en"
83
+ generate_kwargs["initial_prompt"] = (
84
+ "This is a medical dictation. Use accurate English medical terminology, "
85
+ "including anatomy, diseases, investigations, lab values, imaging, and drugs. "
86
+ "Keep the style clinical and professional."
87
+ )
88
+
89
+ return generate_kwargs
90
+
91
+
92
+ def filter_audio_files(paths: List[str]) -> List[str]:
93
+ out: List[str] = []
94
+ for p in paths:
95
+ _, ext = os.path.splitext(p)
96
+ if ext.lower() in AUDIO_EXTENSIONS:
97
+ out.append(p)
98
+ return out
99
+
100
+
101
+ def transcribe_file(path: str, mode: str) -> str:
102
+ pipe = get_pipeline()
103
+ generate_kwargs = build_generate_kwargs(mode)
104
+
105
+ result = pipe(
106
+ path,
107
+ batch_size=8,
108
+ generate_kwargs=generate_kwargs,
109
+ return_timestamps=False,
110
+ )
111
+
112
+ if isinstance(result, dict):
113
+ return (result.get("text") or "").strip()
114
+ if isinstance(result, list) and result:
115
+ return (result[0].get("text") or "").strip()
116
+ return ""
117
+
118
+
119
+ def format_combined(results: List[FileTranscript]) -> str:
120
+ parts: List[str] = []
121
+ for idx, item in enumerate(results, start=1):
122
+ parts.append(f"### File {idx}: {item.filename}")
123
+ parts.append("")
124
+ parts.append(item.text if item.text else "[No transcript]")
125
+ parts.append("")
126
+ return "\n".join(parts).strip()
127
+
128
+
129
+ def build_docx(results: List[FileTranscript], title: str) -> str:
130
+ doc = Document()
131
+ doc.add_heading(title, level=1)
132
+
133
+ for idx, item in enumerate(results, start=1):
134
+ doc.add_heading(f"File {idx}: {item.filename}", level=2)
135
+ doc.add_paragraph(item.text if item.text else "[No transcript]")
136
+ doc.add_paragraph()
137
+
138
+ tmpdir = tempfile.mkdtemp(prefix="docx_")
139
+ out_path = os.path.join(tmpdir, "transcripts.docx")
140
+ doc.save(out_path)
141
+ return out_path
142
+
143
+
144
+ def save_uploads_to_temp(files: List[UploadFile]) -> List[str]:
145
+ tmpdir = tempfile.mkdtemp(prefix="uploads_")
146
+ local_paths: List[str] = []
147
+ for uf in files:
148
+ filename = os.path.basename(uf.filename or "audio")
149
+ local_path = os.path.join(tmpdir, filename)
150
+ with open(local_path, "wb") as out_f:
151
+ shutil.copyfileobj(uf.file, out_f)
152
+ local_paths.append(local_path)
153
+ return local_paths
154
+
155
+
156
+ def extract_zip_to_temp(zip_file: UploadFile, password: Optional[str]) -> List[str]:
157
+ tmpdir = tempfile.mkdtemp(prefix="zip_")
158
+ zip_path = os.path.join(tmpdir, os.path.basename(zip_file.filename or "archive.zip"))
159
+
160
+ # Save uploaded ZIP
161
+ with open(zip_path, "wb") as out_f:
162
+ shutil.copyfileobj(zip_file.file, out_f)
163
+
164
+ outdir = tempfile.mkdtemp(prefix="zip_files_")
165
+
166
+ try:
167
+ with pyzipper.AESZipFile(zip_path, "r") as zf:
168
+ if password:
169
+ zf.setpassword(password.encode("utf-8"))
170
+
171
+ for info in zf.infolist():
172
+ if info.is_dir():
173
+ continue
174
+ name = os.path.basename(info.filename)
175
+ if not name:
176
+ continue
177
+ out_path = os.path.join(outdir, name)
178
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
179
+ with zf.open(info) as src, open(out_path, "wb") as dst:
180
+ shutil.copyfileobj(src, dst)
181
+
182
+ except (pyzipper.BadZipFile, RuntimeError, KeyError) as e:
183
+ shutil.rmtree(outdir, ignore_errors=True)
184
+ raise HTTPException(
185
+ status_code=400,
186
+ detail=f"Failed to open ZIP file. Check password / integrity. {e}",
187
+ )
188
+
189
+ files = [os.path.join(outdir, f) for f in os.listdir(outdir)]
190
+ return files
191
+
192
+
193
+ # ===================== FastAPI app =====================
194
+
195
+ app = FastAPI(
196
+ title="Whisper Large V3 – Medical Batch Transcription API",
197
+ description="""
198
+ HTTP API for Whisper Large V3 with:
199
+
200
+ - Multi-file audio upload
201
+ - Password-protected ZIP upload
202
+ - Medical-biased transcription mode
203
+ - Combined transcript
204
+ - Optional merged Word (.docx) download
205
+
206
+ Use `/docs` for Swagger UI and `/ui` for the web interface.
207
+ """,
208
+ version="1.0.0",
209
+ )
210
+
211
+
212
+ @app.get("/", response_class=PlainTextResponse)
213
+ def root():
214
+ return (
215
+ "Whisper Large V3 – Medical Batch Transcription API\n"
216
+ "Open /docs for API documentation or /ui for the web interface.\n"
217
+ )
218
+
219
+
220
+ @app.get("/health", response_class=PlainTextResponse)
221
+ def health():
222
+ return "OK"
223
+
224
+
225
+ @app.get("/self-test")
226
+ def self_test():
227
+ """
228
+ Basic self-check:
229
+ - can we create/load the pipeline?
230
+ - what device are we using?
231
+ """
232
+ try:
233
+ pipe = get_pipeline()
234
+ model_name = getattr(pipe.model, "name_or_path", MODEL_NAME)
235
+ dev = "cuda" if device == 0 else str(device)
236
+ return JSONResponse(
237
+ {
238
+ "status": "ok",
239
+ "message": "Pipeline loaded successfully.",
240
+ "model": model_name,
241
+ "device": dev,
242
+ }
243
+ )
244
+ except Exception as e:
245
+ return JSONResponse(
246
+ {
247
+ "status": "error",
248
+ "message": f"Pipeline failed to load: {e}",
249
+ },
250
+ status_code=500,
251
+ )
252
+
253
+
254
+ # ---------- 1. Multi-file transcription (JSON) ----------
255
+
256
+ @app.post("/api/transcribe/files", response_model=TranscriptionResponse)
257
+ @spaces.GPU
258
+ def transcribe_files(
259
+ files: List[UploadFile] = File(..., description="One or more audio files"),
260
+ mode: Literal["general", "medical_en"] = Form("medical_en"),
261
+ ):
262
+ if not files:
263
+ raise HTTPException(status_code=400, detail="No files uploaded.")
264
+
265
+ local_paths = save_uploads_to_temp(files)
266
+ audio_paths = filter_audio_files(local_paths)
267
+
268
+ if not audio_paths:
269
+ raise HTTPException(
270
+ status_code=400,
271
+ detail=f"No valid audio files found. Supported extensions: {', '.join(AUDIO_EXTENSIONS)}",
272
+ )
273
+
274
+ items: List[FileTranscript] = []
275
+ for path in audio_paths:
276
+ fname = os.path.basename(path)
277
+ text = transcribe_file(path, mode)
278
+ items.append(FileTranscript(filename=fname, text=text))
279
+
280
+ combined = format_combined(items)
281
+
282
+ return TranscriptionResponse(
283
+ mode=mode,
284
+ combined_transcript=combined,
285
+ items=items,
286
+ )
287
+
288
+
289
+ # ---------- 2. Multi-file transcription (DOCX download) ----------
290
+
291
+ @app.post("/api/transcribe/files/docx")
292
+ @spaces.GPU
293
+ def transcribe_files_docx(
294
+ files: List[UploadFile] = File(..., description="One or more audio files"),
295
+ mode: Literal["general", "medical_en"] = Form("medical_en"),
296
+ ):
297
+ if not files:
298
+ raise HTTPException(status_code=400, detail="No files uploaded.")
299
+
300
+ local_paths = save_uploads_to_temp(files)
301
+ audio_paths = filter_audio_files(local_paths)
302
+
303
+ if not audio_paths:
304
+ raise HTTPException(
305
+ status_code=400,
306
+ detail=f"No valid audio files found. Supported extensions: {', '.join(AUDIO_EXTENSIONS)}",
307
+ )
308
+
309
+ items: List[FileTranscript] = []
310
+ for path in audio_paths:
311
+ fname = os.path.basename(path)
312
+ text = transcribe_file(path, mode)
313
+ items.append(FileTranscript(filename=fname, text=text))
314
+
315
+ docx_path = build_docx(items, "Multi-file transcription")
316
+
317
+ return FileResponse(
318
+ docx_path,
319
+ media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
320
+ filename="transcripts_files.docx",
321
+ )
322
+
323
+
324
+ # ---------- 3. ZIP transcription (JSON) ----------
325
+
326
+ @app.post("/api/transcribe/zip", response_model=TranscriptionResponse)
327
+ @spaces.GPU
328
+ def transcribe_zip(
329
+ file: UploadFile = File(..., description="ZIP file containing audio files"),
330
+ password: str = Form("", description="ZIP password (leave blank if none)"),
331
+ mode: Literal["general", "medical_en"] = Form("medical_en"),
332
+ ):
333
+ if file is None:
334
+ raise HTTPException(status_code=400, detail="No ZIP uploaded.")
335
+
336
+ extracted_paths = extract_zip_to_temp(file, password or None)
337
+ audio_paths = filter_audio_files(extracted_paths)
338
+
339
+ if not audio_paths:
340
+ raise HTTPException(
341
+ status_code=400,
342
+ detail=f"No valid audio files found inside ZIP. Supported extensions: {', '.join(AUDIO_EXTENSIONS)}",
343
+ )
344
+
345
+ items: List[FileTranscript] = []
346
+ for path in audio_paths:
347
+ fname = os.path.basename(path)
348
+ text = transcribe_file(path, mode)
349
+ items.append(FileTranscript(filename=fname, text=text))
350
+
351
+ combined = format_combined(items)
352
+
353
+ return TranscriptionResponse(
354
+ mode=mode,
355
+ combined_transcript=combined,
356
+ items=items,
357
+ )
358
+
359
+
360
+ # ---------- 4. ZIP transcription (DOCX download) ----------
361
+
362
+ @app.post("/api/transcribe/zip/docx")
363
+ @spaces.GPU
364
+ def transcribe_zip_docx(
365
+ file: UploadFile = File(..., description="ZIP file containing audio files"),
366
+ password: str = Form("", description="ZIP password (leave blank if none)"),
367
+ mode: Literal["general", "medical_en"] = Form("medical_en"),
368
+ ):
369
+ if file is None:
370
+ raise HTTPException(status_code=400, detail="No ZIP uploaded.")
371
+
372
+ extracted_paths = extract_zip_to_temp(file, password or None)
373
+ audio_paths = filter_audio_files(extracted_paths)
374
+
375
+ if not audio_paths:
376
+ raise HTTPException(
377
+ status_code=400,
378
+ detail=f"No valid audio files found inside ZIP. Supported extensions: {', '.join(AUDIO_EXTENSIONS)}",
379
+ )
380
+
381
+ items: List[FileTranscript] = []
382
+ for path in audio_paths:
383
+ fname = os.path.basename(path)
384
+ text = transcribe_file(path, mode)
385
+ items.append(FileTranscript(filename=fname, text=text))
386
+
387
+ docx_path = build_docx(items, "ZIP transcription")
388
+
389
+ return FileResponse(
390
+ docx_path,
391
+ media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
392
+ filename="transcripts_zip.docx",
393
+ )
394
+
395
+
396
+ # ===================== Simple HTML UI =====================
397
+
398
  HTML_UI = """
399
  <!DOCTYPE html>
400
  <html lang="en">
 
988
  </body>
989
  </html>
990
  """
991
+
992
+
993
+ @app.get("/ui", response_class=HTMLResponse)
994
+ def get_ui():
995
+ return HTML_UI
996
+
997
+
998
+ # ===================== Run (local dev / HF Spaces) =====================
999
+
1000
+ if __name__ == "__main__":
1001
+ import uvicorn
1002
+
1003
+ port = int(os.getenv("PORT", "7860"))
1004
+ uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)