lydiasolomon commited on
Commit
0a3060f
·
verified ·
1 Parent(s): fdae2ad

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +109 -150
main.py CHANGED
@@ -1,107 +1,44 @@
1
  import os
2
  import tempfile
 
3
  import traceback
4
  from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Body
5
- from fastapi.middleware.cors import CORSMiddleware
6
  from pydantic import BaseModel
7
- from spitch import Spitch
8
  from transformers import pipeline
9
  from langdetect import detect, DetectorFactory
10
- from smebuilder_vector import retriever
11
-
12
- # ----------------- CONFIG -----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  DetectorFactory.seed = 0
14
-
15
- SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
16
- HF_MODEL = os.getenv("HF_MODEL", "deepseek-ai/deepseek-coder-1.3b-instruct")
17
- FRONTEND_ORIGIN = os.getenv("ALLOWED_ORIGIN", "*")
18
  PROJECT_API_KEY = os.getenv("PROJECT_API_KEY")
 
 
 
 
 
 
19
 
20
  if not SPITCH_API_KEY:
21
  raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
22
 
23
- os.environ["SPITCH_API_KEY"] = SPITCH_API_KEY
24
- spitch_client = Spitch()
25
-
26
- # ----------------- HUGGINGFACE PIPELINE -----------------
27
- llm_pipeline = pipeline(
28
- task="text-generation",
29
- model=HF_MODEL,
30
- temperature=0.7,
31
- top_p=0.9,
32
- do_sample=True,
33
- repetition_penalty=1.1,
34
- max_new_tokens=2048
35
- )
36
-
37
- # ----------------- FASTAPI -----------------
38
- app = FastAPI(title="DevAssist AI Backend (FastAPI + HuggingFace Pipeline)")
39
-
40
- app.add_middleware(
41
- CORSMiddleware,
42
- allow_origins=[FRONTEND_ORIGIN] if FRONTEND_ORIGIN != "*" else ["*"],
43
- allow_credentials=True,
44
- allow_methods=["GET", "POST", "OPTIONS"],
45
- allow_headers=["Authorization", "Content-Type"],
46
- )
47
-
48
- # ----------------- PROMPTS -----------------
49
- chat_template = """You are DevAssist, an AI coding assistant.
50
-
51
- Guidelines:
52
- - Always format responses in Markdown.
53
- - Use section headers: Explanation:, Steps:, Fixed Code:
54
- - Use bullet points for steps.
55
- - Use fenced code blocks for code.
56
- - Be friendly yet professional.
57
-
58
- Question: {question}
59
-
60
- Answer:
61
- """
62
-
63
- stt_chat_template = """You are DevAssist, an AI coding assistant.
64
- The input is transcribed speech. Interpret it as a developer question.
65
- Provide clear answers with code examples.
66
- If unclear, ask for clarification.
67
-
68
- Spoken Question: {speech}
69
- Answer:
70
- """
71
-
72
- autodoc_template = """You are DevAssist DocBot.
73
- Read the code and produce professional documentation in markdown.
74
-
75
- Code: {code}
76
- Documentation:
77
- """
78
-
79
- sme_template = """
80
- You are a senior full-stack engineer specializing in modern front-end development.
81
- Your job is to generate **production-ready code** for websites and apps.
82
-
83
- Guidelines:
84
- - Always return three separate files: index.html, styles.css, and script.js
85
- - HTML must be semantic, responsive, and mobile-first
86
- - CSS should use Flexbox/Grid with hover/transition effects
87
- - JavaScript must add interactivity (animations, toggles, button actions)
88
- - Include hero, feature grid, testimonials, and footer
89
- - Use realistic content (no lorem ipsum, no placeholders)
90
-
91
- Prompt: {user_prompt}
92
- Context: {context}
93
-
94
- Output:
95
- """
96
-
97
- # ----------------- REQUEST MODELS -----------------
98
- class ChatRequest(BaseModel):
99
- question: str
100
-
101
- class AutoDocRequest(BaseModel):
102
- code: str
103
-
104
- # ----------------- AUTH -----------------
105
  def check_auth(authorization: str | None):
106
  if not PROJECT_API_KEY:
107
  return
@@ -111,41 +48,73 @@ def check_auth(authorization: str | None):
111
  if token != PROJECT_API_KEY:
