PawinC commited on
Commit
ce36f28
1 Parent(s): cf5892c

Upload main.py

Browse files
Files changed (1) hide show
  1. app/main.py +16 -23
app/main.py CHANGED
@@ -10,25 +10,16 @@ from pydantic import BaseModel
10
  from enum import Enum
11
  from typing import Optional
12
 
 
 
13
  print("Loading model...")
14
  SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", mmap=False, mlock=True)
 
15
  # n_gpu_layers=28, # Uncomment to use GPU acceleration
16
  # seed=1337, # Uncomment to set a specific seed
17
  # n_ctx=2048, # Uncomment to increase the context window
18
  #)
19
 
20
- FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock=True)
21
-
22
- # def ask(question, max_new_tokens=200):
23
- # output = llm(
24
- # question, # Prompt
25
- # max_tokens=max_new_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
26
- # stop=["\n"], # Stop generating just before the model would generate a new question
27
- # echo=False, # Echo the prompt back in the output
28
- # temperature=0.0,
29
- # )
30
- # return output
31
-
32
  def extract_restext(response):
33
  return response['choices'][0]['text'].strip()
34
 
@@ -49,15 +40,17 @@ def check_sentiment(text):
49
  else:
50
  return "unknown"
51
 
52
-
53
  print("Testing model...")
54
  assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
55
  assert ask_fi("Hello!, How are you today?")
56
  print("Ready.")
57
 
 
 
58
  app = FastAPI(
59
- title = "GemmaSA_2b",
60
- description="A simple sentiment analysis API for the Thai language, powered by a finetuned version of Gemma-2b",
61
  version="1.0.0",
62
  )
63
 
@@ -70,6 +63,8 @@ app.add_middleware(
70
  allow_headers=["*"]
71
  )
72
 
 
 
73
  class SA_Result(str, Enum):
74
  positive = "positive"
75
  negative = "negative"
@@ -86,17 +81,15 @@ class FI_Response(BaseModel):
86
  answer: str = None
87
  config: Optional[dict] = None
88
 
 
 
89
  @app.get('/')
90
  def docs():
91
  "Redirects the user from the main page to the docs."
92
  return responses.RedirectResponse('./docs')
93
 
94
- @app.get('/add/{a}/{b}')
95
- def add(a: int,b: int):
96
- return a + b
97
-
98
- @app.post('/SA')
99
- def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SA_Response:
100
  """Performs a sentiment analysis using a finetuned version of Gemma-7b"""
101
  if prompt:
102
  try:
@@ -110,8 +103,8 @@ def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I li
110
  return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided."))
111
 
112
 
113
- @app.post('/FI')
114
- def ask_gemmaFinanceTH(
115
  prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
116
  temperature: float = Body(0.5, embed=True),
117
  max_new_tokens: int = Body(200, embed=True)
 
10
  from enum import Enum
11
  from typing import Optional
12
 
13
+ # MODEL LOADING, FUNCTIONS, AND TESTING
14
+
15
  print("Loading model...")
16
  SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", mmap=False, mlock=True)
17
+ FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", mmap=False, mlock=True)
18
  # n_gpu_layers=28, # Uncomment to use GPU acceleration
19
  # seed=1337, # Uncomment to set a specific seed
20
  # n_ctx=2048, # Uncomment to increase the context window
21
  #)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def extract_restext(response):
24
  return response['choices'][0]['text'].strip()
25
 
 
40
  else:
41
  return "unknown"
42
 
43
+ # TESTING THE MODEL
44
  print("Testing model...")
45
  assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
46
  assert ask_fi("Hello!, How are you today?")
47
  print("Ready.")
48
 
49
+
50
+ # START OF FASTAPI APP
51
  app = FastAPI(
52
+ title = "Gemma Finetuned API",
53
+ description="Gemma Finetuned API for Sentiment Analysis and Finance Questions.",
54
  version="1.0.0",
55
  )
56
 
 
63
  allow_headers=["*"]
64
  )
65
 
66
+
67
+ # API DATA CLASSES
68
  class SA_Result(str, Enum):
69
  positive = "positive"
70
  negative = "negative"
 
81
  answer: str = None
82
  config: Optional[dict] = None
83
 
84
+
85
+ # API ROUTES
86
  @app.get('/')
87
  def docs():
88
  "Redirects the user from the main page to the docs."
89
  return responses.RedirectResponse('./docs')
90
 
91
+ @app.post('/classifications/sentiment')
92
+ async def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SA_Response:
 
 
 
 
93
  """Performs a sentiment analysis using a finetuned version of Gemma-7b"""
94
  if prompt:
95
  try:
 
103
  return HTTPException(400, SA_Response(code=400, result="Request argument 'prompt' not provided."))
104
 
105
 
106
+ @app.post('/questions/finance')
107
+ async def ask_gemmaFinanceTH(
108
  prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
109
  temperature: float = Body(0.5, embed=True),
110
  max_new_tokens: int = Body(200, embed=True)