|
import fastapi |
|
from fastapi.responses import JSONResponse |
|
from llama_cpp import Llama |
|
from time import time |
|
import logging |
|
|
|
|
|
MODEL_PATH = "./qwen1_5-0_5b-chat-q4_0.gguf" |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
""" |
|
try: |
|
llm = Llama.from_pretrained( |
|
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", |
|
filename="*q4_0.gguf", |
|
verbose=False, |
|
n_ctx=4096, |
|
n_threads=4, |
|
n_gpu_layers=0, |
|
) |
|
|
|
llm = Llama( |
|
model_path=MODEL_PATH, |
|
chat_format="llama-2", |
|
n_ctx=4096, |
|
n_threads=8, |
|
n_gpu_layers=0, |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load model: {e}") |
|
raise |
|
""" |
|
|
|
app = fastapi.FastAPI() |
|
|
|
|
|
@app.get("/") |
|
def index(): |
|
return fastapi.responses.RedirectResponse(url="/docs") |
|
|
|
|
|
@app.get("/health") |
|
def health(): |
|
return {"status": "ok"} |
|
|
|
|
|
|
|
@app.get("/generate") |
|
async def complete( |
|
question: str, |
|
system: str = "You are a story writing assistant.", |
|
temperature: float = 0.7, |
|
seed: int = 42, |
|
) -> dict: |
|
try: |
|
st = time() |
|
output = llm.create_chat_completion( |
|
messages=[ |
|
{"role": "system", "content": system}, |
|
{"role": "user", "content": question}, |
|
], |
|
temperature=temperature, |
|
seed=seed, |
|
) |
|
et = time() |
|
output["time"] = et - st |
|
return output |
|
except Exception as e: |
|
logger.error(f"Error in /complete endpoint: {e}") |
|
return JSONResponse( |
|
status_code=500, content={"message": "Internal Server Error"} |
|
) |
|
|
|
""" |
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
""" |