PawinC commited on
Commit
f5fdf38
1 Parent(s): 4d3709e

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +9 -2
  2. app/main.py +61 -71
  3. pythainlp-data/gitkeep +0 -0
  4. requirements.txt +7 -1
Dockerfile CHANGED
@@ -8,11 +8,18 @@ COPY requirements.txt /requirements.txt
8
 
9
  RUN pip install -r requirements.txt
10
 
11
- COPY app /app
 
12
 
13
- COPY models /models
 
 
14
  #DO NOT FORGET TO UNCOMMENT THE ABOVE WHEN PUSHING TO HF!!!!
15
 
 
 
 
 
16
  # EXPOSE 7860
17
 
18
  ENV PYTHONUNBUFFERED=1
 
8
 
9
  RUN pip install -r requirements.txt
10
 
11
+ RUN useradd -m -u 1000 user
12
+ USER user
13
 
14
+ COPY --chown=user:user app /app
15
+
16
+ COPY --chown=user:user models /models
17
  #DO NOT FORGET TO UNCOMMENT THE ABOVE WHEN PUSHING TO HF!!!!
18
 
19
+ COPY --chown=user:user pythainlp-data /pythainlp-data
20
+
21
+ RUN sha256sum /models/final-Physics_llama3.gguf
22
+
23
  # EXPOSE 7860
24
 
25
  ENV PYTHONUNBUFFERED=1
app/main.py CHANGED
@@ -8,51 +8,47 @@ from llama_cpp import Llama
8
 
9
  from pydantic import BaseModel
10
  from enum import Enum
11
- from typing import Optional
12
 
13
  # MODEL LOADING, FUNCTIONS, AND TESTING
14
 
15
  print("Loading model...")
16
- SAllm = Llama(model_path="/models/final-gemma2b_SA-Q8_0.gguf", use_mmap=False, use_mlock=True)
17
- FIllm = Llama(model_path="/models/final-gemma7b_FI-Q8_0.gguf", use_mmap=False, use_mlock=True)
18
- # WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf", use_mmap=False, use_mlock=True)
19
  # n_gpu_layers=28, # Uncomment to use GPU acceleration
20
  # seed=1337, # Uncomment to set a specific seed
21
  # n_ctx=2048, # Uncomment to increase the context window
22
  #)
23
 
24
- def extract_restext(response):
25
- return response['choices'][0]['text'].strip()
 
 
26
 
27
- def ask_llm(llm, question, max_new_tokens=200, temperature=0.5):
28
- prompt = f"""###User: {question}\n###Assistant:"""
29
- result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
 
 
 
 
 
 
30
  return result
31
-
32
- def check_sentiment(text):
33
- prompt = f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] ='
34
- response = SAllm(prompt, max_tokens=3, stop=["\n"], echo=False, temperature=0.5)
35
- # print(response)
36
- result = extract_restext(response)
37
- if "positive" in result:
38
- return "positive"
39
- elif "negative" in result:
40
- return "negative"
41
- else:
42
- return "unknown"
43
 
44
  # TESTING THE MODEL
45
  print("Testing model...")
