llm-api / server.py
mrmadblack's picture
Update server.py
634c389 verified
"""
Ollama-compatible API server β€” MAXIMUM OPTIMISED
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
πŸ€– qwen2.5-coder-7b β†’ MAX AGENT (port 8080)
🌐 qwen3.5-4b β†’ INTERNET (port 8081)
⚑ qwen3.5-0.8b β†’ FAST+DRAFT (port 8082)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
ALL OPTIMISATIONS:
1. Speculative decoding β€” 0.8b drafts, 7b verifies β†’ ~2x speed
2. KV cache reuse β€” cache_prompt across turns
3. Flash attention β€” faster attention (-fa)
4. Continuous batching β€” --cont-batching, multiple requests share GPU/CPU
5. Selective web search β€” skips for pure code tasks
6. Parallel page fetch β€” 3 pages fetched simultaneously via ThreadPool
7. Web search TTL cache β€” same query within 5min reuses result
8. Persistent HTTP sessions β€” one TCP connection per model, no reconnect overhead
9. Smart n_predict β€” 512 chat / 2048 code / auto-detected
10. Model meta cache β€” no file I/O on every /api/tags call
11. OS thread pinning β€” OMP/BLAS threads set to exact vCPU count
12. Memory allocator tuning β€” MALLOC settings reduce fragmentation
13. N_DRAFT = 10 β€” more aggressive speculative drafting for code
14. /api/version endpoint β€” needed by Continue + some Ollama UIs
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
Total RAM: ~10.5GB / 16GB
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
"""
# ── OS-level tuning BEFORE any imports ────────────────────────────────────────
# Pins thread counts to exact vCPU count β€” prevents llama.cpp and numpy
# fighting over threads which causes slowdowns on 2 vCPU
import os
os.environ.setdefault("OMP_NUM_THREADS", "2")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "2")
os.environ.setdefault("MKL_NUM_THREADS", "2")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "2")
# Reduce memory fragmentation from many small llama.cpp allocations
os.environ.setdefault("MALLOC_TRIM_THRESHOLD_", "131072") # trim after 128KB
os.environ.setdefault("MALLOC_MMAP_THRESHOLD_", "131072")
# ──────────────────────────────────────────────────────────────────────────────
from fastapi import FastAPI, HTTPException, Response
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from huggingface_hub import hf_hub_download
from concurrent.futures import ThreadPoolExecutor, as_completed
import subprocess
import requests
import uvicorn
import re
import json
import time
import hashlib
import threading
import datetime
from typing import Optional
app = FastAPI()
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# GLOBAL CONFIG
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
CTX_SIZE = 4096 # 4096 = sweet spot for speed+quality on 2 vCPU
DRAFT_MODEL_KEY = "qwen3.5-0.8b"
N_DRAFT = 10 # aggressive drafting β€” code has high acceptance rate
WEB_CACHE_TTL = 300 # seconds to cache web search results (5 min)
PAGE_FETCH_WORKERS = 3 # parallel page fetches
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# MODEL CONFIGS
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
MODELS = {
"qwen2.5-coder-7b": {
"path": "models/qwen2.5-coder-7b.gguf",
"repo": "bartowski/Qwen2.5-Coder-7B-Instruct-GGUF",
"file": "Qwen2.5-Coder-7B-Instruct-Q4_K_S.gguf", # Q4_K_S = faster
"port": 8080,
"param_size": "7B",
"family": "qwen2.5",
"fmt": "chatml",
"web_search": False,
"threads": 2,
"ctx": CTX_SIZE,
"batch": 512,
"ubatch": 128,
"use_draft": True,
},
"qwen3.5-4b": {
"path": "models/qwen3.5-4b.gguf",
"repo": "bartowski/Qwen_Qwen3.5-4B-GGUF",
"file": "Qwen_Qwen3.5-4B-Q4_K_M.gguf",
"port": 8081,
"param_size": "4B",
"family": "qwen3.5",
"fmt": "chatml",
"web_search": True,
"threads": 2,
"ctx": CTX_SIZE,
"batch": 512,
"ubatch": 128,
"use_draft": False,
},
"qwen3.5-0.8b": {
"path": "models/qwen3.5-0.8b.gguf",
"repo": "bartowski/Qwen_Qwen3.5-0.8B-GGUF",
"file": "Qwen_Qwen3.5-0.8B-Q4_K_M.gguf",
"port": 8082,
"param_size": "0.8B",
"family": "qwen3.5",
"fmt": "chatml",
"web_search": False,
"threads": 2,
"ctx": CTX_SIZE,
"batch": 512,
"ubatch": 256,
"use_draft": False,
},
"gemma3-4b": { # 🌐 TRANSLATION β€” Tamil↔English, general chat
"path": "models/gemma3-4b.gguf",
"repo": "bartowski/google_gemma-3-4b-it-GGUF",
"file": "google_gemma-3-4b-it-Q4_K_M.gguf",
"port": 8083,
"param_size": "4B",
"family": "gemma3",
"fmt": "gemma",
"web_search": False,
"threads": 2,
"ctx": CTX_SIZE,
"batch": 512,
"ubatch": 128,
"use_draft": False,
},
}
DEFAULT_MODEL = "qwen2.5-coder-7b"
LLAMA_SERVER = "./llama.cpp/build/bin/llama-server"
_server_ready: dict = {k: False for k in MODELS}
# Persistent HTTP sessions β€” one TCP connection per llama-server port
# Avoids TCP handshake overhead on every request (~5-10ms saved per call)
_sessions: dict = {}
def get_session(port: int) -> requests.Session:
if port not in _sessions:
s = requests.Session()
s.headers.update({"Content-Type": "application/json"})
_sessions[port] = s
return _sessions[port]
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# REQUEST MODELS
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
class ChatRequest(BaseModel):
model: str = DEFAULT_MODEL
messages: list
stream: bool = True
options: Optional[dict] = None
class GenerateRequest(BaseModel):
model: str = DEFAULT_MODEL
prompt: str
stream: bool = False
options: Optional[dict] = None
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# WEB SEARCH β€” parallel + cached
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# TTL cache β€” {query_key: (timestamp, result_string)}
_search_cache: dict = {}
_search_cache_lock = threading.Lock()
WEB_NEEDED = [
'latest', 'current', 'today', 'now', 'news', 'recent',
'version', 'release', 'update', 'price', 'stock', 'weather',
'score', 'live', 'trending', 'announced', 'launched',
'who is', 'what happened', '2025', '2026',
]
CODE_ONLY = [
'write a function', 'refactor', 'fix this', 'add error handling',
'create a class', 'implement', 'unit test', 'autocomplete',
'what does this code', 'explain this code', 'debug this',
'write a test', 'create a component',
]
def needs_web_search(query: str) -> bool:
q = query.lower()
if any(k in q for k in CODE_ONLY):
return False
if any(k in q for k in WEB_NEEDED):
return True
return True # internet agent β€” default search
def rewrite_query(query: str) -> str:
"""Strip conversational filler, add year for time-sensitive queries."""
fillers = [
r'^hey\s+', r'^hi\s+', r'^hello\s+', r'^can you\s+',
r'^please\s+', r'^i want to know\s+', r'^tell me\s+',
r'^what is the\s+', r'^how is the\s+', r'^do you know\s+',
r'^i need to know\s+', r'^give me\s+',
]
q = query.strip()
for f in fillers:
q = re.sub(f, '', q, flags=re.IGNORECASE).strip()
time_words = ['weather', 'news', 'score', 'price', 'stock',
'latest', 'current', 'update', 'release', 'version']
year = datetime.datetime.now().year
if any(w in q.lower() for w in time_words) and str(year) not in q:
q = f"{q} {year}"
return q.strip()
def fetch_page(url: str, max_chars: int = 600) -> str:
"""Fetch and strip one page β€” called in parallel via ThreadPool."""
try:
if not url.startswith("http"):
return ""
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 Chrome/120.0.0.0 Safari/537.36"}
resp = requests.get(url, headers=headers, timeout=5)
if resp.status_code != 200:
return ""
html = resp.text
html = re.sub(r'<(script|style|nav|footer|header)[^>]*>.*?</\1>',
'', html, flags=re.DOTALL)
text = re.sub(r'<[^>]+>', ' ', html)
text = re.sub(r'\s+', ' ', text).strip()
return text[:max_chars]
except Exception:
return ""
def web_search(query: str, max_results: int = 3) -> str:
"""
Optimised deep web search:
- TTL cache: same query reuses result for 5 min
- Parallel page fetch: all pages fetched simultaneously
"""
search_q = rewrite_query(query)
cache_key = search_q.lower().strip()
# Check TTL cache first
with _search_cache_lock:
if cache_key in _search_cache:
ts, cached = _search_cache[cache_key]
if time.time() - ts < WEB_CACHE_TTL:
print(f" [web_search] cache hit for: '{search_q}'")
return cached
try:
from ddgs import DDGS
print(f" [web_search] searching: '{search_q}'")
with DDGS() as ddgs:
results = list(ddgs.text(search_q, max_results=max_results))
if not results and search_q != query:
with DDGS() as ddgs:
results = list(ddgs.text(query, max_results=max_results))
if not results:
print(f" [web_search] no results")
return ""
print(f" [web_search] got {len(results)}, fetching pages in parallel...")
# PARALLEL page fetch β€” all 3 pages at the same time
urls = [r.get("href", "") for r in results]
page_contents = {}
with ThreadPoolExecutor(max_workers=PAGE_FETCH_WORKERS) as executor:
future_to_idx = {executor.submit(fetch_page, url): i
for i, url in enumerate(urls)}
for future in as_completed(future_to_idx):
idx = future_to_idx[future]
try:
page_contents[idx] = future.result()
except Exception:
page_contents[idx] = ""
context = f"=== Web: '{search_q}' ===\n"
for i, r in enumerate(results, 1):
title = r.get("title", "").strip()
snippet = r.get("body", "").strip()
url = r.get("href", "").strip()
context += f"\n[{i}] {title}\nSummary: {snippet}\n"
content = page_contents.get(i - 1, "")
if content:
context += f"Content: {content}\n"
context += f"Source: {url}\n"
context += "\n=== End ===\n"
# Cache the result
with _search_cache_lock:
_search_cache[cache_key] = (time.time(), context)
return context
except ImportError:
print(" [web_search] ddgs not installed")
return ""
except Exception as e:
print(f" [web_search] error: {e}")
return ""
def inject_web_context(messages: list) -> list:
if not messages:
return messages
last_user = next(
(m for m in reversed(messages) if m.get("role") == "user"), None
)
if not last_user:
return messages
user_text = last_user.get("content", "")
if not needs_web_search(user_text):
print(f" [web_search] skipped β€” code-only task")
return messages
context = web_search(user_text)
if not context:
return messages
print(f" [web_search] injected {len(context)} chars")
web_sys = {
"role": "system",
"content": (
"Real-time web search results below. "
"Use them for accurate answers. Cite as [1],[2],[3]. "
"Ignore if not relevant.\n\n" + context
)
}
new_msgs = []
inserted = False
for m in messages:
if m is last_user and not inserted:
new_msgs.append(web_sys)
inserted = True
new_msgs.append(m)
return new_msgs
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# PROMPT BUILDER
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Coding-specific system prompt β€” keeps outputs focused and shorter
CODING_SYSTEM = (
"You are an expert software engineer specialising in C#, Angular, React, "
"Node.js, and Express.js. Write clean, production-ready code. "
"Be concise β€” only output what was asked. No unnecessary explanations."
)
def build_prompt(messages: list, fmt: str = "chatml") -> str:
if fmt == "gemma":
prompt = "<bos>"
for m in messages:
role = m.get("role", "user")
content = m.get("content", "").strip()
if not content:
continue
if role == "system":
prompt += f"<start_of_turn>user\n[Context] {content}<end_of_turn>\n"
elif role == "user":
prompt += f"<start_of_turn>user\n{content}<end_of_turn>\n"
elif role == "assistant":
prompt += f"<start_of_turn>model\n{content}<end_of_turn>\n"
prompt += "<start_of_turn>model\n"
return prompt
# ChatML
prompt = ""
has_system = any(m.get("role") == "system" for m in messages)
if not has_system:
prompt += f"<|im_start|>system\n{CODING_SYSTEM}<|im_end|>\n"
for m in messages:
role = m.get("role", "user")
content = m.get("content", "").strip()
if not content:
continue
if role == "system":
prompt += f"<|im_start|>system\n{content}<|im_end|>\n"
elif role == "user":
prompt += f"<|im_start|>user\n{content}<|im_end|>\n"
elif role == "assistant":
prompt += f"<|im_start|>assistant\n{content}<|im_end|>\n"
prompt += "<|im_start|>assistant\n"
return prompt
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# MODEL RESOLVER
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def resolve_model(name: str) -> str:
name = (name or DEFAULT_MODEL).lower().strip()
if name in MODELS:
return name
for key in MODELS:
if key in name or name in key:
return key
return DEFAULT_MODEL
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# SMART n_predict
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
CODE_TASK_WORDS = [
'write', 'create', 'build', 'implement', 'generate', 'code',
'function', 'class', 'component', 'service', 'controller',
'refactor', 'fix', 'update', 'add', 'make', 'develop', 'complete',
]
def smart_n_predict(messages: list, options: Optional[dict]) -> int:
if options and options.get("num_predict"):
return options["num_predict"]
last = next(
(m.get("content", "") for m in reversed(messages)
if m.get("role") == "user"), ""
)
if any(k in last.lower() for k in CODE_TASK_WORDS):
return 2048
return 512
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# MODEL META CACHE β€” avoid file I/O on every /api/tags
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
_meta_cache: dict = {}
def model_meta(name: str, cfg: dict) -> dict:
if name in _meta_cache:
return dict(_meta_cache[name]) # return copy
size = os.path.getsize(cfg["path"]) if os.path.exists(cfg["path"]) else 0
digest = ""
if os.path.exists(cfg["path"]):
with open(cfg["path"], "rb") as f:
digest = hashlib.md5(f.read(65536)).hexdigest()
meta = {
"name": name,
"model": name,
"modified_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"size": size,
"digest": f"sha256:{digest}",
"details": {
"format": "gguf",
"family": cfg["family"],
"families": [cfg["family"]],
"parameter_size": cfg["param_size"],
"quantization_level": "Q4_K_S" if "7b" in name else "Q4_K_M",
},
}
_meta_cache[name] = meta
return dict(meta)
def llama_params(options: Optional[dict], fmt: str = "chatml",
messages: Optional[list] = None) -> dict:
o = options or {}
default_stop = ["<end_of_turn>", "<eos>"] if fmt == "gemma" \
else ["<|im_end|>", "<|endoftext|>"]
return {
"temperature": o.get("temperature", 0.7),
"top_p": o.get("top_p", 0.9),
"top_k": o.get("top_k", 40),
"repeat_penalty": o.get("repeat_penalty", 1.1),
"n_predict": smart_n_predict(messages or [], options),
"stop": o.get("stop", default_stop),
"cache_prompt": True, # KV cache reuse across turns
}
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# DOWNLOAD + START LLAMA SERVERS
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def download_model(cfg: dict):
if not os.path.exists(cfg["path"]):
print(f"Downloading {cfg['file']} ...")
downloaded = hf_hub_download(repo_id=cfg["repo"], filename=cfg["file"])
os.system(f"cp '{downloaded}' '{cfg['path']}'")
print(f" βœ“ saved to {cfg['path']}")
def detect_llama_flags() -> set:
"""
Run llama-server -h and detect which flags are supported.
Called once at startup β€” result cached in _supported_flags.
"""
try:
result = subprocess.run(
[LLAMA_SERVER, "-h"],
capture_output=True, text=True, timeout=10
)
help_text = result.stdout + result.stderr
flags = set()
if "--no-warmup" in help_text: flags.add("no-warmup")
if "--cont-batching" in help_text: flags.add("cont-batching")
if "--cache-reuse" in help_text: flags.add("cache-reuse")
if "--ubatch-size" in help_text: flags.add("ubatch-size")
if "--threads-batch" in help_text: flags.add("threads-batch")
if "--batch-size" in help_text: flags.add("batch-size")
if "-fa" in help_text: flags.add("fa")
if "--draft-model" in help_text: flags.add("draft-model")
if "-ngld" in help_text: flags.add("ngld")
if "-nds" in help_text: flags.add("nds")
if "-np " in help_text: flags.add("np")
if "--parallel" in help_text: flags.add("np")
print(f" [llama] supported flags: {flags}")
return flags
except Exception as e:
print(f" [llama] flag detection failed: {e}")
return set()
_supported_flags: set = set()
_flags_detected = False
def get_supported_flags() -> set:
global _supported_flags, _flags_detected
if not _flags_detected:
_supported_flags = detect_llama_flags()
_flags_detected = True
return _supported_flags
def start_llama(model_name: str, cfg: dict):
download_model(cfg)
print(f"Starting {model_name} on port {cfg['port']}...")
log = open(f"llama_{model_name}.log", "w")
# Proven working command β€” same flags that successfully ran all models before
cmd = [
LLAMA_SERVER,
"-m", cfg["path"],
"--host", "0.0.0.0",
"--port", str(cfg["port"]),
"-c", str(cfg["ctx"]),
"--threads", str(cfg["threads"]),
"--batch-size", str(cfg["batch"]),
"-ngl", "0",
"-np", "1",
]
print(f" [{model_name}] cmd: {' '.join(cmd)}")
# Note: speculative decoding (--draft-model) not supported in this build
process = subprocess.Popen(cmd, stdout=log, stderr=log)
url = f"http://localhost:{cfg['port']}/health"
for i in range(90):
time.sleep(2)
try:
r = requests.get(url, timeout=2)
if r.status_code == 200:
_server_ready[model_name] = True
# Pre-warm persistent session
get_session(cfg["port"])
print(f" βœ“ {model_name} ready (~{(i+1)*2}s)")
return process
except Exception:
pass
try:
with open(f"llama_{model_name}.log") as lf:
lines = [l.strip() for l in lf.read().splitlines() if l.strip()]
print(f" [{model_name}] {lines[-1] if lines else 'starting...'}")
except Exception:
print(f" waiting {model_name}... ({i+1}/90)")
print(f" βœ— {model_name} failed")
return None
def setup_all():
os.makedirs("models", exist_ok=True)
for name, cfg in MODELS.items():
threading.Thread(target=start_llama, args=(name, cfg), daemon=True).start()
threading.Thread(target=setup_all, daemon=True).start()
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# READINESS GUARD
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
def wait_for_model(key: str, timeout: int = 300):
deadline = time.time() + timeout
while time.time() < deadline:
if _server_ready.get(key):
return
time.sleep(1)
raise HTTPException(503, detail=f"'{key}' still loading β€” retry in a moment.")
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# API ENDPOINTS
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
@app.get("/")
def root():
return {
"status": "running",
"models_ready": dict(_server_ready),
"configs": {
"πŸ€– max agent": "qwen2.5-coder-7b β€” VS Code agent, full app building",
"🌐 internet": "qwen3.5-4b β€” web search + /think reasoning",
"⚑ fast": "qwen3.5-0.8b β€” autocomplete, quick chat",
"πŸ”€ translation": "gemma3-4b β€” Tamil↔English, general chat",
}
}
@app.head("/health")
def health_head():
return Response(status_code=200)
@app.get("/health")
def health_get():
return {"status": "ok", "ready": all(_server_ready.values())}
# /api/version β€” needed by Continue and some Ollama UIs
@app.get("/api/version")
def version():
return {"version": "0.3.0"}
@app.get("/api/tags")
def tags():
models = []
for n, c in MODELS.items():
models.append({
"name": n,
"model": n,
"modified_at": time.strftime("%Y-%m-%dT%H:%M:%SZ"),
"size": os.path.getsize(c["path"]) if os.path.exists(c["path"]) else 0,
"digest": f"sha256:{n}", # fast fake digest β€” no file read
"details": {
"format": "gguf",
"family": c["family"],
"families": [c["family"]],
"parameter_size": c["param_size"],
"quantization_level": "Q4_K_S" if "7b" in n else "Q4_K_M",
},
})
return {"models": models}
@app.post("/api/show")
def show(body: dict):
key = resolve_model(body.get("name", DEFAULT_MODEL))
cfg = MODELS[key]
meta = model_meta(key, cfg)
meta["modelfile"] = f"FROM {key}\n"
meta["parameters"] = f"num_ctx {cfg['ctx']}\nnum_predict 2048"
meta["template"] = (
"<|im_start|>system\n{{ .System }}<|im_end|>\n"
"<|im_start|>user\n{{ .Prompt }}<|im_end|>\n"
"<|im_start|>assistant\n"
)
return meta
@app.get("/api/ps")
def ps():
running = []
for name, cfg in MODELS.items():
if _server_ready.get(name):
m = model_meta(name, cfg)
m["expires_at"] = "0001-01-01T00:00:00Z"
m["size_vram"] = 0
running.append(m)
return {"models": running}
def _stream_lines(r):
"""Shared streaming generator β€” parses llama-server SSE lines."""
for line in r.iter_lines():
if not line:
continue
line = line.decode("utf-8").strip()
if line.startswith("data:"):
line = line[5:].strip()
try:
yield json.loads(line)
except Exception:
continue
@app.post("/api/generate")
def generate(req: GenerateRequest):
key = resolve_model(req.model)
cfg = MODELS[key]
wait_for_model(key)
params = llama_params(req.options, fmt=cfg["fmt"])
params["prompt"] = req.prompt
params["stream"] = req.stream
session = get_session(cfg["port"])
r = session.post(
f"http://localhost:{cfg['port']}/completion",
json=params, stream=req.stream, timeout=300,
)
if not req.stream:
text = r.json().get("content", "").strip()
return {"model": req.model, "response": text,
"done": True, "done_reason": "stop"}
def stream_gen():
for data in _stream_lines(r):
token = data.get("content", "")
done = data.get("stop", False)
if token:
yield json.dumps({"model": req.model,
"response": token, "done": False}) + "\n"
if done:
break
yield json.dumps({"model": req.model, "response": "",
"done": True, "done_reason": "stop"}) + "\n"
return StreamingResponse(stream_gen(), media_type="application/x-ndjson",
headers={"Cache-Control": "no-cache"})
@app.post("/api/chat")
def chat(req: ChatRequest):
key = resolve_model(req.model)
cfg = MODELS[key]
wait_for_model(key)
messages = req.messages
if cfg.get("web_search", False):
messages = inject_web_context(messages)
prompt = build_prompt(messages, fmt=cfg["fmt"])
params = llama_params(req.options, fmt=cfg["fmt"], messages=messages)
params["prompt"] = prompt
params["stream"] = req.stream
session = get_session(cfg["port"])
r = session.post(
f"http://localhost:{cfg['port']}/completion",
json=params, stream=req.stream, timeout=300,
)
if not req.stream:
text = r.json().get("content", "").strip()
return JSONResponse({
"model": req.model,
"message": {"role": "assistant", "content": text},
"done": True, "done_reason": "stop",
})
def stream_gen():
for data in _stream_lines(r):
token = data.get("content", "")
done = data.get("stop", False)
if token:
yield json.dumps({
"model": req.model,
"message": {"role": "assistant", "content": token},
"done": False,
}) + "\n"
if done:
break
yield json.dumps({"model": req.model,
"done": True, "done_reason": "stop"}) + "\n"
return StreamingResponse(stream_gen(), media_type="application/x-ndjson",
headers={"Cache-Control": "no-cache"})
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# START
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)