Spaces:
Sleeping
Sleeping
Upload main.py
Browse files- app/main.py +16 -23
app/main.py
CHANGED
@@ -10,25 +10,16 @@ from pydantic import BaseModel
|
|
10 |
from enum import Enum
|
11 |
from typing import Optional
|
12 |
|
|
|
|
|
13 |
print("Loading model...")
|
14 |
SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", mmap=False, mlock=True)
|
|
|
15 |
# n_gpu_layers=28, # Uncomment to use GPU acceleration
|
16 |
# seed=1337, # Uncomment to set a specific seed
|
17 |
# n_ctx=2048, # Uncomment to increase the context window
|
18 |
#)
|
19 |
|
20 |
-
FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock=True)
|
21 |
-
|
22 |
-
# def ask(question, max_new_tokens=200):
|
23 |
-
# output = llm(
|
24 |
-
# question, # Prompt
|
25 |
-
# max_tokens=max_new_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
|
26 |
-
# stop=["\n"], # Stop generating just before the model would generate a new question
|
27 |
-
# echo=False, # Echo the prompt back in the output
|
28 |
-
# temperature=0.0,
|
29 |
-
# )
|
30 |
-
# return output
|
31 |
-
|
32 |
def extract_restext(response):
|
33 |
return response['choices'][0]['text'].strip()
|
34 |
|
@@ -49,15 +40,17 @@ def check_sentiment(text):
|
|
49 |
else:
|
50 |
return "unknown"
|
51 |
|
52 |
-
|
53 |
print("Testing model...")
|
54 |
assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
|
55 |
assert ask_fi("Hello!, How are you today?")
|
56 |
print("Ready.")
|
57 |
|
|
|
|
|
58 |
app = FastAPI(
|
59 |
-
title = "
|
60 |
-
description="
|
61 |
version="1.0.0",
|
62 |
)
|
63 |
|
@@ -70,6 +63,8 @@ app.add_middleware(
|
|
70 |
allow_headers=["*"]
|
71 |
)
|
72 |
|
|
|
|
|
73 |
class SA_Result(str, Enum):
|
74 |
positive = "positive"
|
75 |
negative = "negative"
|
@@ -86,17 +81,15 @@ class FI_Response(BaseModel):
|
|
86 |
answer: str = None
|
87 |
config: Optional[dict] = None
|
88 |
|
|
|
|
|
89 |
@app.get('/')
|
90 |
def docs():
|
91 |
"Redirects the user from the main page to the docs."
|
92 |
return responses.RedirectResponse('./docs')
|
93 |
|
94 |
-
@app.
|
95 |
-
def
|
96 |
-
return a + b
|
97 |
-
|
98 |
-
@app.post('/SA')
|
99 |
-
def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SA_Response:
|
100 |
"""Performs a sentiment analysis using a finetuned version of Gemma-7b"""
|
101 |
if prompt:
|
102 |
try:
|
@@ -110,8 +103,8 @@ def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I li
|
|
110 |
return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided."))
|
111 |
|
112 |
|
113 |
-
@app.post('/
|
114 |
-
def ask_gemmaFinanceTH(
|
115 |
prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
|
116 |
temperature: float = Body(0.5, embed=True),
|
117 |
max_new_tokens: int = Body(200, embed=True)
|
|
|
10 |
from enum import Enum
|
11 |
from typing import Optional
|
12 |
|
13 |
+
# MODEL LOADING, FUNCTIONS, AND TESTING
|
14 |
+
|
15 |
print("Loading model...")
|
16 |
SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", mmap=False, mlock=True)
|
17 |
+
FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock=True)
|
18 |
# n_gpu_layers=28, # Uncomment to use GPU acceleration
|
19 |
# seed=1337, # Uncomment to set a specific seed
|
20 |
# n_ctx=2048, # Uncomment to increase the context window
|
21 |
#)
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def extract_restext(response):
|
24 |
return response['choices'][0]['text'].strip()
|
25 |
|
|
|
40 |
else:
|
41 |
return "unknown"
|
42 |
|
43 |
+
# TESTING THE MODEL
|
44 |
print("Testing model...")
|
45 |
assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
|
46 |
assert ask_fi("Hello!, How are you today?")
|
47 |
print("Ready.")
|
48 |
|
49 |
+
|
50 |
+
# START OF FASTAPI APP
|
51 |
app = FastAPI(
|
52 |
+
title = "Gemma Finetuned API",
|
53 |
+
description="Gemma Finetuned API for Sentiment Analysis and Finance Questions.",
|
54 |
version="1.0.0",
|
55 |
)
|
56 |
|
|
|
63 |
allow_headers=["*"]
|
64 |
)
|
65 |
|
66 |
+
|
67 |
+
# API DATA CLASSES
|
68 |
class SA_Result(str, Enum):
|
69 |
positive = "positive"
|
70 |
negative = "negative"
|
|
|
81 |
answer: str = None
|
82 |
config: Optional[dict] = None
|
83 |
|
84 |
+
|
85 |
+
# API ROUTES
|
86 |
@app.get('/')
|
87 |
def docs():
|
88 |
"Redirects the user from the main page to the docs."
|
89 |
return responses.RedirectResponse('./docs')
|
90 |
|
91 |
+
@app.post('/classifications/sentiment')
|
92 |
+
async def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SA_Response:
|
|
|
|
|
|
|
|
|
93 |
"""Performs a sentiment analysis using a finetuned version of Gemma-7b"""
|
94 |
if prompt:
|
95 |
try:
|
|
|
103 |
return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided."))
|
104 |
|
105 |
|
106 |
+
@app.post('/questions/finance')
|
107 |
+
async def ask_gemmaFinanceTH(
|
108 |
prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
|
109 |
temperature: float = Body(0.5, embed=True),
|
110 |
max_new_tokens: int = Body(200, embed=True)
|