neuralbroker commited on
Commit
b0bcd7a
·
verified ·
1 Parent(s): eeb54ea

Upload server.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server.py +365 -0
server.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ BlitzKode backend server.
4
+
5
+ Serves the bundled frontend and proxies prompts to a local GGUF model
6
+ through llama.cpp. Model is loaded lazily so the module stays importable
7
+ in tests and environments where the model artifact is not present yet.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import asyncio
13
+ import json
14
+ import logging
15
+ import os
16
+ import time
17
+ from concurrent.futures import ThreadPoolExecutor
18
+ from contextlib import asynccontextmanager
19
+ from dataclasses import dataclass
20
+ from pathlib import Path
21
+ from threading import Lock
22
+ from typing import Iterator
23
+
24
+ import llama_cpp
25
+ import uvicorn
26
+ from fastapi import FastAPI, HTTPException, Request
27
+ from fastapi.middleware.cors import CORSMiddleware
28
+ from fastapi.responses import FileResponse, JSONResponse, StreamingResponse
29
+ from pydantic import BaseModel, Field
30
+
31
+ APP_NAME = "BlitzKode"
32
+ APP_VERSION = "2.0"
33
+ CREATOR = "Sajad"
34
+ ROOT_DIR = Path(__file__).resolve().parent
35
+ DEFAULT_MODEL_PATH = ROOT_DIR / "blitzkode.gguf"
36
+ DEFAULT_FRONTEND_PATH = ROOT_DIR / "frontend" / "index.html"
37
+ DEFAULT_CONTEXT = 2048
38
+ DEFAULT_MAX_PROMPT_LENGTH = 4000
39
+ DEFAULT_MAX_TOKENS = 512
40
+ STOP_TOKENS = ["<|im_end|>", "<|im_start|>user"]
41
+
42
+ SYSTEM_PROMPT = (
43
+ "<|im_start|>system\n"
44
+ "You are BlitzKode, an AI coding assistant created by Sajad. "
45
+ "You are an expert in Python, JavaScript, Java, C++, and other programming languages. "
46
+ "Write clean, efficient, and well-documented code. Keep responses concise and practical.<|im_end|>"
47
+ )
48
+
49
+ logger = logging.getLogger("blitzkode")
50
+
51
+
52
+ def _bool_from_env(name: str, default: bool = False) -> bool:
53
+ value = os.getenv(name)
54
+ if value is None:
55
+ return default
56
+ return value.strip().lower() in {"1", "true", "yes", "on"}
57
+
58
+
59
+ def _int_from_env(name: str, default: int) -> int:
60
+ value = os.getenv(name)
61
+ if not value:
62
+ return default
63
+ try:
64
+ return int(value)
65
+ except ValueError:
66
+ return default
67
+
68
+
69
+ def _validate_prompt(prompt: str, max_length: int) -> tuple[str, JSONResponse | None]:
70
+ prompt = prompt.strip()
71
+ if not prompt:
72
+ return prompt, JSONResponse({"error": "Prompt is required"}, status_code=400)
73
+ if len(prompt) > max_length:
74
+ return prompt, JSONResponse(
75
+ {"error": f"Prompt too long. Max {max_length} chars."},
76
+ status_code=400,
77
+ )
78
+ return prompt, None
79
+
80
+
81
+ @dataclass(slots=True)
82
+ class Settings:
83
+ root_dir: Path = ROOT_DIR
84
+ model_path: Path = Path(os.getenv("BLITZKODE_MODEL_PATH", DEFAULT_MODEL_PATH))
85
+ frontend_path: Path = Path(os.getenv("BLITZKODE_FRONTEND_PATH", DEFAULT_FRONTEND_PATH))
86
+ host: str = os.getenv("BLITZKODE_HOST", "0.0.0.0")
87
+ port: int = _int_from_env("BLITZKODE_PORT", 7860)
88
+ n_gpu_layers: int = _int_from_env("BLITZKODE_GPU_LAYERS", 0)
89
+ n_ctx: int = _int_from_env("BLITZKODE_N_CTX", DEFAULT_CONTEXT)
90
+ n_threads: int = _int_from_env("BLITZKODE_THREADS", max(1, min(8, os.cpu_count() or 1)))
91
+ n_batch: int = _int_from_env("BLITZKODE_BATCH", 128)
92
+ max_prompt_length: int = _int_from_env("BLITZKODE_MAX_PROMPT_LENGTH", DEFAULT_MAX_PROMPT_LENGTH)
93
+ preload_model: bool = _bool_from_env("BLITZKODE_PRELOAD_MODEL", default=False)
94
+ workers: int = _int_from_env("BLITZKODE_WORKERS", 2)
95
+ cors_origins: str = os.getenv("BLITZKODE_CORS_ORIGINS", "*")
96
+ api_key: str = os.getenv("BLITZKODE_API_KEY", "")
97
+
98
+
99
+ class MessageItem(BaseModel):
100
+ role: str
101
+ content: str
102
+
103
+
104
+ class GenerateRequest(BaseModel):
105
+ prompt: str
106
+ messages: list[MessageItem] = Field(default_factory=list)
107
+ temperature: float = Field(default=0.5, ge=0.0, le=2.0)
108
+ max_tokens: int = Field(default=256, ge=1, le=DEFAULT_MAX_TOKENS)
109
+ top_p: float = Field(default=0.95, gt=0.0, le=1.0)
110
+ top_k: int = Field(default=20, ge=1, le=200)
111
+ repeat_penalty: float = Field(default=1.05, ge=0.8, le=2.0)
112
+
113
+
114
+ class ModelService:
115
+ def __init__(self, settings: Settings):
116
+ self.settings = settings
117
+ self._llm = None
118
+ self._lock = Lock()
119
+ self._load_time_seconds: float | None = None
120
+ self._last_error: str | None = None
121
+
122
+ @property
123
+ def model_loaded(self) -> bool:
124
+ return self._llm is not None
125
+
126
+ @property
127
+ def model_exists(self) -> bool:
128
+ return self.settings.model_path.exists()
129
+
130
+ @property
131
+ def last_error(self) -> str | None:
132
+ return self._last_error
133
+
134
+ @property
135
+ def load_time_seconds(self) -> float | None:
136
+ return self._load_time_seconds
137
+
138
+ def load_model(self):
139
+ if self._llm is not None:
140
+ return self._llm
141
+
142
+ with self._lock:
143
+ if self._llm is not None:
144
+ return self._llm
145
+
146
+ if not self.model_exists:
147
+ self._last_error = f"Model not found at {self.settings.model_path}"
148
+ raise FileNotFoundError(self._last_error)
149
+
150
+ start_time = time.perf_counter()
151
+ try:
152
+ self._llm = llama_cpp.Llama(
153
+ model_path=str(self.settings.model_path),
154
+ n_gpu_layers=self.settings.n_gpu_layers,
155
+ n_ctx=self.settings.n_ctx,
156
+ n_threads=self.settings.n_threads,
157
+ n_batch=self.settings.n_batch,
158
+ verbose=False,
159
+ use_mmap=True,
160
+ use_mlock=False,
161
+ seed=-1,
162
+ )
163
+ self._load_time_seconds = time.perf_counter() - start_time
164
+ self._last_error = None
165
+ logger.info("Model loaded in %.2fs (gpu_layers=%d)", self._load_time_seconds, self.settings.n_gpu_layers)
166
+ except Exception as exc:
167
+ self._last_error = str(exc)
168
+ logger.error("Model load failed: %s", exc)
169
+ raise
170
+
171
+ return self._llm
172
+
173
+ def build_prompt(self, req: GenerateRequest) -> str:
174
+ parts = [SYSTEM_PROMPT]
175
+ for msg in req.messages:
176
+ if msg.role in ("user", "assistant"):
177
+ parts.append(f"<|im_start|>{msg.role}\n{msg.content}<|im_end|>")
178
+ parts.append(f"<|im_start|>user\n{req.prompt}<|im_end|>")
179
+ parts.append("<|im_start|>assistant\n")
180
+ return "\n".join(parts)
181
+
182
+ def _gen_params(self, req: GenerateRequest) -> dict:
183
+ return dict(
184
+ max_tokens=req.max_tokens,
185
+ temperature=req.temperature,
186
+ top_p=req.top_p,
187
+ top_k=req.top_k,
188
+ repeat_penalty=req.repeat_penalty,
189
+ frequency_penalty=0.0,
190
+ presence_penalty=0.0,
191
+ stop=STOP_TOKENS,
192
+ )
193
+
194
+ def generate_once(self, req: GenerateRequest) -> dict[str, object]:
195
+ llm = self.load_model()
196
+ start = time.perf_counter()
197
+
198
+ result = llm(self.build_prompt(req), **self._gen_params(req))
199
+ response = result["choices"][0]["text"].strip()
200
+ elapsed = time.perf_counter() - start
201
+ logger.info("Generated %d chars in %.2fs", len(response), elapsed)
202
+
203
+ return {"response": response, "creator": CREATOR, "model": APP_NAME, "version": APP_VERSION}
204
+
205
+ def stream_tokens(self, req: GenerateRequest) -> Iterator[str]:
206
+ llm = self.load_model()
207
+ start = time.perf_counter()
208
+ token_count = 0
209
+
210
+ try:
211
+ for token in llm(self.build_prompt(req), stream=True, **self._gen_params(req)):
212
+ if not token.get("choices"):
213
+ continue
214
+ text = token["choices"][0].get("text", "")
215
+ if text:
216
+ token_count += 1
217
+ yield f"data: {json.dumps({'token': text})}\n\n"
218
+ elapsed = time.perf_counter() - start
219
+ logger.info("Streamed %d tokens in %.2fs", token_count, elapsed)
220
+ yield "data: [DONE]\n\n"
221
+ except Exception as exc:
222
+ logger.error("Stream error: %s", exc)
223
+ yield f"data: {json.dumps({'error': str(exc)})}\n\n"
224
+
225
+
226
+ def _check_api_key(request: Request, settings: Settings) -> JSONResponse | None:
227
+ if not settings.api_key:
228
+ return None
229
+ auth = request.headers.get("Authorization", "")
230
+ token = auth[7:] if auth.startswith("Bearer ") else auth
231
+ if token != settings.api_key:
232
+ return JSONResponse({"error": "Unauthorized"}, status_code=401)
233
+ return None
234
+
235
+
236
+ def create_app(settings: Settings | None = None) -> FastAPI:
237
+ settings = settings or Settings()
238
+ model_service = ModelService(settings)
239
+ executor = ThreadPoolExecutor(max_workers=settings.workers)
240
+
241
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s", datefmt="%H:%M:%S")
242
+
243
+ @asynccontextmanager
244
+ async def lifespan(_: FastAPI):
245
+ if settings.preload_model:
246
+ try:
247
+ await asyncio.to_thread(model_service.load_model)
248
+ except Exception:
249
+ pass
250
+ try:
251
+ yield
252
+ finally:
253
+ executor.shutdown(wait=False, cancel_futures=True)
254
+
255
+ app = FastAPI(title=f"{APP_NAME} API", version=APP_VERSION, lifespan=lifespan)
256
+ app.state.settings = settings
257
+ app.state.model_service = model_service
258
+ app.state.executor = executor
259
+
260
+ cors_origins = [o.strip() for o in settings.cors_origins.split(",") if o.strip()]
261
+ app.add_middleware(CORSMiddleware, allow_origins=cors_origins, allow_methods=["*"], allow_headers=["*"])
262
+
263
+ @app.get("/")
264
+ async def root():
265
+ if not settings.frontend_path.exists():
266
+ raise HTTPException(status_code=404, detail="Frontend file is missing.")
267
+ return FileResponse(str(settings.frontend_path))
268
+
269
+ @app.get("/health")
270
+ async def health():
271
+ status = "healthy"
272
+ if not settings.frontend_path.exists() or not model_service.model_exists:
273
+ status = "degraded"
274
+ return JSONResponse({
275
+ "status": status,
276
+ "model_loaded": model_service.model_loaded,
277
+ "model_path": str(settings.model_path),
278
+ "model_exists": model_service.model_exists,
279
+ "frontend_exists": settings.frontend_path.exists(),
280
+ "version": APP_VERSION,
281
+ "gpu_layers": settings.n_gpu_layers,
282
+ "last_error": model_service.last_error,
283
+ })
284
+
285
+ @app.post("/generate")
286
+ async def generate(req: GenerateRequest, request: Request):
287
+ auth_err = _check_api_key(request, settings)
288
+ if auth_err:
289
+ return auth_err
290
+
291
+ prompt, err = _validate_prompt(req.prompt, settings.max_prompt_length)
292
+ if err:
293
+ return err
294
+
295
+ try:
296
+ sanitized = req.model_copy(update={"prompt": prompt})
297
+ payload = await asyncio.get_running_loop().run_in_executor(executor, model_service.generate_once, sanitized)
298
+ return JSONResponse(payload)
299
+ except FileNotFoundError as exc:
300
+ return JSONResponse({"error": str(exc)}, status_code=503)
301
+ except Exception as exc:
302
+ return JSONResponse({"error": str(exc)}, status_code=500)
303
+
304
+ @app.post("/generate/stream")
305
+ async def generate_stream(req: GenerateRequest, request: Request):
306
+ auth_err = _check_api_key(request, settings)
307
+ if auth_err:
308
+ return auth_err
309
+
310
+ prompt, err = _validate_prompt(req.prompt, settings.max_prompt_length)
311
+ if err:
312
+ return err
313
+
314
+ if not model_service.model_exists:
315
+ return JSONResponse({"error": f"Model not found at {settings.model_path}"}, status_code=503)
316
+
317
+ sanitized = req.model_copy(update={"prompt": prompt})
318
+ return StreamingResponse(
319
+ model_service.stream_tokens(sanitized),
320
+ media_type="text/event-stream",
321
+ headers={"Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no"},
322
+ )
323
+
324
+ @app.get("/info")
325
+ async def info():
326
+ return JSONResponse({
327
+ "name": APP_NAME,
328
+ "creator": CREATOR,
329
+ "version": APP_VERSION,
330
+ "status": "ready" if model_service.model_exists else "model-missing",
331
+ "mode": f"{'GPU' if settings.n_gpu_layers > 0 else 'CPU'} (llama.cpp)",
332
+ "gpu_layers": settings.n_gpu_layers,
333
+ "context_window": settings.n_ctx,
334
+ "model_loaded": model_service.model_loaded,
335
+ "load_time_seconds": model_service.load_time_seconds,
336
+ "endpoints": {
337
+ "generate": "POST /generate",
338
+ "stream": "POST /generate/stream",
339
+ "health": "GET /health",
340
+ "info": "GET /info",
341
+ },
342
+ })
343
+
344
+ return app
345
+
346
+
347
+ app = create_app()
348
+
349
+
350
+ def main() -> None:
351
+ s = app.state.settings
352
+ print(f"\n{'=' * 50}")
353
+ print(f"{APP_NAME.upper()} v{APP_VERSION}")
354
+ print(f"Creator: {CREATOR}")
355
+ print(f"{'=' * 50}")
356
+ print(f"Model: {s.model_path}")
357
+ print(f"GPU: {s.n_gpu_layers} layers")
358
+ print(f"Ctx: {s.n_ctx} | Threads: {s.n_threads} | Workers: {s.workers}")
359
+ print(f"URL: http://localhost:{s.port}\n")
360
+
361
+ uvicorn.run(app, host=s.host, port=s.port, log_level="warning")
362
+
363
+
364
+ if __name__ == "__main__":
365
+ main()