|
import fastapi |
|
from fastapi.responses import JSONResponse |
|
from time import time |
|
|
|
import logging |
|
import llama_cpp |
|
import llama_cpp.llama_tokenizer |
|
|
|
llama = llama_cpp.Llama.from_pretrained( |
|
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", |
|
filename="*q4_0.gguf", |
|
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"), |
|
verbose=False, |
|
n_ctx=4096, |
|
n_gpu_layers=0, |
|
chat_format="llama-2" |
|
) |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
""" |
|
try: |
|
llm = Llama.from_pretrained( |
|
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF", |
|
filename="*q4_0.gguf", |
|
verbose=False, |
|
n_ctx=4096, |
|
n_threads=4, |
|
n_gpu_layers=0, |
|
) |
|
|
|
llm = Llama( |
|
model_path=MODEL_PATH, |
|
chat_format="llama-2", |
|
n_ctx=4096, |
|
n_threads=8, |
|
n_gpu_layers=0, |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Failed to load model: {e}") |
|
raise |
|
""" |
|
|
|
app = fastapi.FastAPI() |
|
|
|
|
|
@app.get("/") |
|
def index(): |
|
return fastapi.responses.RedirectResponse(url="/docs") |
|
|
|
|
|
@app.get("/health") |
|
def health(): |
|
return {"status": "ok"} |
|
|
|
|
|
|
|
@app.get("/generate") |
|
async def complete( |
|
question: str, |
|
system: str = "You are a story writing assistant.", |
|
temperature: float = 0.7, |
|
seed: int = 42, |
|
) -> dict: |
|
try: |
|
st = time() |
|
output = llama.create_chat_completion( |
|
messages=[ |
|
{"role": "system", "content": system}, |
|
{"role": "user", "content": question}, |
|
], |
|
temperature=temperature, |
|
seed=seed, |
|
stream=True |
|
) |
|
for chunk in output: |
|
""" |
|
delta = chunk['choices'][0]['delta'] |
|
if 'role' in delta: |
|
print(delta['role'], end=': ') |
|
elif 'content' in delta: |
|
print(delta['content'], end='') |
|
""" |
|
print(chunk) |
|
et = time() |
|
output["time"] = et - st |
|
|
|
except Exception as e: |
|
logger.error(f"Error in /complete endpoint: {e}") |
|
return JSONResponse( |
|
status_code=500, content={"message": "Internal Server Error"} |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |