abcd / engine.py
Karan6933's picture
Upload 6 files
65a1dcc verified
import asyncio
from typing import List, AsyncGenerator, Dict
from llama_cpp import Llama, LlamaGrammar
import logging
logger = logging.getLogger(__name__)
class BatchInferenceEngine:
"""
Pure Python batch inference engine using llama-cpp-python.
Loads model once, handles multiple concurrent requests efficiently.
"""
def __init__(self, model_path: str, n_ctx: int = 4096, n_threads: int = 4):
self.model_path = model_path
self.n_ctx = n_ctx
self.n_threads = n_threads
self._model: Llama = None
self._lock = asyncio.Lock()
def load(self):
"""Load model once at startup"""
logger.info(f"Loading model from {self.model_path}")
self._model = Llama(
model_path=self.model_path,
n_ctx=self.n_ctx,
n_threads=self.n_threads,
n_batch=512,
verbose=False
)
logger.info("Model loaded successfully")
async def generate_stream(
self,
prompt: str,
max_tokens: int = 256,
temperature: float = 0.7,
stop: List[str] = None
) -> AsyncGenerator[str, None]:
"""
Async streaming generator for single request.
Uses thread pool to run sync llama-cpp in background.
"""
if self._model is None:
raise RuntimeError("Model not loaded")
# Run blocking llama-cpp call in thread pool
loop = asyncio.get_event_loop()
def _generate():
return self._model.create_completion(
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
stop=stop or [],
stream=True # Enable streaming
)
# Get streaming iterator
stream = await loop.run_in_executor(None, _generate)
# Yield tokens as they arrive
for chunk in stream:
if "choices" in chunk and len(chunk["choices"]) > 0:
delta = chunk["choices"][0].get("text", "")
if delta:
yield delta
async def generate_batch(
self,
prompts: List[str],
max_tokens: int = 256,
temperature: float = 0.7
) -> List[str]:
"""
Process multiple prompts efficiently.
On CPU, we process sequentially to avoid contention.
"""
results = []
for prompt in prompts:
chunks = []
async for token in self.generate_stream(prompt, max_tokens, temperature):
chunks.append(token)
results.append("".join(chunks))
return results
# Global singleton instance
_engine: BatchInferenceEngine = None
def get_engine() -> BatchInferenceEngine:
global _engine
if _engine is None:
raise RuntimeError("Engine not initialized")
return _engine
def init_engine(model_path: str, **kwargs):
global _engine
_engine = BatchInferenceEngine(model_path, **kwargs)
_engine.load()
return _engine