112
  raise HTTPException(status_code=403, detail="Invalid token")
113
 
114
- # ----------------- DEBUG LOGGING -----------------
115
- DEBUG_LOG_FILE = "llm_debug.log"
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- def run_pipeline(prompt_text: str):
118
- try:
119
- output_list = llm_pipeline(prompt_text, max_new_tokens=2048, do_sample=True)
120
- text = output_list[0]['generated_text'].strip()
121
 
122
- # Debug logging
123
- with open(DEBUG_LOG_FILE, "a", encoding="utf-8") as f:
124
- f.write("=== PROMPT START ===\n")
125
- f.write(prompt_text + "\n")
126
- f.write("--- MODEL OUTPUT ---\n")
127
- f.write(text + "\n")
128
- f.write("=== PROMPT END ===\n\n")
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  if not text:
131
- return {"success": False, "error": "⚠️ LLM returned empty output", "prompt": prompt_text}
132
  return text
133
- except Exception:
134
- with open(DEBUG_LOG_FILE, "a", encoding="utf-8") as f:
135
- f.write("=== PROMPT START ===\n")
136
- f.write(prompt_text + "\n")
137
- f.write("--- EXCEPTION ---\n")
138
- f.write(traceback.format_exc() + "\n")
139
- f.write("=== PROMPT END ===\n\n")
140
- return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc(), "prompt": prompt_text}
141
-
142
- # ----------------- AUDIO PROCESSING -----------------
143
  async def process_audio(file: UploadFile, lang_hint: str | None = None):
 
 
144
  suffix = os.path.splitext(file.filename)[1] or ".wav"
145
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
146
  tf.write(await file.read())
147
  tmp_path = tf.name
148
-
149
  with open(tmp_path, "rb") as f:
150
  audio_bytes = f.read()
151
 
@@ -174,60 +143,48 @@ async def process_audio(file: UploadFile, lang_hint: str | None = None):
174
 
175
  return transcription, detected_lang, translation
176
 
177
- # ----------------- ENDPOINTS -----------------
 
 
178
  @app.get("/")
179
- def root():
180
  return {"status": "✅ DevAssist AI Backend running"}
181
 
182
  @app.post("/chat")
183
- def chat(req: ChatRequest, authorization: str | None = Header(None)):
184
  check_auth(authorization)
185
- prompt_text = chat_template.format(question=req.question)
186
- result = run_pipeline(prompt_text)
187
  return result if isinstance(result, dict) else {"reply": result}
188
 
189
- @app.post("/stt")
190
- async def stt_audio(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
191
- check_auth(authorization)
192
- transcription, detected_lang, translation = await process_audio(file, lang_hint)
193
- prompt_text = stt_chat_template.format(speech=translation)
194
- result = run_pipeline(prompt_text)
195
- return {
196
- "transcription": transcription,
197
- "detected_language": detected_lang,
198
- "translation": translation,
199
- "reply": result if isinstance(result, str) else result.get("reply", "")
200
- }
201
-
202
  @app.post("/autodoc")
203
- def autodoc(req: AutoDocRequest, authorization: str | None = Header(None)):
204
  check_auth(authorization)
205
- prompt_text = autodoc_template.format(code=req.code)
206
- result = run_pipeline(prompt_text)
207
  return result if isinstance(result, dict) else {"documentation": result}
208
 
209
  @app.post("/sme/generate")
210
- async def sme_generate(payload: dict = Body(...), authorization: str | None = Header(None)):
211
  check_auth(authorization)
212
  try:
213
- user_prompt = payload.get("user_prompt", "")
214
- context_docs = retriever.get_relevant_documents(user_prompt)
215
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
216
- prompt_text = sme_template.format(user_prompt=user_prompt, context=context)
217
- result = run_pipeline(prompt_text)
218
  return {"success": True, "data": result if isinstance(result, str) else result.get("reply", "")}
219
- except Exception:
220
- return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc()}
221
 
222
  @app.post("/sme/speech-generate")
223
- async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
224
  check_auth(authorization)
225
  transcription, detected_lang, translation = await process_audio(file, lang_hint)
226
  try:
227
  context_docs = retriever.get_relevant_documents(translation)