46
- assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
47
- assert ask_llm(FIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
48
- # assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
 
49
  print("Ready.")
50
 
51
 
52
  # START OF FASTAPI APP
53
  app = FastAPI(
54
  title = "Gemma Finetuned API",
55
- description="Gemma Finetuned API for Sentiment Analysis and Finance Questions.",
56
  version="1.0.0",
57
  )
58
 
@@ -67,22 +63,22 @@ app.add_middleware(
67
 
68
 
69
  # API DATA CLASSES
70
- class SA_Result(str, Enum):
71
- positive = "positive"
72
- negative = "negative"
73
- unknown = "unknown"
74
-
75
- class SAResponse(BaseModel):
76
- code: int = 200
77
- text: Optional[str] = None
78
- result: SA_Result = None
79
-
80
  class QuestionResponse(BaseModel):
81
  code: int = 200
82
  question: Optional[str] = None
83
  answer: str = None
84
  config: Optional[dict] = None
85
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  # API ROUTES
88
  @app.get('/')
@@ -90,60 +86,54 @@ def docs():
90
  "Redirects the user from the main page to the docs."
91
  return responses.RedirectResponse('./docs')
92
 
93
- @app.post('/classifications/sentiment')
94
- async def perform_sentiment_analysis(prompt: str = Body(..., embed=True, example="I like eating fried chicken")) -> SAResponse:
95
- """Performs a sentiment analysis using a finetuned version of Gemma-7b"""
96
- if prompt:
97
- try:
98
- print(f"Checking sentiment for {prompt}")
99
- result = check_sentiment(prompt)
100
- print(f"Result: {result}")
101
- return SAResponse(result=result, text=prompt)
102
- except Exception as e:
103
- return HTTPException(500, SAResponse(code=500, result=str(e), text=prompt))
104
- else:
105
- return HTTPException(400, SAResponse(code=400, result="Request argument 'prompt' not provided."))
106
-
107
-
108
- @app.post('/questions/finance')
109
- async def ask_gemmaFinanceTH(
110
- prompt: str = Body(..., embed=True, example="What's the best way to invest my money"),
111
  temperature: float = Body(0.5, embed=True),
112
- max_new_tokens: int = Body(200, embed=True)
 
 
113
  ) -> QuestionResponse:
114
  """
115
- Ask a finetuned Gemma a finance-related question, just for fun.
116
- NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
117
  """
118
  if prompt:
119
  try:
120
- print(f'Asking GemmaFinance with the question "{prompt}"')
121
- result = ask_llm(FIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
 
 
122
  print(f"Result: {result}")
123
- return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
 
 
124
  except Exception as e:
125
  return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
126
  else:
127
  return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
128
-
129
 
130
- # @app.post('/questions/open-ended')
131
- # async def ask_gemmaWild(
132
- # prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
 
133
  # temperature: float = Body(0.5, embed=True),
 
134
  # max_new_tokens: int = Body(200, embed=True)
135
- # ) -> QuestionResponse:
136
  # """
137
- # Ask a finetuned Gemma an open-ended question..
138
- # NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
 
139
  # """
140
- # if prompt:
141
  # try:
142
- # print(f'Asking GemmaWild with the question "{prompt}"')
143
- # result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
144
  # print(f"Result: {result}")
145
- # return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
146
  # except Exception as e:
147
- # return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
148
  # else:
149
  # return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
 
 
8
 
9
  from pydantic import BaseModel
10
  from enum import Enum
11
+ from typing import Optional, Literal, Dict, List
12
 
13
  # MODEL LOADING, FUNCTIONS, AND TESTING
14
 
15
  print("Loading model...")
16
+ PHllm = Llama(model_path="/models/final-Physics_llama3.gguf", use_mmap=False, use_mlock=True)
17
+ # MIllm = Llama(model_path="/models/final-LlamaTuna_Q8_0.gguf", use_mmap=False, use_mlock=True)
 
18
  # n_gpu_layers=28, # Uncomment to use GPU acceleration
19
  # seed=1337, # Uncomment to set a specific seed
20
  # n_ctx=2048, # Uncomment to increase the context window
21
  #)
22
 
23
+ print("Loading Translators.")
24
+ from pythainlp.translate.en_th import EnThTranslator, ThEnTranslator
25
+ t = EnThTranslator()
26
+ e = ThEnTranslator()
27
 
28
+ def extract_restext(response, is_chat=False):
29
+ return response['choices'][0]['text' if is_chat else 'message'].strip()
30
+
31
+ def ask_llama(llm: Llama, question: str, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0):
32
+ result = extract_restext(llm.create_chat_completion({"role": "user", "content": question}, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"]), is_chat=True)
33
+ return result
34
+
35
+ def chat_llama(llm: Llama, chat_history: dict, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0):
36
+ result = extract_restext(llm.create_chat_completion(chat_history, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"]), is_chat=True)
37
  return result
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  # TESTING THE MODEL
40
  print("Testing model...")
41
+ assert ask_llama(PHllm, ["Hello!, How are you today?"], max_new_tokens=5) #Just checking that it can run
42
+ print("Checking Translators.")
43
+ assert t.translate("Hello!") == "สวัสดี!"
44
+ assert e.translate("สวัสดี!") == "Hello!"
45
  print("Ready.")
46
 
47
 
48
  # START OF FASTAPI APP
49
  app = FastAPI(
50
  title = "Gemma Finetuned API",
51
+ description="Gemma Finetuned API for Thai Open-ended question answering.",
52
  version="1.0.0",
53
  )
54
 
 
63
 
64
 
65
  # API DATA CLASSES
 
 
 
 
 
 
 
 
 
 
66
  class QuestionResponse(BaseModel):
67
  code: int = 200
68
  question: Optional[str] = None
69
  answer: str = None
70
  config: Optional[dict] = None
71
 
72
+ class ChatHistoryResponse(BaseModel):
73
+ code: int = 200
74
+ chat_history: Dict[str] = None
75
+ answer: str = None
76
+ config: Optional[dict] = None
77
+
78
+ class LlamaChatMessage(BaseModel):
79
+ role: Literal["user", "assistant"]
80
+ content: str
81
+
82
 
83
  # API ROUTES
84
  @app.get('/')
 
86
  "Redirects the user from the main page to the docs."
87
  return responses.RedirectResponse('./docs')
88
 
89
+ @app.post('/questions/physics')
90
+ async def ask_gemmaPhysics(
91
+ prompt: str = Body(..., embed=True, example="Why do ice cream melt so fast?"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  temperature: float = Body(0.5, embed=True),
93
+ repeat_penalty: float = Body(1.0, embed=True),
94
+ max_new_tokens: int = Body(200, embed=True),
95
+ translate_from_thai: bool = Body(False, embed=True)
96
  ) -> QuestionResponse:
97
  """
98
+ Ask a finetuned Gemma an physics question.
99
+ NOTICE: Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything.
100
  """
101
  if prompt:
102
  try:
103
+ print(f'Asking LlamaPhysics with the question "{prompt}", translation is {"enabled" if translate_from_thai else "disabled"}')
104
+ if translate_from_thai:
105
+ prompt = e.translate(prompt)
106
+ result = ask_llama(PHllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty)
107
  print(f"Result: {result}")
108
+ if translate_from_thai:
109
+ result = t.translate(result)
110
+ return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty})
111
  except Exception as e:
112
  return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
113
  else:
114
  return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
 
115
 
116
+
117
+ # @app.post('/chat/multiturn')
118
+ # async def ask_llama3_Tuna(
119
+ # chat_history: List[LlamaChatMessage] = Body(..., embed=True),
120
  # temperature: float = Body(0.5, embed=True),
121
+ # repeat_penalty: float = Body(2.0, embed=True),
122
  # max_new_tokens: int = Body(200, embed=True)
123
+ # ) -> ChatHistoryResponse:
124
  # """
125
+ # Chat with a finetuned Llama-3 model (in Thai).
126
+ # Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything.
127
+ # NOTICE: YOU MUST APPLY THE LLAMA3 PROMPT YOURSELF!
128
  # """
129
+ # if chat_history:
130
  # try:
131
+ # print(f'Asking Llama3Tuna with the question "{chat_history}"')
132
+ # result = chat_llama(MIllm, chat_history, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty)
133
  # print(f"Result: {result}")
134
+ # return ChatHistoryResponse(answer=result, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty})
135
  # except Exception as e:
136
+ # return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=chat_history))
137
  # else:
138
  # return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
139
+
pythainlp-data/gitkeep ADDED
File without changes
requirements.txt CHANGED
@@ -1,3 +1,9 @@
1
  uvicorn[standard]
2
  fastapi
3
- llama-cpp-python
 
 
 
 
 
 
 
1
  uvicorn[standard]
2
  fastapi
3
+ llama-cpp-python
4
+ pythainlp
5
+ pandas
6
+ fairseq
7
+ sacremoses
8
+ sentencepiece
9
+ transformers