omaryasserhassan commited on
Commit
ccef136
·
verified ·
1 Parent(s): 33f19bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -59
app.py CHANGED
@@ -1,26 +1,43 @@
 
1
  import os
2
  import time
 
 
 
3
  from fastapi import FastAPI, HTTPException
4
- from fastapi.responses import StreamingResponse, JSONResponse
5
- from pydantic import BaseModel
6
  from huggingface_hub import hf_hub_download
7
  from llama_cpp import Llama
8
 
9
- # ---------------- Config ----------------
10
- REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF"
11
- FILENAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf"
12
- CACHE_DIR = "/app/models" # match your Dockerfile prefetch if you use it
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Conservative CPU settings for Spaces (prevents stalls)
15
- N_THREADS = min(4, (os.cpu_count() or 2)) # don't over-thread on tiny CPUs
16
- N_BATCH = 64 # modest batch to avoid RAM thrash
17
- N_CTX = 2048 # enough for short prompts
18
 
19
- # --------------- FastAPI App ---------------
20
- app = FastAPI(title="Llama 3.2 3B Instruct (llama.cpp) API")
21
- _model = None
 
22
 
23
- # --------------- Load Model ---------------
24
  def get_model() -> Llama:
25
  global _model
26
  if _model is not None:
@@ -34,30 +51,43 @@ def get_model() -> Llama:
34
  local_dir_use_symlinks=False,
35
  )
36
 
37
- # IMPORTANT: use Llama-3 chat template
38
  _model = Llama(
39
  model_path=local_path,
40
- chat_format="llama-3", # <- ensures proper prompt templating
41
  n_ctx=N_CTX,
42
  n_threads=N_THREADS,
43
  n_batch=N_BATCH,
44
- verbose=False
45
  )
46
  return _model
47
 
48
- # --------------- Schemas ----------------
49
- class ChatMessage(BaseModel):
50
- role: str # "system" | "user" | "assistant"
51
- content: str
52
 
53
- class ChatRequest(BaseModel):
54
- messages: list[ChatMessage]
55
- max_tokens: int = 128
56
- temperature: float = 0.7
57
- top_p: float = 0.9
58
- stream: bool = False
59
 
60
- # --------------- Endpoints ---------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  @app.get("/health")
62
  def health():
63
  try:
@@ -66,46 +96,74 @@ def health():
66
  except Exception as e:
67
  return {"ok": False, "error": str(e)}
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  @app.post("/generate")
70
- def generate(req: ChatRequest):
71
  """
72
- Chat-completion endpoint with optional server-side streaming.
73
- Uses Llama-3 chat template via chat_format="llama-3".
74
  """
75
  try:
 
 
 
76
  model = get_model()
77
 
78
- # Convert to llama.cpp message format
79
- msgs = [{"role": m.role, "content": m.content} for m in req.messages]
80
 
81
- if not req.stream:
 
 
 
 
 
 
 
82
  out = model.create_chat_completion(
83
- messages=msgs,
84
- max_tokens=req.max_tokens,
85
- temperature=req.temperature,
86
- top_p=req.top_p,
 
87
  )
88
- text = out["choices"][0]["message"]["content"]
89
- return JSONResponse({"ok": True, "response": text})
90
-
91
- # --- Streaming mode ---
92
- def token_stream():
93
- start = time.time()
94
- for chunk in model.create_chat_completion(
95
- messages=msgs,
96
- max_tokens=req.max_tokens,
97
- temperature=req.temperature,
98
- top_p=req.top_p,
99
- stream=True,
100
- ):
101
- if "choices" in chunk and chunk["choices"]:
102
- delta = chunk["choices"][0]["delta"].get("content", "")
103
- if delta:
104
- yield delta
105
- # small trailer to mark end (optional)
106
- yield f"\n\n[done in {time.time()-start:.2f}s]"
107
-
108
- return StreamingResponse(token_stream(), media_type="text/plain")
109
 
 
 
110
  except Exception as e:
111
  raise HTTPException(status_code=500, detail=str(e))
 
1
+ # app.py
2
  import os
3
  import time
4
+ import threading
5
+ from typing import Optional
6
+
7
  from fastapi import FastAPI, HTTPException
8
+ from fastapi.responses import JSONResponse
9
+ from pydantic import BaseModel, Field
10
  from huggingface_hub import hf_hub_download
11
  from llama_cpp import Llama
12
 
13
+ # ---------------- Config (fixed defaults; can be overridden by env) ----------------
14
+ REPO_ID = os.getenv("REPO_ID", "bartowski/Llama-3.2-3B-Instruct-GGUF")
15
+ FILENAME = os.getenv("FILENAME", "Llama-3.2-3B-Instruct-Q4_K_M.gguf")
16
+ CACHE_DIR = os.getenv("CACHE_DIR", "/app/models")
17
+
18
+ # Inference knobs (fixed for the Space; override via env only if needed)
19
+ N_THREADS = int(os.getenv("N_THREADS", str(min(4, (os.cpu_count() or 2)))))
20
+ N_BATCH = int(os.getenv("N_BATCH", "64"))
21
+ N_CTX = int(os.getenv("N_CTX", "2048"))
22
+
23
+ # Fixed sampling
24
+ MAX_TOKENS = int(os.getenv("MAX_TOKENS", "256"))
25
+ TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7"))
26
+ TOP_P = float(os.getenv("TOP_P", "0.9"))
27
+ STOP_TOKENS = os.getenv("STOP_TOKENS", "</s>,<|eot_id|>").split(",")
28
+
29
+ # System prompt (optional). Leave empty for pure user prompt.
30
+ SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "").strip()
31
 