228
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
229
- prompt_text = sme_template.format(user_prompt=translation, context=context)
230
- result = run_pipeline(prompt_text)
231
  return {
232
  "success": True,
233
  "transcription": transcription,
@@ -235,10 +192,12 @@ async def sme_speech_generate(file: UploadFile = File(...), lang_hint: str | Non
235
  "translation": translation,
236
  "sme_site": result if isinstance(result, str) else result.get("reply", "")
237
  }
238
- except Exception:
239
- return {"success": False, "error": "⚠️ LLM error", "details": traceback.format_exc()}
240
 
241
- # ----------------- MAIN -----------------
 
 
242
  if __name__ == "__main__":
243
  import uvicorn
244
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)
 
1
  import os
2
  import tempfile
3
+ import logging
4
  import traceback
5
  from fastapi import FastAPI, UploadFile, File, Header, HTTPException, Body
6
+ from fastapi.responses import JSONResponse
7
  from pydantic import BaseModel
 
8
  from transformers import pipeline
9
  from langdetect import detect, DetectorFactory
10
+ from PIL import Image
11
+ from smebuilder_vector import retriever # Your vector retrieval module
12
+
13
+ # ==============================
14
+ # Logging Setup
15
+ # ==============================
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger("DevAssist")
18
+
19
+ # ==============================
20
+ # App Init
21
+ # ==============================
22
+ app = FastAPI(title="DevAssist AI Backend")
23
+
24
+ # ==============================
25
+ # Config
26
+ # ==============================
27
  DetectorFactory.seed = 0
 
 
 
 
28
  PROJECT_API_KEY = os.getenv("PROJECT_API_KEY")
29
+ SPITCH_API_KEY = os.getenv("SPITCH_API_KEY")
30
+ HF_MODELS = {
31
+ "chat": "bigcode/starcoderbase",
32
+ "autodoc": "Salesforce/codegen-2B-mono",
33
+ "sme": "deepseek-ai/deepseek-coder-1.3b-instruct"
34
+ }
35
 
36
  if not SPITCH_API_KEY:
37
  raise RuntimeError("Set SPITCH_API_KEY in environment before starting.")
38
 
39
+ # ==============================
40
+ # Auth Check
41
+ # ==============================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def check_auth(authorization: str | None):
43
  if not PROJECT_API_KEY:
44
  return
 
48
  if token != PROJECT_API_KEY:
49
  raise HTTPException(status_code=403, detail="Invalid token")
50
 
51
+ # ==============================
52
+ # Global Exception Handler
53
+ # ==============================
54
+ @app.exception_handler(Exception)
55
+ async def global_exception_handler(request, exc: Exception):
56
+ logger.error(f"Unhandled error: {exc}")
57
+ return JSONResponse(status_code=500, content={"error": str(exc)})
58
+
59
+ # ==============================
60
+ # Request Models
61
+ # ==============================
62
+ class ChatRequest(BaseModel):
63
+ question: str
64
 
65
+ class AutoDocRequest(BaseModel):
66
+ code: str
 
 
67
 
68
+ class SMERequest(BaseModel):
69
+ user_prompt: str
 
 
 
 
 
70
 
71
+ # ==============================
72
+ # Pipeline Loader
73
+ # ==============================
74
+ def load_pipeline(task: str, model_name: str, fallback: str = None):
75
+ try:
76
+ return pipeline(task, model=model_name)
77
+ except Exception as e:
78
+ logger.warning(f"Failed to load {model_name}: {e}")
79
+ if fallback:
80
+ logger.info(f"Falling back to {fallback}")
81
+ return pipeline(task, model=fallback)
82
+ raise e
83
+
84
+ # ==============================
85
+ # Pipelines
86
+ # ==============================
87
+ chat_pipe = load_pipeline("text-generation", HF_MODELS["chat"], "gpt2")
88
+ autodoc_pipe = load_pipeline("text-generation", HF_MODELS["autodoc"], "gpt2")
89
+ sme_pipe = load_pipeline("text-generation", HF_MODELS["sme"], "gpt2")
90
+
91
+ # ==============================
92
+ # Helper Functions
93
+ # ==============================
94
+ def run_pipeline(pipe, prompt: str):
95
+ try:
96
+ output_list = pipe(prompt, max_new_tokens=1024, do_sample=True)
97
+ text = output_list[0].get("generated_text", "").strip() if isinstance(output_list, list) else str(output_list)
98
+
99
+ # Log prompt + output
100
+ logger.info(f"Prompt:\n{prompt}\n--- Output:\n{text}\n--- End")
101
  if not text:
