Kai Izumoto commited on
Commit
b6e2eb2
Β·
verified Β·
1 Parent(s): aac47db

Create supercoder.py

Browse files
Files changed (1) hide show
  1. supercoder.py +412 -0
supercoder.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SuperCoder - Unified Application
3
+ All-in-one file containing Gradio UI, API server, tunnel support, and AI logic.
4
+ """
5
+ import os
6
+ import sys
7
+ import time
8
+ import uuid
9
+ import argparse
10
+ import subprocess
11
+ import traceback
12
+ import requests
13
+ import json
14
+ from pathlib import Path
15
+ from typing import Optional, List, Dict, Any, Generator, Tuple
16
+ from collections import defaultdict
17
+ from functools import partial
18
+ from multiprocessing import Process
19
+
20
+ import gradio as gr
21
+ from fastapi import FastAPI, HTTPException
22
+ from fastapi.middleware.cors import CORSMiddleware
23
+ from pydantic import BaseModel
24
+ import uvicorn
25
+
26
+ # Import config (only external dependency)
27
+ from config import *
28
+
29
+ # ============================================================================
30
+ # SERVER MANAGER - llama.cpp server lifecycle
31
+ # ============================================================================
32
+ _server_process = None
33
+ _server_info = {}
34
+
35
+ def check_server_health() -> bool:
36
+ try:
37
+ # Check if Ollama is responding
38
+ response = requests.get(f"{LLAMA_SERVER_URL}/api/tags", timeout=2)
39
+ return response.status_code == 200 and len(response.json().get("models", [])) > 0
40
+ except:
41
+ return False
42
+
43
+ def start_llama_server() -> bool:
44
+ global _server_process, _server_info
45
+
46
+ if _server_process and check_server_health():
47
+ return True
48
+
49
+ print(f"\nπŸš€ Starting llama.cpp server on {LLAMA_SERVER_URL}")
50
+
51
+ try:
52
+ cmd = [
53
+ LLAMA_SERVER_PATH, "-hf", LLAMA_MODEL,
54
+ "-c", str(MODEL_CONTEXT_WINDOW),
55
+ "-t", str(MODEL_THREADS),
56
+ "-ngl", str(MODEL_GPU_LAYERS),
57
+ "--host", LLAMA_SERVER_HOST, "--port", str(LLAMA_SERVER_PORT)
58
+ ]
59
+
60
+ _server_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
61
+ _server_info = {'pid': _server_process.pid, 'url': LLAMA_SERVER_URL}
62
+
63
+ # Wait for ready
64
+ for _ in range(SERVER_STARTUP_TIMEOUT * 2):
65
+ if check_server_health():
66
+ print(f"βœ… Server ready (PID: {_server_process.pid})")
67
+ return True
68
+ time.sleep(0.5)
69
+
70
+ return False
71
+ except Exception as e:
72
+ print(f"❌ Server start failed: {e}")
73
+ return False
74
+
75
+ def stop_llama_server():
76
+ global _server_process
77
+ if _server_process:
78
+ _server_process.terminate()
79
+ _server_process.wait()
80
+ _server_process = None
81
+
82
+ def get_llm():
83
+ return True if check_server_health() else None
84
+
85
+ def get_model_info():
86
+ return _server_info.copy()
87
+
88
+ # ============================================================================
89
+ # SESSION MANAGER - Chat history
90
+ # ============================================================================
91
+ SESSION_STORE = {}
92
+ SESSION_METADATA = defaultdict(dict)
93
+
94
+ def get_session_id(request: gr.Request) -> str:
95
+ return request.session_hash
96
+
97
+ def get_history(session_id: str, create_if_missing: bool = False) -> List[Dict]:
98
+ if session_id not in SESSION_STORE and create_if_missing:
99
+ SESSION_STORE[session_id] = []
100
+ return SESSION_STORE.get(session_id, [])
101
+
102
+ def add_to_history(session_id: str, role: str, text: str):
103
+ history = get_history(session_id, create_if_missing=True)
104
+ history.append({"role": role, "text": text, "timestamp": time.time()})
105
+
106
+ def clear_history(session_id: str):
107
+ if session_id in SESSION_STORE:
108
+ SESSION_STORE[session_id] = []
109
+
110
+ def convert_history_to_gradio_messages(history: List[Dict]) -> List[Dict]:
111
+ return [{"role": msg["role"], "content": msg["text"]} for msg in history]
112
+
113
+ def calculate_safe_max_tokens(history: List[Dict], requested: int, max_context: int) -> int:
114
+ history_chars = sum(len(msg["text"]) for msg in history)
115
+ estimated_tokens = history_chars // 4
116
+ available = max_context - estimated_tokens - SYSTEM_OVERHEAD_TOKENS
117
+ return max(min(requested, available, SAFE_MAX_TOKENS_CAP), MIN_TOKENS)
118
+
119
+ def get_recent_history(session_id: str, max_messages: int = 10) -> List[Dict]:
120
+ history = get_history(session_id)
121
+ return history[-max_messages:] if len(history) > max_messages else history
122
+
123
+ def update_session_activity(session_id: str):
124
+ SESSION_METADATA[session_id]['last_activity'] = time.time()
125
+
126
+ # ============================================================================
127
+ # GENERATION - AI response generation
128
+ # ============================================================================
129
+ def generate_response_stream(session_id: str, user_message: str, max_tokens: int,
130
+ temperature: float, stream: bool = True) -> Generator[str, None, None]:
131
+ if not get_llm():
132
+ yield "⚠️ Server not running"
133
+ return
134
+
135
+ update_session_activity(session_id)
136
+ recent_history = get_recent_history(session_id, max_messages=6)
137
+ safe_tokens = calculate_safe_max_tokens(recent_history, max_tokens, MODEL_CONTEXT_WINDOW)
138
+
139
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
140
+ for msg in recent_history:
141
+ messages.append({"role": msg["role"], "content": msg["text"]})
142
+ messages.append({"role": "user", "content": user_message})
143
+
144
+ try:
145
+ payload = {
146
+ "messages": messages, "max_tokens": safe_tokens,
147
+ "temperature": max(0.01, temperature),
148
+ "top_p": DEFAULT_TOP_P, "stream": stream
149
+ }
150
+
151
+ if stream:
152
+ response = requests.post(f"{LLAMA_SERVER_URL}/v1/chat/completions",
153
+ json=payload, stream=True, timeout=300)
154
+ full_response = ""
155
+ for line in response.iter_lines():
156
+ if line:
157
+ line_text = line.decode('utf-8')
158
+ if line_text.startswith('data: '):
159
+ line_text = line_text[6:]
160
+ if line_text.strip() == '[DONE]':
161
+ break
162
+ try:
163
+ chunk = json.loads(line_text)
164
+ content = chunk.get("choices", [{}])[0].get("delta", {}).get("content", "")
165
+ if content:
166
+ full_response += content
167
+ yield full_response.strip()
168
+ except:
169
+ continue
170
+ else:
171
+ # Use Ollama API format instead of OpenAI format
172
+ ollama_payload = {
173
+ "model": LLAMA_MODEL,
174
+ "messages": messages,
175
+ "stream": False
176
+ }
177
+ response = requests.post(f"{LLAMA_SERVER_URL}/api/chat",
178
+ json=ollama_payload, timeout=300)
179
+ yield response.json()["message"]["content"].strip()
180
+
181
+ except Exception as e:
182
+ yield f"⚠️ Error: {str(e)}"
183
+
184
+ # ============================================================================
185
+ # GRADIO UI COMPONENTS
186
+ # ============================================================================
187
+ def create_gradio_interface(error_msg: Optional[str] = None):
188
+ with gr.Blocks(title=APP_TITLE, theme=gr.themes.Soft(primary_hue=PRIMARY_HUE)) as demo:
189
+ gr.Markdown(f"# πŸ€– {APP_TITLE}\n### {APP_DESCRIPTION}\n---")
190
+
191
+ if error_msg:
192
+ gr.Markdown(f"⚠️ {error_msg}")
193
+
194
+ with gr.Row():
195
+ with gr.Column(scale=3):
196
+ chatbot = gr.Chatbot(label="πŸ’¬ Conversation", height=CHAT_HEIGHT,
197
+ type="messages", show_copy_button=True)
198
+ with gr.Row():
199
+ txt_input = gr.Textbox(placeholder="Ask me about code...",
200
+ show_label=False, scale=5, lines=2)
201
+ send_btn = gr.Button("Send πŸš€", scale=1, variant="primary")
202
+
203
+ with gr.Column(scale=1):
204
+ gr.Markdown("### βš™οΈ Settings")
205
+ temp_slider = gr.Slider(0.0, 1.0, value=DEFAULT_TEMPERATURE, step=0.05,
206
+ label="🌑️ Temperature")
207
+ tokens_slider = gr.Slider(MIN_TOKENS, SAFE_MAX_TOKENS_CAP,
208
+ value=DEFAULT_MAX_TOKENS, step=128, label="πŸ“ Max Tokens")
209
+ stream_checkbox = gr.Checkbox(label="⚑ Stream", value=True)
210
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="stop", size="sm")
211
+
212
+ session_state = gr.State()
213
+
214
+ # Event handlers
215
+ def handle_message(session_id, msg, temp, tokens, stream, request: gr.Request):
216
+ session_id = session_id or get_session_id(request)
217
+ if not msg.strip():
218
+ return session_id, convert_history_to_gradio_messages(get_history(session_id)), ""
219
+
220
+ add_to_history(session_id, "user", msg)
221
+ yield session_id, convert_history_to_gradio_messages(get_history(session_id)), ""
222
+
223
+ full_response = ""
224
+ for partial in generate_response_stream(session_id, msg, tokens, temp, stream):
225
+ full_response = partial
226
+ temp_hist = get_history(session_id).copy()
227
+ temp_hist.append({"role": "assistant", "text": full_response})
228
+ yield session_id, convert_history_to_gradio_messages(temp_hist), ""
229
+
230
+ add_to_history(session_id, "assistant", full_response)
231
+ yield session_id, convert_history_to_gradio_messages(get_history(session_id)), ""
232
+
233
+ def handle_clear(session_id, request: gr.Request):
234
+ session_id = session_id or get_session_id(request)
235
+ clear_history(session_id)
236
+ return session_id, [], ""
237
+
238
+ txt_input.submit(handle_message, [session_state, txt_input, temp_slider, tokens_slider, stream_checkbox],
239
+ [session_state, chatbot, txt_input])
240
+ send_btn.click(handle_message, [session_state, txt_input, temp_slider, tokens_slider, stream_checkbox],
241
+ [session_state, chatbot, txt_input])
242
+ clear_btn.click(handle_clear, [session_state], [session_state, chatbot, txt_input])
243
+
244
+ return demo
245
+
246
+ # ============================================================================
247
+ # FASTAPI SERVER
248
+ # ============================================================================
249
+ api_app = FastAPI(title="SuperCoder API")
250
+ api_app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
251
+
252
+ api_sessions = {}
253
+
254
+ class ChatMessage(BaseModel):
255
+ role: str
256
+ content: str
257
+
258
+ class ChatRequest(BaseModel):
259
+ messages: List[ChatMessage]
260
+ temperature: Optional[float] = 0.1
261
+ max_tokens: Optional[int] = 512
262
+
263
+ class ChatResponse(BaseModel):
264
+ response: str
265
+ session_id: str
266
+
267
+ @api_app.get("/health")
268
+ async def health():
269
+ return {"status": "ok" if get_llm() else "model_not_loaded"}
270
+
271
+ @api_app.post("/api/chat", response_model=ChatResponse)
272
+ async def chat(request: ChatRequest):
273
+ if not get_llm():
274
+ raise HTTPException(503, "Model not loaded")
275
+
276
+ session_id = str(uuid.uuid4())
277
+ api_sessions[session_id] = []
278
+
279
+ user_message = request.messages[-1].content
280
+ api_sessions[session_id].append({"role": "user", "text": user_message})
281
+
282
+ full_response = ""
283
+ for partial in generate_response_stream(session_id, user_message, request.max_tokens,
284
+ request.temperature, False):
285
+ full_response = partial
286
+
287
+ api_sessions[session_id].append({"role": "assistant", "text": full_response})
288
+ return ChatResponse(response=full_response, session_id=session_id)
289
+
290
+ def run_api_server():
291
+ uvicorn.run(api_app, host="0.0.0.0", port=8000, log_level="info")
292
+
293
+ # ============================================================================
294
+ # TUNNEL SUPPORT
295
+ # ============================================================================
296
+ def start_ngrok_tunnel(port: int = 8000) -> Optional[str]:
297
+ try:
298
+ subprocess.run(["which", "ngrok"], capture_output=True, check=True)
299
+ subprocess.Popen(["ngrok", "http", str(port)], stdout=subprocess.PIPE)
300
+ time.sleep(3)
301
+
302
+ response = requests.get("http://127.0.0.1:4040/api/tunnels", timeout=5)
303
+ tunnels = response.json()
304
+ if tunnels.get("tunnels"):
305
+ url = tunnels["tunnels"][0]["public_url"]
306
+ print(f"βœ… Tunnel: {url}")
307
+ return url
308
+ except:
309
+ print("❌ ngrok not found. Install: brew install ngrok")
310
+ return None
311
+
312
+ def start_cloudflare_tunnel(port: int = 8000) -> Optional[str]:
313
+ try:
314
+ subprocess.run(["which", "cloudflared"], capture_output=True, check=True)
315
+ proc = subprocess.Popen(["cloudflared", "tunnel", "--url", f"http://localhost:{port}"],
316
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
317
+ time.sleep(3)
318
+
319
+ for _ in range(30):
320
+ line = proc.stdout.readline()
321
+ if "trycloudflare.com" in line:
322
+ import re
323
+ urls = re.findall(r'https://[^\s]+\.trycloudflare\.com', line)
324
+ if urls:
325
+ print(f"βœ… Tunnel: {urls[0]}")
326
+ return urls[0]
327
+ time.sleep(1)
328
+ except:
329
+ print("❌ cloudflared not found. Install: brew install cloudflared")
330
+ return None
331
+
332
+ # ============================================================================
333
+ # MAIN LAUNCHER
334
+ # ============================================================================
335
+ def main():
336
+ parser = argparse.ArgumentParser(description="SuperCoder - All-in-One AI Coding Assistant")
337
+ parser.add_argument("--mode", choices=["gradio", "api", "both"], default="gradio",
338
+ help="Run mode: gradio (UI), api (server), or both")
339
+ parser.add_argument("--tunnel", choices=["ngrok", "cloudflare"],
340
+ help="Start tunnel for public access")
341
+ parser.add_argument("--no-server", action="store_true",
342
+ help="Don't start llama.cpp server (assume already running)")
343
+
344
+ args = parser.parse_args()
345
+
346
+ print("╔════════════════════════════════════════════════╗")
347
+ print("β•‘ SuperCoder - Unified Launcher β•‘")
348
+ print("β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•")
349
+
350
+ # Start llama.cpp server
351
+ if not args.no_server:
352
+ success = start_llama_server()
353
+ error_msg = None if success else "Failed to start llama.cpp server"
354
+ else:
355
+ error_msg = None
356
+
357
+ # Run selected mode
358
+ if args.mode == "gradio":
359
+ print(f"\nπŸ“Œ Mode: Gradio UI\n🌐 Access: http://localhost:{SERVER_PORT}\n")
360
+ demo = create_gradio_interface(error_msg)
361
+ demo.launch(server_name=SERVER_NAME, server_port=SERVER_PORT)
362
+
363
+ elif args.mode == "api":
364
+ print(f"\nπŸ“Œ Mode: API Server\nπŸ“‘ API: http://localhost:8000/api/chat\n")
365
+
366
+ if args.tunnel:
367
+ api_proc = Process(target=run_api_server)
368
+ api_proc.start()
369
+ time.sleep(3)
370
+
371
+ if args.tunnel == "ngrok":
372
+ start_ngrok_tunnel(8000)
373
+ else:
374
+ start_cloudflare_tunnel(8000)
375
+
376
+ try:
377
+ api_proc.join()
378
+ except KeyboardInterrupt:
379
+ api_proc.terminate()
380
+ else:
381
+ run_api_server()
382
+
383
+ elif args.mode == "both":
384
+ print(f"\nπŸ“Œ Mode: Both Gradio + API\n🎨 UI: http://localhost:{SERVER_PORT}\nπŸ“‘ API: http://localhost:8000\n")
385
+
386
+ gradio_proc = Process(target=lambda: create_gradio_interface(error_msg).launch(
387
+ server_name=SERVER_NAME, server_port=SERVER_PORT))
388
+ api_proc = Process(target=run_api_server)
389
+
390
+ gradio_proc.start()
391
+ api_proc.start()
392
+
393
+ if args.tunnel:
394
+ time.sleep(3)
395
+ if args.tunnel == "ngrok":
396
+ start_ngrok_tunnel(8000)
397
+ else:
398
+ start_cloudflare_tunnel(8000)
399
+
400
+ try:
401
+ gradio_proc.join()
402
+ api_proc.join()
403
+ except KeyboardInterrupt:
404
+ gradio_proc.terminate()
405
+ api_proc.terminate()
406
+
407
+ if __name__ == "__main__":
408
+ try:
409
+ main()
410
+ except KeyboardInterrupt:
411
+ print("\nπŸ‘‹ Shutting down...")
412
+ stop_llama_server()