File size: 2,153 Bytes
c625a8c 654eaa0 c625a8c 654eaa0 30b9c64 654eaa0 c625a8c 654eaa0 c625a8c 654eaa0 c625a8c 6c8cc78 c625a8c 6c8cc78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import fastapi
"""
from fastapi.responses import JSONResponse
from llama_cpp import Llama
from time import time
"""
#MODEL_PATH = "./qwen1_5-0_5b-chat-q4_0.gguf" #"./qwen1_5-0_5b-chat-q4_0.gguf"
import logging
import llama_cpp
import llama_cpp.llama_tokenizer
llama = llama_cpp.Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q4_0.gguf",
tokenizer=llama_cpp.llama_tokenizer.LlamaHFTokenizer.from_pretrained("Qwen/Qwen1.5-0.5B"),
verbose=False,
n_ctx=4096,
n_threads=4,
n_gpu_layers=0,
)
# Logger setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize Llama model
"""
try:
llm = Llama.from_pretrained(
repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
filename="*q4_0.gguf",
verbose=False,
n_ctx=4096,
n_threads=4,
n_gpu_layers=0,
)
llm = Llama(
model_path=MODEL_PATH,
chat_format="llama-2",
n_ctx=4096,
n_threads=8,
n_gpu_layers=0,
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
"""
app = fastapi.FastAPI()
@app.get("/")
def index():
return fastapi.responses.RedirectResponse(url="/docs")
@app.get("/health")
def health():
return {"status": "ok"}
# Chat Completion API
@app.get("/generate")
async def complete(
question: str,
system: str = "You are a story writing assistant.",
temperature: float = 0.7,
seed: int = 42,
) -> dict:
try:
st = time()
output = llama.create_chat_completion(
messages=[
{"role": "system", "content": system},
{"role": "user", "content": question},
],
temperature=temperature,
seed=seed,
)
et = time()
output["time"] = et - st
return output
except Exception as e:
logger.error(f"Error in /complete endpoint: {e}")
return JSONResponse(
status_code=500, content={"message": "Internal Server Error"}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |