alaselababatunde commited on
Commit
f0502f7
Β·
1 Parent(s): 6facecb
Files changed (2) hide show
  1. main.py +33 -35
  2. requirements.txt +2 -1
main.py CHANGED
@@ -4,10 +4,9 @@ import tempfile
4
  from fastapi import FastAPI, UploadFile, File, Header, HTTPException
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
- from spitch import Spitch # Spitch Python SDK (docs use this pattern)
8
  from langchain.prompts import PromptTemplate
9
- from langchain.chains import LLMChain
10
- from langchain_community.llms import HuggingFaceHub
11
  from langdetect import detect, DetectorFactory
12
 
13
  DetectorFactory.seed = 0
@@ -15,23 +14,27 @@ DetectorFactory.seed = 0
15
  # --------- BASIC CONFIG ----------
16
  SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
17
  HF_MODEL = os.getenv("HF_MODEL", "google/flan-t5-base")
18
- FRONTEND_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*") # set to Vercel domain in production
19
- PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "") # simple bearer key for frontend -> backend auth
20
 
21
  if not SPITCH_API_KEY:
22
  raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
23
 
24
- # Init Spitch (SDK reads env var; docs show this pattern)
25
  os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
26
  spitch_client = Spitch()
27
 
28
- # Init LLM
29
- llm = HuggingFaceHub(repo_id=HF_MODEL, model_kwargs={"temperature": 0.2, "max_length": 512})
 
 
 
 
30
 
31
  # FastAPI app
32
  app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)")
33
 
34
- # CORS (allow only your Vercel domain in production)
35
  app.add_middleware(
36
  CORSMiddleware,
37
  allow_origins=[FRONTEND_ORIGIN] if FRONTEND_ORIGIN != "*" else ["*"],
@@ -60,9 +63,14 @@ autodoc_template = """You are DevAssist DocBot.
60
  Code: {code}
61
  Documentation:"""
62
 
63
- chat_chain = LLMChain(llm=llm, prompt=PromptTemplate(input_variables=["question"], template=chat_template))
64
- stt_chain = LLMChain(llm=llm, prompt=PromptTemplate(input_variables=["speech"], template=stt_chat_template))
65
- autodoc_chain = LLMChain(llm=llm, prompt=PromptTemplate(input_variables=["code"], template=autodoc_template))
 
 
 
 
 
66
 
67
  # --------- REQUEST MODELS ----------
68
  class ChatRequest(BaseModel):
@@ -74,7 +82,7 @@ class AutoDocRequest(BaseModel):
74
  # --------- AUTH ----------
75
  def check_auth(authorization: str | None):
76
  if not PROJECT_API_KEY:
77
- return # you can disable for local dev (but set in production)
78
  if not authorization or not authorization.startswith("Bearer "):
79
  raise HTTPException(status_code=401, detail="Missing bearer token")
80
  token = authorization.split(" ", 1)[1]
@@ -89,68 +97,58 @@ def root():
89
  @app.post("/chat")
90
  def chat(req: ChatRequest, authorization: str | None = Header(None)):
91
  check_auth(authorization)
92
- answer = chat_chain.run(question=req.question)
93
- return {"reply": answer.strip()}
94
 
95
- # Speech endpoint: full pipeline speech -> transcription -> translation (if needed) -> LLM
96
  @app.post("/stt")
97
  async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
98
- """
99
- POST /stt with form-data file=@audio.mp3
100
- Optional query/form field lang_hint: two-letter code (e.g. 'yo' for Yoruba) if frontend knows spoken language
101
- Returns: transcription, detected_language, translation (to en), reply
102
- """
103
  check_auth(authorization)
104
 
105
- # save uploaded file to temp file
106
  suffix = os.path.splitext(file.filename)[1] or ".wav"
107
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
108
  content = await file.read()
109
  tf.write(content)
110
  tmp_path = tf.name
111
 
112
- # 1) Transcribe using Spitch SDK (docs show client.speech.transcribe)
113
- # If lang_hint provided, pass it; else attempt without language param and fallback
114
  try:
115
  if lang_hint:
116
  resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read())
117
  else:
118
- # attempt transcribe without explicit language (SDK may auto-detect)
119
  resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read())
120
- except Exception as e:
121
- # fallback: try English transcription as last resort
122
  resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
123
 
124
  transcription = getattr(resp, "text", "") or resp.get("text", "") if isinstance(resp, dict) else ""
125
 
126
- # 2) Detect language of transcription if not provided
127
  try:
128
  detected_lang = detect(transcription) if transcription.strip() else "en"
129
  except Exception:
130
  detected_lang = "en"
131
 
132
- # 3) If detected_lang != 'en', translate to English so LLM reasons in English
133
  translation = transcription
134
  if detected_lang != "en":
135
  try:
136
  translation_resp = spitch_client.text.translate(text=transcription, source=detected_lang, target="en")
137
  translation = getattr(translation_resp, "text", "") or translation_resp.get("text", "") if isinstance(translation_resp, dict) else translation
138
  except Exception:
139
- # if translation fails, fallback to transcription
140
  translation = transcription
141
 
142
- # 4) Pass translated text to LLM (LLM assumes English)
143
- reply = stt_chain.run(speech=translation)
144
 
145
  return {
146
  "transcription": transcription,
147
  "detected_language": detected_lang,
148
  "translation": translation,
149
- "reply": reply.strip()
150
  }
