lydiasolomon commited on
Commit
9c16cbf
·
verified ·
1 Parent(s): 58a0d61

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +41 -32
main.py CHANGED
@@ -5,8 +5,7 @@ from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Body
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
  from spitch import Spitch
8
- from langchain.prompts import PromptTemplate
9
- from langchain_huggingface import HuggingFaceEndpoint
10
  from langdetect import detect, DetectorFactory
11
  from smebuilder_vector import retriever
12
 
@@ -24,9 +23,10 @@ if not SPITCH_API_KEY:
24
  os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
25
  spitch_client = Spitch()
26
 
27
- # HuggingFace LLM
28
- llm = HuggingFaceEndpoint(
29
- repo_id=HF_MODEL,
 
30
  temperature=0.7,
31
  top_p=0.9,
32
  do_sample=True,
@@ -35,7 +35,7 @@ llm = HuggingFaceEndpoint(
35
  )
36
 
37
  # ----------------- FASTAPI -----------------
38
- app = FastAPI(title="DevAssist AI Backend (FastAPI + LangChain)")
39
 
40
  app.add_middleware(
41
  CORSMiddleware,
@@ -45,7 +45,7 @@ app.add_middleware(
45
  allow_headers=["Authorization", "Content-Type"],
46
  )
47
 
48
- # ----------------- PROMPT TEMPLATES -----------------
49
  chat_template = """You are DevAssist, an AI coding assistant.
50
 
51
  Guidelines:
@@ -94,12 +94,6 @@ Context: {context}
94
  Output:
95
  """
96
 
97
- # ----------------- CHAINS -----------------
98
- chat_chain = PromptTemplate(input_variables=["question"], template=chat_template)
99
- stt_chain = PromptTemplate(input_variables=["speech"], template=stt_chat_template)
100
- autodoc_chain = PromptTemplate(input_variables=["code"], template=autodoc_template)
101
- sme_chain = PromptTemplate(input_variables=["user_prompt", "context"], template=sme_template)
102
-
103
  # ----------------- REQUEST MODELS -----------------
104
  class ChatRequest(BaseModel):
105
  question: str
@@ -117,20 +111,35 @@ def check_auth(authorization: str | None):
117
  if token != PROJECT_API_KEY:
118
  raise HTTPException(status_code=403, detail="Invalid token")
119
 
120
- # ----------------- HELPER FUNCTIONS -----------------
121
- def run_llm(prompt_text: str):
122
- """
123
- Directly run HuggingFaceEndpoint with string input.
124
- Returns text or error dict.
125
- """
126
  try:
127
- output = llm(prompt_text)
128
- if not output.strip():
 
 
 
 
 
 
 
 
 
 
129
  return {"success": False, "error": "⚠️ LLM returned empty output", "prompt": prompt_text}
130
- return output.strip()
131
  except Exception:
 
 
 
 
 
 
132
  return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc(), "prompt": prompt_text}
133
 
 
134
  async def process_audio(file: UploadFile, lang_hint: str | None = None):
135
  suffix = os.path.splitext(file.filename)[1] or ".wav"
136
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
@@ -173,16 +182,16 @@ def root():
173
  @app.post("/chat")
174
  def chat(req: ChatRequest, authorization: str | None = Header(None)):
175
  check_auth(authorization)
176
- prompt_text = chat_chain.format(question=req.question)
177
- result = run_llm(prompt_text)
178
  return result if isinstance(result, dict) else {"reply": result}
179
 
180
  @app.post("/stt")
181
  async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
182
  check_auth(authorization)
183
  transcription, detected_lang, translation = await process_audio(file, lang_hint)
184
- prompt_text = stt_chain.format(speech=translation)
185
- result = run_llm(prompt_text)
186
  return {
187
  "transcription": transcription,
188
  "detected_language": detected_lang,
@@ -193,8 +202,8 @@ async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None,
193
  @app.post("/autodoc")
194
  def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
195
  check_auth(authorization)
196
- prompt_text = autodoc_chain.format(code=req.code)
197
- result = run_llm(prompt_text)
198
  return result if isinstance(result, dict) else {"documentation": result}
199
 
200
  @app.post("/sme/generate")
@@ -204,8 +213,8 @@ async def sme_generate(payload: dict = Body(...), authorization: str | None = He
204
  user_prompt = payload.get("user_prompt", "")
205
  context_docs = retriever.get_relevant_documents(user_prompt)
206
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
207
- prompt_text = sme_chain.format(user_prompt=user_prompt, context=context)
208
- result = run_llm(prompt_text)
209
  return {"success": True, "data": result if isinstance(result, str) else result.get("reply", "")}
210
  except Exception:
211
  return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc()}
@@ -217,8 +226,8 @@ async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | Non
217
  try:
218
  context_docs = retriever.get_relevant_documents(translation)
219
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
220
- prompt_text = sme_chain.format(user_prompt=translation, context=context)
221
- result = run_llm(prompt_text)
222
  return {
223
  "success": True,
224
  "transcription": transcription,
 
5
  from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
  from spitch import Spitch
8
+ from transformers import pipeline
 
9
  from langdetect import detect, DetectorFactory
10
  from smebuilder_vector import retriever
11
 
 
23
  os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
24
  spitch_client = Spitch()
25
 
26
+ # ----------------- HUGGINGFACE PIPELINE -----------------
27
+ llm_pipeline = pipeline(
28
+ task="text-generation",
29
+ model=HF_MODEL,
30
  temperature=0.7,
31
  top_p=0.9,
32
  do_sample=True,
 
35
  )
36
 
37
  # ----------------- FASTAPI -----------------
38
+ app = FastAPI(title="DevAssist AI Backend (FastAPI + HuggingFace Pipeline)")
39
 
40
  app.add_middleware(
41
  CORSMiddleware,
 
45
  allow_headers=["Authorization", "Content-Type"],
46
  )
47
 
48
+ # ----------------- PROMPTS -----------------
49
  chat_template = """You are DevAssist, an AI coding assistant.
50
 
51
  Guidelines:
 
94
  Output:
95
  """
96
 
 
 
 
 
 
 
97
  # ----------------- REQUEST MODELS -----------------
98
  class ChatRequest(BaseModel):
99
  question: str
 
111
  if token != PROJECT_API_KEY:
112
  raise HTTPException(status_code=403, detail="Invalid token")
113
 
114
+ # ----------------- DEBUG LOGGING -----------------
115
+ DEBUG_LOG_FILE = "llm_debug.log"
116
+
117
+ def run_pipeline(prompt_text: str):
 
 
118
  try:
119
+ output_list = llm_pipeline(prompt_text, max_new_tokens=2048, do_sample=True)
120
+ text = output_list[0]['generated_text'].strip()
121
+
122
+ # Debug logging
123
+ with open(DEBUG_LOG_FILE, "a", encoding="utf-8") as f:
124
+ f.write("=== PROMPT START ===\n")
125
+ f.write(prompt_text + "\n")
126
+ f.write("--- MODEL OUTPUT ---\n")
127
+ f.write(text + "\n")
128
+ f.write("=== PROMPT END ===\n\n")
129
+
130
+ if not text:
131
  return {"success": False, "error": "⚠️ LLM returned empty output", "prompt": prompt_text}
132
+ return text
133
  except Exception:
134
+ with open(DEBUG_LOG_FILE, "a", encoding="utf-8") as f:
135
+ f.write("=== PROMPT START ===\n")
136
+ f.write(prompt_text + "\n")
137
+ f.write("--- EXCEPTION ---\n")
138
+ f.write(traceback.format_exc() + "\n")
139
+ f.write("=== PROMPT END ===\n\n")
140
  return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc(), "prompt": prompt_text}
141
 
142
+ # ----------------- AUDIO PROCESSING -----------------
143
  async def process_audio(file: UploadFile, lang_hint: str | None = None):
144
  suffix = os.path.splitext(file.filename)[1] or ".wav"
145
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
 
182
  @app.post("/chat")
183
  def chat(req: ChatRequest, authorization: str | None = Header(None)):
184
  check_auth(authorization)
185
+ prompt_text = chat_template.format(question=req.question)
186
+ result = run_pipeline(prompt_text)
187
  return result if isinstance(result, dict) else {"reply": result}
188
 
189
  @app.post("/stt")
190
  async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
191
  check_auth(authorization)
192
  transcription, detected_lang, translation = await process_audio(file, lang_hint)
193
+ prompt_text = stt_chat_template.format(speech=translation)
194
+ result = run_pipeline(prompt_text)
195
  return {
196
  "transcription": transcription,
197
  "detected_language": detected_lang,
 
202
  @app.post("/autodoc")
203
  def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
204
  check_auth(authorization)
205
+ prompt_text = autodoc_template.format(code=req.code)
206
+ result = run_pipeline(prompt_text)
207
  return result if isinstance(result, dict) else {"documentation": result}
208
 
209
  @app.post("/sme/generate")
 
213
  user_prompt = payload.get("user_prompt", "")
214
  context_docs = retriever.get_relevant_documents(user_prompt)
215
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
216
+ prompt_text = sme_template.format(user_prompt=user_prompt, context=context)
217
+ result = run_pipeline(prompt_text)
218
  return {"success": True, "data": result if isinstance(result, str) else result.get("reply", "")}
219
  except Exception:
220
  return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc()}
 
226
  try:
227
  context_docs = retriever.get_relevant_documents(translation)
228
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
229
+ prompt_text = sme_template.format(user_prompt=translation, context=context)
230
+ result = run_pipeline(prompt_text)
231
  return {
232
  "success": True,
233
  "transcription": transcription,