#!/usr/bin/env python # coding: utf-8 from os import listdir from os.path import isdir from fastapi import FastAPI, HTTPException, Request, responses, Body from fastapi.middleware.cors import CORSMiddleware from llama_cpp import Llama from pydantic import BaseModel from enum import Enum from typing import Optional, Literal, Dict, List # MODEL LOADING, FUNCTIONS, AND TESTING print("Loading model...") PHllm = Llama(model_path="/models/final-Physics_llama3.gguf", use_mmap=False, use_mlock=True) # MIllm = Llama(model_path="/models/final-LlamaTuna_Q8_0.gguf", use_mmap=False, use_mlock=True) # n_gpu_layers=28, # Uncomment to use GPU acceleration # seed=1337, # Uncomment to set a specific seed # n_ctx=2048, # Uncomment to increase the context window #) print("Loading Translators.") from pythainlp.translate.en_th import EnThTranslator, ThEnTranslator t = EnThTranslator() e = ThEnTranslator() def extract_restext(response, is_chat=False): return response['choices'][0]['message' if is_chat else 'text'].strip() def ask_llama(llm: Llama, question: str, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0): prompt = f"""<|begin_of_text|> <|start_header_id|> user <|end_header_id|> {question} <|eot_id|> <|start_header_id|> assistant <|end_header_id|>""" result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"])).replace("<|eot_id|>", "").replace("<|end_of_text|>", "") return result # def chat_llama(llm: Llama, chat_history: dict, max_new_tokens=200, temperature=0.5, repeat_penalty=2.0): # result = extract_restext(llm.create_chat_completion(chat_history, max_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty, stop=["<|eot_id|>", "<|end_of_text|>"]), is_chat=True) # return result # TESTING THE MODEL print("Testing model...") assert ask_llama(PHllm, "Hello!, How are you today?", max_new_tokens=5) #Just checking that it can run print("Checking Translators.") assert t.translate("Hello!") == "สวัสดี!" assert e.translate("สวัสดี!") == "Hello!" print("Ready.") # START OF FASTAPI APP app = FastAPI( title = "Gemma Finetuned API", description="Gemma Finetuned API for Thai Open-ended question answering.", version="1.0.0", ) origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) # API DATA CLASSES class QuestionResponse(BaseModel): code: int = 200 question: Optional[str] = None answer: str = None config: Optional[dict] = None class ChatHistoryResponse(BaseModel): code: int = 200 chat_history: Dict[str, str] = None answer: str = None config: Optional[dict] = None class LlamaChatMessage(BaseModel): role: Literal["user", "assistant"] content: str # API ROUTES @app.get('/') def docs(): "Redirects the user from the main page to the docs." return responses.RedirectResponse('./docs') @app.post('/questions/physics') async def ask_gemmaPhysics( prompt: str = Body(..., embed=True, example="Why do ice cream melt so fast?"), temperature: float = Body(0.5, embed=True), repeat_penalty: float = Body(1.0, embed=True), max_new_tokens: int = Body(200, embed=True), translate_from_thai: bool = Body(False, embed=True) ) -> QuestionResponse: """ Ask a finetuned Gemma an physics question. NOTICE: Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything. """ if prompt: try: print(f'Asking LlamaPhysics with the question "{prompt}", translation is {"enabled" if translate_from_thai else "disabled"}') if translate_from_thai: print("Translating content to EN.") prompt = e.translate(prompt) print(f"Asking the model with the question {prompt}") result = ask_llama(PHllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty) print(f"Got Model Response: {result}") if translate_from_thai: result = t.translate(result) print(f"Translation Result: {result}") return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty}) except Exception as e: return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt)) else: return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided.")) # @app.post('/chat/multiturn') # async def ask_llama3_Tuna( # chat_history: List[LlamaChatMessage] = Body(..., embed=True), # temperature: float = Body(0.5, embed=True), # repeat_penalty: float = Body(2.0, embed=True), # max_new_tokens: int = Body(200, embed=True) # ) -> ChatHistoryResponse: # """ # Chat with a finetuned Llama-3 model (in Thai). # Answers may be random / inaccurate. Always do your research & confirm its responses before doing anything. # NOTICE: YOU MUST APPLY THE LLAMA3 PROMPT YOURSELF! # """ # if chat_history: # try: # print(f'Asking Llama3Tuna with the question "{chat_history}"') # result = chat_llama(MIllm, chat_history, max_new_tokens=max_new_tokens, temperature=temperature, repeat_penalty=repeat_penalty) # print(f"Result: {result}") # return ChatHistoryResponse(answer=result, config={"temperature": temperature, "max_new_tokens": max_new_tokens, "repeat_penalty": repeat_penalty}) # except Exception as e: # return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=chat_history)) # else: # return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))