Spaces:
Sleeping
Sleeping
Commit
Β·
f0502f7
1
Parent(s):
6facecb
Updated
Browse files- main.py +33 -35
- requirements.txt +2 -1
main.py
CHANGED
|
@@ -4,10 +4,9 @@ import tempfile
|
|
| 4 |
from fastapi import FastAPI, UploadFile, File, Header, HTTPException
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
from pydantic import BaseModel
|
| 7 |
-
from spitch import Spitch
|
| 8 |
from langchain.prompts import PromptTemplate
|
| 9 |
-
from
|
| 10 |
-
from langchain_community.llms import HuggingFaceHub
|
| 11 |
from langdetect import detect, DetectorFactory
|
| 12 |
|
| 13 |
DetectorFactory.seed = 0
|
|
@@ -15,23 +14,27 @@ DetectorFactory.seed = 0
|
|
| 15 |
# --------- BASIC CONFIG ----------
|
| 16 |
SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
|
| 17 |
HF_MODEL = os.getenv("HF_MODEL", "google/flan-t5-base")
|
| 18 |
-
FRONTEND_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*")
|
| 19 |
-
PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "")
|
| 20 |
|
| 21 |
if not SPITCH_API_KEY:
|
| 22 |
raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
|
| 23 |
|
| 24 |
-
# Init Spitch
|
| 25 |
os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
|
| 26 |
spitch_client = Spitch()
|
| 27 |
|
| 28 |
-
#
|
| 29 |
-
llm =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
# FastAPI app
|
| 32 |
app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)")
|
| 33 |
|
| 34 |
-
# CORS
|
| 35 |
app.add_middleware(
|
| 36 |
CORSMiddleware,
|
| 37 |
allow_origins=[FRONTEND_ORIGIN] if FRONTEND_ORIGIN != "*" else ["*"],
|
|
@@ -60,9 +63,14 @@ autodoc_template = """You are DevAssist DocBot.
|
|
| 60 |
Code: {code}
|
| 61 |
Documentation:"""
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# --------- REQUEST MODELS ----------
|
| 68 |
class ChatRequest(BaseModel):
|
|
@@ -74,7 +82,7 @@ class AutoDocRequest(BaseModel):
|
|
| 74 |
# --------- AUTH ----------
|
| 75 |
def check_auth(authorization: str | None):
|
| 76 |
if not PROJECT_API_KEY:
|
| 77 |
-
return
|
| 78 |
if not authorization or not authorization.startswith("Bearer "):
|
| 79 |
raise HTTPException(status_code=401, detail="Missing bearer token")
|
| 80 |
token = authorization.split(" ", 1)[1]
|
|
@@ -89,68 +97,58 @@ def root():
|
|
| 89 |
@app.post("/chat")
|
| 90 |
def chat(req: ChatRequest, authorization: str | None = Header(None)):
|
| 91 |
check_auth(authorization)
|
| 92 |
-
answer = chat_chain.
|
| 93 |
-
return {"reply": answer.strip()}
|
| 94 |
|
| 95 |
-
# Speech endpoint: full pipeline speech -> transcription -> translation (if needed) -> LLM
|
| 96 |
@app.post("/stt")
|
| 97 |
async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
|
| 98 |
-
"""
|
| 99 |
-
POST /stt with form-data file=@audio.mp3
|
| 100 |
-
Optional query/form field lang_hint: two-letter code (e.g. 'yo' for Yoruba) if frontend knows spoken language
|
| 101 |
-
Returns: transcription, detected_language, translation (to en), reply
|
| 102 |
-
"""
|
| 103 |
check_auth(authorization)
|
| 104 |
|
| 105 |
-
# save uploaded file to temp file
|
| 106 |
suffix = os.path.splitext(file.filename)[1] or ".wav"
|
| 107 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
| 108 |
content = await file.read()
|
| 109 |
tf.write(content)
|
| 110 |
tmp_path = tf.name
|
| 111 |
|
| 112 |
-
# 1) Transcribe using Spitch SDK (docs show client.speech.transcribe)
|
| 113 |
-
# If lang_hint provided, pass it; else attempt without language param and fallback
|
| 114 |
try:
|
| 115 |
if lang_hint:
|
| 116 |
resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read())
|
| 117 |
else:
|
| 118 |
-
# attempt transcribe without explicit language (SDK may auto-detect)
|
| 119 |
resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read())
|
| 120 |
-
except Exception
|
| 121 |
-
# fallback: try English transcription as last resort
|
| 122 |
resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
|
| 123 |
|
| 124 |
transcription = getattr(resp, "text", "") or resp.get("text", "") if isinstance(resp, dict) else ""
|
| 125 |
|
| 126 |
-
# 2) Detect language of transcription if not provided
|
| 127 |
try:
|
| 128 |
detected_lang = detect(transcription) if transcription.strip() else "en"
|
| 129 |
except Exception:
|
| 130 |
detected_lang = "en"
|
| 131 |
|
| 132 |
-
# 3) If detected_lang != 'en', translate to English so LLM reasons in English
|
| 133 |
translation = transcription
|
| 134 |
if detected_lang != "en":
|
| 135 |
try:
|
| 136 |
translation_resp = spitch_client.text.translate(text=transcription, source=detected_lang, target="en")
|
| 137 |
translation = getattr(translation_resp, "text", "") or translation_resp.get("text", "") if isinstance(translation_resp, dict) else translation
|
| 138 |
except Exception:
|
| 139 |
-
# if translation fails, fallback to transcription
|
| 140 |
translation = transcription
|
| 141 |
|
| 142 |
-
|
| 143 |
-
reply = stt_chain.run(speech=translation)
|
| 144 |
|
| 145 |
return {
|
| 146 |
"transcription": transcription,
|
| 147 |
"detected_language": detected_lang,
|
| 148 |
"translation": translation,
|
| 149 |
-
"reply": reply.strip()
|
| 150 |
}
|
| 151 |
|
| 152 |
@app.post("/autodoc")
|
| 153 |
def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
|
| 154 |
check_auth(authorization)
|
| 155 |
-
docs = autodoc_chain.
|
| 156 |
-
return {"documentation": docs.strip()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from fastapi import FastAPI, UploadFile, File, Header, HTTPException
|
| 5 |
from fastapi.middleware.cors import CORSMiddleware
|
| 6 |
from pydantic import BaseModel
|
| 7 |
+
from spitch import Spitch
|
| 8 |
from langchain.prompts import PromptTemplate
|
| 9 |
+
from langchain_huggingface import HuggingFaceEndpoint # β
updated import
|
|
|
|
| 10 |
from langdetect import detect, DetectorFactory
|
| 11 |
|
| 12 |
DetectorFactory.seed = 0
|
|
|
|
| 14 |
# --------- BASIC CONFIG ----------
|
| 15 |
SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
|
| 16 |
HF_MODEL = os.getenv("HF_MODEL", "google/flan-t5-base")
|
| 17 |
+
FRONTEND_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*")
|
| 18 |
+
PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "")
|
| 19 |
|
| 20 |
if not SPITCH_API_KEY:
|
| 21 |
raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
|
| 22 |
|
| 23 |
+
# Init Spitch
|
| 24 |
os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
|
| 25 |
spitch_client = Spitch()
|
| 26 |
|
| 27 |
+
# β
Use new HuggingFaceEndpoint instead of deprecated HuggingFaceHub
|
| 28 |
+
llm = HuggingFaceEndpoint(
|
| 29 |
+
repo_id=HF_MODEL,
|
| 30 |
+
temperature=0.2,
|
| 31 |
+
max_length=512
|
| 32 |
+
)
|
| 33 |
|
| 34 |
# FastAPI app
|
| 35 |
app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)")
|
| 36 |
|
| 37 |
+
# CORS
|
| 38 |
app.add_middleware(
|
| 39 |
CORSMiddleware,
|
| 40 |
allow_origins=[FRONTEND_ORIGIN] if FRONTEND_ORIGIN != "*" else ["*"],
|
|
|
|
| 63 |
Code: {code}
|
| 64 |
Documentation:"""
|
| 65 |
|
| 66 |
+
# β
Use RunnableSequence instead of LLMChain
|
| 67 |
+
chat_prompt = PromptTemplate(input_variables=["question"], template=chat_template)
|
| 68 |
+
stt_prompt = PromptTemplate(input_variables=["speech"], template=stt_chat_template)
|
| 69 |
+
autodoc_prompt = PromptTemplate(input_variables=["code"], template=autodoc_template)
|
| 70 |
+
|
| 71 |
+
chat_chain = chat_prompt | llm
|
| 72 |
+
stt_chain = stt_prompt | llm
|
| 73 |
+
autodoc_chain = autodoc_prompt | llm
|
| 74 |
|
| 75 |
# --------- REQUEST MODELS ----------
|
| 76 |
class ChatRequest(BaseModel):
|
|
|
|
| 82 |
# --------- AUTH ----------
|
| 83 |
def check_auth(authorization: str | None):
|
| 84 |
if not PROJECT_API_KEY:
|
| 85 |
+
return
|
| 86 |
if not authorization or not authorization.startswith("Bearer "):
|
| 87 |
raise HTTPException(status_code=401, detail="Missing bearer token")
|
| 88 |
token = authorization.split(" ", 1)[1]
|
|
|
|
| 97 |
@app.post("/chat")
|
| 98 |
def chat(req: ChatRequest, authorization: str | None = Header(None)):
|
| 99 |
check_auth(authorization)
|
| 100 |
+
answer = chat_chain.invoke({"question": req.question})
|
| 101 |
+
return {"reply": answer.strip() if isinstance(answer, str) else str(answer)}
|
| 102 |
|
|
|
|
| 103 |
@app.post("/stt")
|
| 104 |
async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
check_auth(authorization)
|
| 106 |
|
|
|
|
| 107 |
suffix = os.path.splitext(file.filename)[1] or ".wav"
|
| 108 |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
|
| 109 |
content = await file.read()
|
| 110 |
tf.write(content)
|
| 111 |
tmp_path = tf.name
|
| 112 |
|
|
|
|
|
|
|
| 113 |
try:
|
| 114 |
if lang_hint:
|
| 115 |
resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read())
|
| 116 |
else:
|
|
|
|
| 117 |
resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read())
|
| 118 |
+
except Exception:
|
|
|
|
| 119 |
resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
|
| 120 |
|
| 121 |
transcription = getattr(resp, "text", "") or resp.get("text", "") if isinstance(resp, dict) else ""
|
| 122 |
|
|
|
|
| 123 |
try:
|
| 124 |
detected_lang = detect(transcription) if transcription.strip() else "en"
|
| 125 |
except Exception:
|
| 126 |
detected_lang = "en"
|
| 127 |
|
|
|
|
| 128 |
translation = transcription
|
| 129 |
if detected_lang != "en":
|
| 130 |
try:
|
| 131 |
translation_resp = spitch_client.text.translate(text=transcription, source=detected_lang, target="en")
|
| 132 |
translation = getattr(translation_resp, "text", "") or translation_resp.get("text", "") if isinstance(translation_resp, dict) else translation
|
| 133 |
except Exception:
|
|
|
|
| 134 |
translation = transcription
|
| 135 |
|
| 136 |
+
reply = stt_chain.invoke({"speech": translation})
|
|
|
|
| 137 |
|
| 138 |
return {
|
| 139 |
"transcription": transcription,
|
| 140 |
"detected_language": detected_lang,
|
| 141 |
"translation": translation,
|
| 142 |
+
"reply": reply.strip() if isinstance(reply, str) else str(reply)
|
| 143 |
}
|
| 144 |
|
| 145 |
@app.post("/autodoc")
|
| 146 |
def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
|
| 147 |
check_auth(authorization)
|
| 148 |
+
docs = autodoc_chain.invoke({"code": req.code})
|
| 149 |
+
return {"documentation": docs.strip() if isinstance(docs, str) else str(docs)}
|
| 150 |
+
|
| 151 |
+
# β
Hugging Face requires port 7860, not 8000
|
| 152 |
+
if __name__ == "__main__":
|
| 153 |
+
import uvicorn
|
| 154 |
+
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)
|
requirements.txt
CHANGED
|
@@ -7,4 +7,5 @@ langchain-community
|
|
| 7 |
langdetect
|
| 8 |
httpx
|
| 9 |
huggingface_hub
|
| 10 |
-
python-multipart
|
|
|
|
|
|
| 7 |
langdetect
|
| 8 |
httpx
|
| 9 |
huggingface_hub
|
| 10 |
+
python-multipart
|
| 11 |
+
langchain-huggingface>=0.0.8
|