moriire commited on
Commit
c625a8c
·
verified ·
1 Parent(s): bfb952c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fastapi
2
+ from fastapi.responses import JSONResponse
3
+ from llama_cpp import Llama
4
+ from time import time
5
+ import logging
6
+
7
+
8
+ MODEL_PATH = "./qwen1_5-0_5b-chat-q4_0.gguf" #"./qwen1_5-0_5b-chat-q4_0.gguf"
9
+
10
+ # Logger setup
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Initialize Llama model
15
+ """
16
+ try:
17
+ llm = Llama.from_pretrained(
18
+ repo_id="Qwen/Qwen1.5-0.5B-Chat-GGUF",
19
+ filename="*q4_0.gguf",
20
+ verbose=False,
21
+ n_ctx=4096,
22
+ n_threads=4,
23
+ n_gpu_layers=0,
24
+ )
25
+
26
+ llm = Llama(
27
+ model_path=MODEL_PATH,
28
+ chat_format="llama-2",
29
+ n_ctx=4096,
30
+ n_threads=8,
31
+ n_gpu_layers=0,
32
+ )
33
+
34
+ except Exception as e:
35
+ logger.error(f"Failed to load model: {e}")
36
+ raise
37
+ """
38
+
39
+ app = fastapi.FastAPI()
40
+
41
+
42
+ @app.get("/")
43
+ def index():
44
+ return fastapi.responses.RedirectResponse(url="/docs")
45
+
46
+
47
+ @app.get("/health")
48
+ def health():
49
+ return {"status": "ok"}
50
+
51
+
52
+ # Chat Completion API
53
+ @app.get("/generate")
54
+ async def complete(
55
+ question: str,
56
+ system: str = "You are a story writing assistant.",
57
+ temperature: float = 0.7,
58
+ seed: int = 42,
59
+ ) -> dict:
60
+ try:
61
+ st = time()
62
+ output = llm.create_chat_completion(
63
+ messages=[
64
+ {"role": "system", "content": system},
65
+ {"role": "user", "content": question},
66
+ ],
67
+ temperature=temperature,
68
+ seed=seed,
69
+ )
70
+ et = time()
71
+ output["time"] = et - st
72
+ return output
73
+ except Exception as e:
74
+ logger.error(f"Error in /complete endpoint: {e}")
75
+ return JSONResponse(
76
+ status_code=500, content={"message": "Internal Server Error"}
77
+ )
78
+
79
+ """
80
+ if __name__ == "__main__":
81
+ import uvicorn
82
+
83
+ uvicorn.run(app, host="0.0.0.0", port=8000)
84
+ """