151
 
152
  @app.post("/autodoc")
153
  def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
154
  check_auth(authorization)
155
- docs = autodoc_chain.run(code=req.code)
156
- return {"documentation": docs.strip()}
 
 
 
 
 
 
4
  from fastapi import FastAPI, UploadFile, File, Header, HTTPException
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
+ from spitch import Spitch
8
  from langchain.prompts import PromptTemplate
9
+ from langchain_huggingface import HuggingFaceEndpoint # βœ… updated import
 
10
  from langdetect import detect, DetectorFactory
11
 
12
  DetectorFactory.seed = 0
 
14
  # --------- BASIC CONFIG ----------
15
  SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
16
  HF_MODEL = os.getenv("HF_MODEL", "google/flan-t5-base")
17
+ FRONTEND_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*")
18
+ PROJECT_API_KEY = os.getenv("PROJECT_API_KEY", "")
19
 
20
  if not SPITCH_API_KEY:
21
  raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
22
 
23
+ # Init Spitch
24
  os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
25
  spitch_client = Spitch()
26
 
27
+ # βœ… Use new HuggingFaceEndpoint instead of deprecated HuggingFaceHub
28
+ llm = HuggingFaceEndpoint(
29
+ repo_id=HF_MODEL,
30
+ temperature=0.2,
31
+ max_length=512
32
+ )
33
 
34
  # FastAPI app
35
  app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)")
36
 
37
+ # CORS
38
  app.add_middleware(
39
  CORSMiddleware,
40
  allow_origins=[FRONTEND_ORIGIN] if FRONTEND_ORIGIN != "*" else ["*"],
 
63
  Code: {code}
64
  Documentation:"""
65
 
66
+ # βœ… Use RunnableSequence instead of LLMChain
67
+ chat_prompt = PromptTemplate(input_variables=["question"], template=chat_template)
68
+ stt_prompt = PromptTemplate(input_variables=["speech"], template=stt_chat_template)
69
+ autodoc_prompt = PromptTemplate(input_variables=["code"], template=autodoc_template)
70
+
71
+ chat_chain = chat_prompt | llm
72
+ stt_chain = stt_prompt | llm
73
+ autodoc_chain = autodoc_prompt | llm
74
 
75
  # --------- REQUEST MODELS ----------
76
  class ChatRequest(BaseModel):
 
82
  # --------- AUTH ----------
83
  def check_auth(authorization: str | None):
84
  if not PROJECT_API_KEY:
85
+ return
86
  if not authorization or not authorization.startswith("Bearer "):
87
  raise HTTPException(status_code=401, detail="Missing bearer token")
88
  token = authorization.split(" ", 1)[1]
 
97
  @app.post("/chat")
98
  def chat(req: ChatRequest, authorization: str | None = Header(None)):
99
  check_auth(authorization)
100
+ answer = chat_chain.invoke({"question": req.question})
101
+ return {"reply": answer.strip() if isinstance(answer, str) else str(answer)}
102
 
 
103
  @app.post("/stt")
104
  async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
 
 
 
 
 
105
  check_auth(authorization)
106
 
 
107
  suffix = os.path.splitext(file.filename)[1] or ".wav"
108
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
109
  content = await file.read()
110
  tf.write(content)
111
  tmp_path = tf.name
112
 
 
 
113
  try:
114
  if lang_hint:
115
  resp = spitch_client.speech.transcribe(language=lang_hint, content=open(tmp_path, "rb").read())
116
  else:
 
117
  resp = spitch_client.speech.transcribe(content=open(tmp_path, "rb").read())
118
+ except Exception:
 
119
  resp = spitch_client.speech.transcribe(language="en", content=open(tmp_path, "rb").read())
120
 
121
  transcription = getattr(resp, "text", "") or resp.get("text", "") if isinstance(resp, dict) else ""
122
 
 
123
  try:
124
  detected_lang = detect(transcription) if transcription.strip() else "en"
125
  except Exception:
126
  detected_lang = "en"
127
 
 
128
  translation = transcription
129
  if detected_lang != "en":
130
  try:
131
  translation_resp = spitch_client.text.translate(text=transcription, source=detected_lang, target="en")
132
  translation = getattr(translation_resp, "text", "") or translation_resp.get("text", "") if isinstance(translation_resp, dict) else translation
133
  except Exception:
 
134
  translation = transcription
135
 
136
+ reply = stt_chain.invoke({"speech": translation})
 
137
 
138
  return {
139
  "transcription": transcription,
140
  "detected_language": detected_lang,
141
  "translation": translation,
142
+ "reply": reply.strip() if isinstance(reply, str) else str(reply)
143
  }
144
 
145
  @app.post("/autodoc")
146
  def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
147
  check_auth(authorization)
148
+ docs = autodoc_chain.invoke({"code": req.code})
149
+ return {"documentation": docs.strip() if isinstance(docs, str) else str(docs)}
150
+
151
+ # βœ… Hugging Face requires port 7860, not 8000
152
+ if __name__ == "__main__":
153
+ import uvicorn
154
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)
requirements.txt CHANGED
@@ -7,4 +7,5 @@ langchain-community
7
  langdetect
8
  httpx
9
  huggingface_hub
10
- python-multipart
 
 
7
  langdetect
8
  httpx
9
  huggingface_hub
10
+ python-multipart
11
+ langchain-huggingface>=0.0.8