PawinC commited on
Commit
ba181f4
·
verified ·
1 Parent(s): aa8d737

Update app/main.py (#1)

Browse files

- Update app/main.py (42b08f6dae4d60a90bd3bda5d51c3962a35b8899)

Files changed (1) hide show
  1. app/main.py +26 -60
app/main.py CHANGED
@@ -13,9 +13,8 @@ from typing import Optional
13
  # MODEL LOADING, FUNCTIONS, AND TESTING
14
 
15
  print("Loading model...")
16
- SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", use_mmap=False, use_mlock=True)
17
- FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", use_mmap=False, use_mlock=True)
18
- # WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf", use_mmap=False, use_mlock=True)
19
  # n_gpu_layers=28, # Uncomment to use GPU acceleration
20
  # seed=1337, # Uncomment to set a specific seed
21
  # n_ctx=2048, # Uncomment to increase the context window
@@ -28,31 +27,18 @@ def ask_llm(llm, question, max_new_tokens=200, temperature=0.5):
28
  prompt = f"""###User: {question}\n###Assistant:"""
29
  result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
30
  return result
31
-
32
- def check_sentiment(text):
33
- prompt = f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] ='
34
- response = SAllm(prompt, max_tokens=3, stop=["\n"], echo=False, temperature=0.5)
35
- # print(response)
36
- result = extract_restext(response)
37
- if "positive" in result:
38
- return "positive"
39
- elif "negative" in result:
40
- return "negative"
41
- else:
42
- return "unknown"
43
 
44
  # TESTING THE MODEL
45
  print("Testing model...")
46
- assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
47
- assert ask_llm(FIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
48
- # assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
49
  print("Ready.")
50
 
51
 
52
  # START OF FASTAPI APP
53
  app = FastAPI(
54
  title = "Gemma Finetuned API",
55
- description="Gemma Finetuned API for Sentiment Analysis and Finance Questions.",
56
  version="1.0.0",
57
  )
58
 
@@ -67,10 +53,6 @@ app.add_middleware(
67
 
68
 
69
  # API DATA CLASSES
70
- class SA_Result(str, Enum):
71
- positive = "positive"
72
- negative = "negative"
73
- unknown = "unknown"
74
 
75
  class SAResponse(BaseModel):
76
  code: int = 200
@@ -90,60 +72,44 @@ def docs():
90
  "Redirects the user from the main page to the docs."
91
  return responses.RedirectResponse('./docs')
92
 
93
- @app.post('/classifications/sentiment')
94
- async def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SAResponse:
95
- """Performs a sentiment analysis using a finetuned version of Gemma-7b"""
 
 
 
 
 
 
 
96
  if prompt:
97
  try:
98
- print(f"Checking sentiment for {prompt}")
99
- result = check_sentiment(prompt)
100
  print(f"Result: {result}")
101
- return SAResponse(result=result, text=prompt)
102
  except Exception as e:
103
- return HTTPException(500, SAResponse(code=500, result=str(e), text=prompt))
104
  else:
105
- return HTTPException(400, SAResponse(code=400, result="Request argument 'prompt' not provided."))
106
-
107
 
108
- @app.post('/questions/finance')
109
- async def ask_gemmaFinanceTH(
110
- prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
111
  temperature: float = Body(0.5, embed=True),
112
  max_new_tokens: int = Body(200, embed=True)
113
  ) -> QuestionResponse:
114
  """
115
- Ask a finetuned Gemma a finance-related question, just for fun.
116
  NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
117
  """
118
  if prompt:
119
  try:
120
- print(f'Asking GemmaFinance with the question "{prompt}"')
121
- result = ask_llm(FIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
122
  print(f"Result: {result}")
123
  return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
124
  except Exception as e:
125
  return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
126
  else:
127
  return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
128
-
129
-
130
- # @app.post('/questions/open-ended')
131
- # async def ask_gemmaWild(
132
- # prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
133
- # temperature: float = Body(0.5, embed=True),
134
- # max_new_tokens: int = Body(200, embed=True)
135
- # ) -> QuestionResponse:
136
- # """
137
- # Ask a finetuned Gemma an open-ended question..
138
- # NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
139
- # """
140
- # if prompt:
141
- # try:
142
- # print(f'Asking GemmaWild with the question "{prompt}"')
143
- # result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
144
- # print(f"Result: {result}")
145
- # return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
146
- # except Exception as e:
147
- # return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
148
- # else:
149
- # return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
 
13
  # MODEL LOADING, FUNCTIONS, AND TESTING
14
 
15
  print("Loading model...")
16
+ WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf", use_mmap=False, use_mlock=True)
17
+ COllm = Llama(model_path="/models/TunaCodes-Q8_0.gguf", use_mmap=False, use_mlock=True)
 
18
  # n_gpu_layers=28, # Uncomment to use GPU acceleration
19
  # seed=1337, # Uncomment to set a specific seed
20
  # n_ctx=2048, # Uncomment to increase the context window
 
27
  prompt = f"""###User: {question}\n###Assistant:"""
28
  result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
29
  return result
30
+
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # TESTING THE MODEL
33
  print("Testing model...")
34
+ assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
 
 
35
  print("Ready.")
36
 
37
 
38
  # START OF FASTAPI APP
39
  app = FastAPI(
40
  title = "Gemma Finetuned API",
41
+ description="Gemma Finetuned API for Open-ended and Coding questions.",
42
  version="1.0.0",
43
  )
44
 
 
53
 
54
 
55
  # API DATA CLASSES
 
 
 
 
56
 
57
  class SAResponse(BaseModel):
58
  code: int = 200
 
72
  "Redirects the user from the main page to the docs."
73
  return responses.RedirectResponse('./docs')
74
 
75
+ @app.post('/questions/open-ended')
76
+ async def ask_gemmaWild(
77
+ prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
78
+ temperature: float = Body(0.5, embed=True),
79
+ max_new_tokens: int = Body(200, embed=True)
80
+ ) -> QuestionResponse:
81
+ """
82
+ Ask a finetuned Gemma an open-ended question..
83
+ NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
84
+ """
85
  if prompt:
86
  try:
87
+ print(f'Asking GemmaWild with the question "{prompt}"')
88
+ result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
89
  print(f"Result: {result}")
90
+ return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
91
  except Exception as e:
92
+ return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
93
  else:
94
+ return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
 
95
 
96
+ @app.post('/questions/coding')
97
+ async def ask_gemmaCode(
98
+ prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
99
  temperature: float = Body(0.5, embed=True),
100
  max_new_tokens: int = Body(200, embed=True)
101
  ) -> QuestionResponse:
102
  """
103
+ Ask a finetuned Gemma an open-ended question..
104
  NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
105
  """
106
  if prompt:
107
  try:
108
+ print(f'Asking GemmaCode with the question "{prompt}"')
109
+ result = ask_llm(COllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
110
  print(f"Result: {result}")
111
  return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
112
  except Exception as e:
113
  return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
114
  else:
115
  return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))