102
+ return {"success": False, "error": "⚠️ LLM returned empty output", "prompt": prompt}
103
  return text
104
+ except Exception as e:
105
+ logger.error(f"Pipeline error: {e}")
106
+ return {"success": False, "error": f"⚠️ LLM error: {str(e)}", "prompt": prompt, "trace": traceback.format_exc()}
107
+
108
+ # ==============================
109
+ # Audio Processing Helper
110
+ # ==============================
 
 
 
111
  async def process_audio(file: UploadFile, lang_hint: str | None = None):
112
+ import spitch
113
+ spitch_client = spitch.Spitch()
114
  suffix = os.path.splitext(file.filename)[1] or ".wav"
115
  with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tf:
116
  tf.write(await file.read())
117
  tmp_path = tf.name
 
118
  with open(tmp_path, "rb") as f:
119
  audio_bytes = f.read()
120
 
 
143
 
144
  return transcription, detected_lang, translation
145
 
146
+ # ==============================
147
+ # Endpoints
148
+ # ==============================
149
  @app.get("/")
150
+ async def root_endpoint():
151
  return {"status": "✅ DevAssist AI Backend running"}
152
 
153
  @app.post("/chat")
154
+ async def chat_endpoint(req: ChatRequest, authorization: str | None = Header(None)):
155
  check_auth(authorization)
156
+ prompt = f"You are a professional coding assistant. Answer clearly:\nQuestion: {req.question}\nAnswer:"
157
+ result = run_pipeline(chat_pipe, prompt)
158
  return result if isinstance(result, dict) else {"reply": result}
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  @app.post("/autodoc")
161
+ async def autodoc_endpoint(req: AutoDocRequest, authorization: str | None = Header(None)):
162
  check_auth(authorization)
163
+ prompt = f"Generate professional documentation for the following code in Markdown:\n{req.code}\nDocumentation:"
164
+ result = run_pipeline(autodoc_pipe, prompt)
165
  return result if isinstance(result, dict) else {"documentation": result}
166
 
167
  @app.post("/sme/generate")
168
+ async def sme_generate_endpoint(req: SMERequest, authorization: str | None = Header(None)):
169
  check_auth(authorization)
170
  try:
171
+ context_docs = retriever.get_relevant_documents(req.user_prompt)
 
172
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
173
+ prompt = f"Generate production-ready frontend code based on this prompt:\n{req.user_prompt}\nContext:\n{context}\nOutput:"
174
+ result = run_pipeline(sme_pipe, prompt)
175
  return {"success": True, "data": result if isinstance(result, str) else result.get("reply", "")}
176
+ except Exception as e:
177
+ return {"success": False, "error": f"⚠️ LLM error: {str(e)}", "trace": traceback.format_exc()}
178
 
179
  @app.post("/sme/speech-generate")
180
+ async def sme_speech_endpoint(file: UploadFile = File(...), lang_hint: str | None = None, authorization: str | None = Header(None)):
181
  check_auth(authorization)
182
  transcription, detected_lang, translation = await process_audio(file, lang_hint)
183
  try:
184
  context_docs = retriever.get_relevant_documents(translation)
185
  context = "\n".join([doc.page_content for doc in context_docs]) if context_docs else "No extra context"
186
+ prompt = f"Generate production-ready frontend code based on this prompt:\n{translation}\nContext:\n{context}\nOutput:"
187
+ result = run_pipeline(sme_pipe, prompt)
188
  return {
189
  "success": True,
190
  "transcription": transcription,
 
192
  "translation": translation,
193
  "sme_site": result if isinstance(result, str) else result.get("reply", "")
194
  }
195
+ except Exception as e:
196
+ return {"success": False, "error": f"⚠️ LLM error: {str(e)}", "trace": traceback.format_exc()}
197
 
198
+ # ==============================
199
+ # Run App
200
+ # ==============================
201
  if __name__ == "__main__":
202
  import uvicorn
203
  uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False)