32
+ # Safety margin for context budgeting (prompt + completion + overhead <= N_CTX)
33
+ CTX_SAFETY = int(os.getenv("CTX_SAFETY", "128"))
 
 
34
 
35
+ # ---------------- App scaffolding ----------------
36
+ app = FastAPI(title="Llama 3.2 3B Instruct (llama.cpp) API - Prompt Only")
37
+ _model: Optional[Llama] = None
38
+ _model_lock = threading.Lock()
39
 
40
+ # ---------------- Model loader ----------------
41
  def get_model() -> Llama:
42
  global _model
43
  if _model is not None:
 
51
  local_dir_use_symlinks=False,
52
  )
53
 
 
54
  _model = Llama(
55
  model_path=local_path,
56
+ chat_format="llama-3", # ensures proper Llama-3 prompt templating
57
  n_ctx=N_CTX,
58
  n_threads=N_THREADS,
59
  n_batch=N_BATCH,
60
+ verbose=False,
61
  )
62
  return _model
63
 
64
+ @app.on_event("startup")
65
+ def _warm():
66
+ # Preload to avoid cold-start on first request
67
+ get_model()
68
 
69
+ # ---------------- Schemas ----------------
70
+ class GenerateRequest(BaseModel):
71
+ prompt: str = Field(..., description="User prompt text only.")
 
 
 
72
 
73
+ # ---------------- Helpers ----------------
74
+ def _fit_prompt_to_context(model: Llama, prompt: str) -> str:
75
+ """
76
+ Simple context budgeting: ensures tokens(prompt) + MAX_TOKENS + CTX_SAFETY <= N_CTX.
77
+ If over budget, we truncate the prompt from the start (keep the tail).
78
+ """
79
+ toks = model.tokenize(prompt.encode("utf-8"))
80
+ budget = max(256, N_CTX - MAX_TOKENS - CTX_SAFETY) # keep some minimal room
81
+ if len(toks) <= budget:
82
+ return prompt
83
+ # Truncate from the front (keep the latest part)
84
+ kept = model.detokenize(toks[-budget:])
85
+ try:
86
+ return kept.decode("utf-8", errors="ignore")
87
+ except Exception:
88
+ return kept.decode("utf-8", "ignore")
89
+
90
+ # ---------------- Endpoints ----------------
91
  @app.get("/health")
92
  def health():
93
  try:
 
96
  except Exception as e:
97
  return {"ok": False, "error": str(e)}
98
 
99
+ @app.get("/config")
100
+ def config():
101
+ return {
102
+ "repo_id": REPO_ID,
103
+ "filename": FILENAME,
104
+ "cache_dir": CACHE_DIR,
105
+ "n_threads": N_THREADS,
106
+ "n_batch": N_BATCH,
107
+ "n_ctx": N_CTX,
108
+ "max_tokens": MAX_TOKENS,
109
+ "temperature": TEMPERATURE,
110
+ "top_p": TOP_P,
111
+ "stop_tokens": STOP_TOKENS,
112
+ "ctx_safety": CTX_SAFETY,
113
+ "has_system_prompt": bool(SYSTEM_PROMPT),
114
+ }
115
+
116
  @app.post("/generate")
117
+ def generate(req: GenerateRequest):
118
  """
119
+ Non-streaming chat completion.
120
+ Accepts ONLY a prompt string; all other params are fixed in code.
121
  """
122
  try:
123
+ if not req.prompt or not req.prompt.strip():
124
+ raise HTTPException(status_code=400, detail="prompt must be a non-empty string")
125
+
126
  model = get_model()
127
 
128
+ user_prompt = req.prompt.strip()
129
+ fitted_prompt = _fit_prompt_to_context(model, user_prompt)
130
 
131
+ # Build messages (Llama-3 chat format). System is optional.
132
+ messages = []
133
+ if SYSTEM_PROMPT:
134
+ messages.append({"role": "system", "content": SYSTEM_PROMPT})
135
+ messages.append({"role": "user", "content": fitted_prompt})
136
+
137
+ t0 = time.time()
138
+ with _model_lock:
139
  out = model.create_chat_completion(
140
+ messages=messages,
141
+ max_tokens=MAX_TOKENS,
142
+ temperature=TEMPERATURE,
143
+ top_p=TOP_P,
144
+ stop=STOP_TOKENS,
145
  )
146
+ dt = time.time() - t0
147
+
148
+ text = out["choices"][0]["message"]["content"]
149
+ usage = out.get("usage", {}) # may include prompt_tokens/completion_tokens
150
+
151
+ return JSONResponse({
152
+ "ok": True,
153
+ "response": text,
154
+ "usage": usage,
155
+ "timing_sec": round(dt, 3),
156
+ "params": {
157
+ "max_tokens": MAX_TOKENS,
158
+ "temperature": TEMPERATURE,
159
+ "top_p": TOP_P,
160
+ "stop": STOP_TOKENS,
161
+ "n_ctx": N_CTX,
162
+ },
163
+ "prompt_truncated": (fitted_prompt != user_prompt),
164
+ })
 
 
165
 
166
+ except HTTPException:
167
+ raise
168
  except Exception as e:
169
  raise HTTPException(status_code=500, detail=str(e))