Spaces:
Sleeping
Sleeping
Update app/main.py (#1)
Browse files- Update app/main.py (42b08f6dae4d60a90bd3bda5d51c3962a35b8899)
- app/main.py +26 -60
app/main.py
CHANGED
@@ -13,9 +13,8 @@ from typing import Optional
|
|
13 |
# MODEL LOADING, FUNCTIONS, AND TESTING
|
14 |
|
15 |
print("Loading model...")
|
16 |
-
|
17 |
-
|
18 |
-
# WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf", use_mmap=False, use_mlock=True)
|
19 |
# n_gpu_layers=28, # Uncomment to use GPU acceleration
|
20 |
# seed=1337, # Uncomment to set a specific seed
|
21 |
# n_ctx=2048, # Uncomment to increase the context window
|
@@ -28,31 +27,18 @@ def ask_llm(llm, question, max_new_tokens=200, temperature=0.5):
|
|
28 |
prompt = f"""###User: {question}\n###Assistant:"""
|
29 |
result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
|
30 |
return result
|
31 |
-
|
32 |
-
def check_sentiment(text):
|
33 |
-
prompt = f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] ='
|
34 |
-
response = SAllm(prompt, max_tokens=3, stop=["\n"], echo=False, temperature=0.5)
|
35 |
-
# print(response)
|
36 |
-
result = extract_restext(response)
|
37 |
-
if "positive" in result:
|
38 |
-
return "positive"
|
39 |
-
elif "negative" in result:
|
40 |
-
return "negative"
|
41 |
-
else:
|
42 |
-
return "unknown"
|
43 |
|
44 |
# TESTING THE MODEL
|
45 |
print("Testing model...")
|
46 |
-
assert "
|
47 |
-
assert ask_llm(FIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
|
48 |
-
# assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
|
49 |
print("Ready.")
|
50 |
|
51 |
|
52 |
# START OF FASTAPI APP
|
53 |
app = FastAPI(
|
54 |
title = "Gemma Finetuned API",
|
55 |
-
description="Gemma Finetuned API for
|
56 |
version="1.0.0",
|
57 |
)
|
58 |
|
@@ -67,10 +53,6 @@ app.add_middleware(
|
|
67 |
|
68 |
|
69 |
# API DATA CLASSES
|
70 |
-
class SA_Result(str, Enum):
|
71 |
-
positive = "positive"
|
72 |
-
negative = "negative"
|
73 |
-
unknown = "unknown"
|
74 |
|
75 |
class SAResponse(BaseModel):
|
76 |
code: int = 200
|
@@ -90,60 +72,44 @@ def docs():
|
|
90 |
"Redirects the user from the main page to the docs."
|
91 |
return responses.RedirectResponse('./docs')
|
92 |
|
93 |
-
@app.post('/
|
94 |
-
async def
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
if prompt:
|
97 |
try:
|
98 |
-
print(f
|
99 |
-
result =
|
100 |
print(f"Result: {result}")
|
101 |
-
return
|
102 |
except Exception as e:
|
103 |
-
return HTTPException(500,
|
104 |
else:
|
105 |
-
return HTTPException(400,
|
106 |
-
|
107 |
|
108 |
-
@app.post('/questions/
|
109 |
-
async def
|
110 |
-
prompt: str = Body(..., embed=True, example="
|
111 |
temperature: float = Body(0.5, embed=True),
|
112 |
max_new_tokens: int = Body(200, embed=True)
|
113 |
) -> QuestionResponse:
|
114 |
"""
|
115 |
-
Ask a finetuned Gemma
|
116 |
NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
|
117 |
"""
|
118 |
if prompt:
|
119 |
try:
|
120 |
-
print(f'Asking
|
121 |
-
result = ask_llm(
|
122 |
print(f"Result: {result}")
|
123 |
return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
|
124 |
except Exception as e:
|
125 |
return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
|
126 |
else:
|
127 |
return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
|
128 |
-
|
129 |
-
|
130 |
-
# @app.post('/questions/open-ended')
|
131 |
-
# async def ask_gemmaWild(
|
132 |
-
# prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
|
133 |
-
# temperature: float = Body(0.5, embed=True),
|
134 |
-
# max_new_tokens: int = Body(200, embed=True)
|
135 |
-
# ) -> QuestionResponse:
|
136 |
-
# """
|
137 |
-
# Ask a finetuned Gemma an open-ended question..
|
138 |
-
# NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
|
139 |
-
# """
|
140 |
-
# if prompt:
|
141 |
-
# try:
|
142 |
-
# print(f'Asking GemmaWild with the question "{prompt}"')
|
143 |
-
# result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
|
144 |
-
# print(f"Result: {result}")
|
145 |
-
# return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
|
146 |
-
# except Exception as e:
|
147 |
-
# return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
|
148 |
-
# else:
|
149 |
-
# return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
|
|
|
13 |
# MODEL LOADING, FUNCTIONS, AND TESTING
|
14 |
|
15 |
print("Loading model...")
|
16 |
+
WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf", use_mmap=False, use_mlock=True)
|
17 |
+
COllm = Llama(model_path="/models/TunaCodes-Q8_0.gguf", use_mmap=False, use_mlock=True)
|
|
|
18 |
# n_gpu_layers=28, # Uncomment to use GPU acceleration
|
19 |
# seed=1337, # Uncomment to set a specific seed
|
20 |
# n_ctx=2048, # Uncomment to increase the context window
|
|
|
27 |
prompt = f"""###User: {question}\n###Assistant:"""
|
28 |
result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
|
29 |
return result
|
30 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
# TESTING THE MODEL
|
33 |
print("Testing model...")
|
34 |
+
assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
|
|
|
|
|
35 |
print("Ready.")
|
36 |
|
37 |
|
38 |
# START OF FASTAPI APP
|
39 |
app = FastAPI(
|
40 |
title = "Gemma Finetuned API",
|
41 |
+
description="Gemma Finetuned API for Open-ended and Coding questions.",
|
42 |
version="1.0.0",
|
43 |
)
|
44 |
|
|
|
53 |
|
54 |
|
55 |
# API DATA CLASSES
|
|
|
|
|
|
|
|
|
56 |
|
57 |
class SAResponse(BaseModel):
|
58 |
code: int = 200
|
|
|
72 |
"Redirects the user from the main page to the docs."
|
73 |
return responses.RedirectResponse('./docs')
|
74 |
|
75 |
+
@app.post('/questions/open-ended')
|
76 |
+
async def ask_gemmaWild(
|
77 |
+
prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
|
78 |
+
temperature: float = Body(0.5, embed=True),
|
79 |
+
max_new_tokens: int = Body(200, embed=True)
|
80 |
+
) -> QuestionResponse:
|
81 |
+
"""
|
82 |
+
Ask a finetuned Gemma an open-ended question..
|
83 |
+
NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
|
84 |
+
"""
|
85 |
if prompt:
|
86 |
try:
|
87 |
+
print(f'Asking GemmaWild with the question "{prompt}"')
|
88 |
+
result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
|
89 |
print(f"Result: {result}")
|
90 |
+
return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
|
91 |
except Exception as e:
|
92 |
+
return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
|
93 |
else:
|
94 |
+
return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
|
|
|
95 |
|
96 |
+
@app.post('/questions/coding')
|
97 |
+
async def ask_gemmaCode(
|
98 |
+
prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
|
99 |
temperature: float = Body(0.5, embed=True),
|
100 |
max_new_tokens: int = Body(200, embed=True)
|
101 |
) -> QuestionResponse:
|
102 |
"""
|
103 |
+
Ask a finetuned Gemma an open-ended question..
|
104 |
NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
|
105 |
"""
|
106 |
if prompt:
|
107 |
try:
|
108 |
+
print(f'Asking GemmaCode with the question "{prompt}"')
|
109 |
+
result = ask_llm(COllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
|
110 |
print(f"Result: {result}")
|
111 |
return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
|
112 |
except Exception as e:
|
113 |
return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
|
114 |
else:
|
115 |
return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|