Quartz4065 commited on
Commit
54358d8
·
verified ·
1 Parent(s): bc616a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -59
app.py CHANGED
@@ -1,79 +1,98 @@
 
1
  import os
2
- import tempfile
3
- from typing import List, Optional
 
4
 
5
- from fastapi import FastAPI, File, Form, UploadFile
6
- from fastapi.middleware.cors import CORSMiddleware
7
  from pydantic import BaseModel
 
 
 
 
 
8
  from faster_whisper import WhisperModel
 
9
 
10
- APP_PORT = int(os.environ.get("PORT", "7860"))
 
11
 
12
- _models = {}
13
- def get_model(name: str):
14
- if name not in _models:
15
- _models[name] = WhisperModel(
16
- name, compute_type="int8", cpu_threads=os.cpu_count() or 2
 
 
 
 
 
 
 
 
 
 
 
 
17
  )
18
- return _models[name]
 
 
19
 
20
- class Segment(BaseModel):
21
- start: float
22
- end: float
23
- text: str
 
 
24
 
25
  class TranscribeOut(BaseModel):
26
  text: str
27
- segments: List[Segment]
28
  duration_sec: Optional[float] = None
29
- words: Optional[int] = None
30
  wpm: Optional[float] = None
31
- model: str
32
 
33
- app = FastAPI(title="Nuvia Free Transcriber")
34
- app.add_middleware(
35
- CORSMiddleware,
36
- allow_origins=["*"], allow_credentials=True,
37
- allow_methods=["*"], allow_headers=["*"],
38
- )
39
 
40
- @app.get("/health")
41
  def health():
42
- return {"ok": True}
43
 
44
  @app.post("/transcribe", response_model=TranscribeOut)
45
- def transcribe(file: UploadFile = File(...), model: str = Form("base.en")):
46
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
47
- tmp.write(file.file.read())
48
- tmp_path = tmp.name
49
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
- m = get_model(model)
52
- segments, info = m.transcribe(tmp_path, vad_filter=True)
53
-
54
- segs = []
55
- total_words = 0
56
- for s in segments:
57
- txt = s.text.strip()
58
- segs.append(Segment(start=float(s.start), end=float(s.end), text=txt))
59
- total_words += len(txt.split())
60
-
61
- dur = float(info.duration) if getattr(info, "duration", None) else None
62
- wpm = None
63
- if dur and dur > 0:
64
- wpm = round(total_words / (dur / 60.0), 2)
65
-
66
- full_text = " ".join([s.text for s in segs]).strip()
67
- return TranscribeOut(
68
- text=full_text,
69
- segments=segs,
70
- duration_sec=dur,
71
- words=total_words,
72
- wpm=wpm,
73
- model=model
74
- )
75
- finally:
76
- try:
77
- os.remove(tmp_path)
78
- except Exception:
79
- pass
 
1
+ import io
2
  import os
3
+ import math
4
+ import subprocess
5
+ from typing import Optional
6
 
7
+ from fastapi import FastAPI, File, UploadFile
8
+ from fastapi.responses import JSONResponse
9
  from pydantic import BaseModel
10
+
11
+ # Optional CORS (safe default in Spaces)
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+
14
+ # Transcription (CPU)
15
  from faster_whisper import WhisperModel
16
+ import soundfile as sf
17
 
18
+ # ---------- App ----------
19
+ app = FastAPI(title="Nuvia Free Transcriber", version="1.0.0")
20
 
21
+ app.add_middleware(
22
+ CORSMiddleware,
23
+ allow_origins=["*"], allow_credentials=True,
24
+ allow_methods=["*"], allow_headers=["*"],
25
+ )
26
+
27
+ # ---------- Model load (CPU, small for free tier) ----------
28
+ # You can switch to "base.en" if needed; "tiny.en" is faster.
29
+ MODEL_NAME = os.environ.get("WHISPER_MODEL", "tiny.en")
30
+ model = WhisperModel(MODEL_NAME, device="cpu", compute_type="int8")
31
+
32
+ # ---------- Helpers ----------
33
+ def ffprobe_duration(path: str) -> Optional[float]:
34
+ try:
35
+ out = subprocess.check_output(
36
+ ["ffprobe", "-v", "error", "-show_entries", "format=duration",
37
+ "-of", "default=noprint_wrappers=1:nokey=1", path]
38
  )
39
+ return float(out.decode("utf-8").strip())
40
+ except Exception:
41
+ return None
42
 
43
+ def word_count(text: str) -> int:
44
+ return len([w for w in text.split() if w.strip()])
45
+
46
+ # ---------- Schemas ----------
47
+ class HealthOut(BaseModel):
48
+ ok: bool
49
 
50
  class TranscribeOut(BaseModel):
51
  text: str
 
52
  duration_sec: Optional[float] = None
 
53
  wpm: Optional[float] = None
 
54
 
55
+ # ---------- Routes ----------
56
+ @app.get("/", response_model=HealthOut)
57
+ def root():
58
+ """Root route so probes and GPT 'test connection' don’t 404."""
59
+ return HealthOut(ok=True)
 
60
 
61
+ @app.get("/health", response_model=HealthOut)
62
  def health():
63
+ return HealthOut(ok=True)
64
 
65
  @app.post("/transcribe", response_model=TranscribeOut)
66
+ async def transcribe(file: UploadFile = File(...)):
67
+ # Read uploaded bytes
68
+ raw = await file.read()
 
69
 
70
+ # Save to temp wav (Spaces use ephemeral FS; this is fine)
71
+ tmp_in = "/tmp/infile"
72
+ # Keep original extension if present
73
+ ext = os.path.splitext(file.filename or "")[1].lower() or ".bin"
74
+ tmp_in = tmp_in + ext
75
+ with open(tmp_in, "wb") as f:
76
+ f.write(raw)
77
+
78
+ # Ensure we have a WAV for robust decode
79
+ tmp_wav = "/tmp/in.wav"
80
  try:
81
+ subprocess.check_call(["ffmpeg", "-y", "-i", tmp_in, "-ar", "16000", "-ac", "1", tmp_wav], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
82
+ except subprocess.CalledProcessError:
83
+ return JSONResponse(status_code=400, content={"error": "ffmpeg failed to decode the audio"})
84
+
85
+ # Duration via ffprobe (more accurate than guessing)
86
+ duration = ffprobe_duration(tmp_wav)
87
+
88
+ # Transcribe
89
+ segments, info = model.transcribe(tmp_wav, language="en")
90
+ text = "".join([seg.text for seg in segments]).strip()
91
+
92
+ # WPM (best-effort)
93
+ wpm = None
94
+ if duration and duration > 0:
95
+ wc = word_count(text)
96
+ wpm = round((wc / (duration / 60.0)), 1)
97
+
98
+ return TranscribeOut(text=text, duration_sec=duration, wpm=wpm)