File size: 4,048 Bytes
4926347
 
 
 
0c228e3
4926347
 
7afebb8
 
 
 
4926347
ce36f28
 
4926347
ba181f4
 
4926347
 
 
33d6214
4926347
befe899
 
 
7a0ef1d
8b750c3
7a0ef1d
8b750c3
ba181f4
33d6214
ce36f28
4926347
ba181f4
4926347
 
ce36f28
 
4926347
ce36f28
ba181f4
4926347
 
 
 
 
 
 
 
 
 
 
 
ce36f28
 
7afebb8
7a0ef1d
7afebb8
33d6214
7afebb8
 
7a0ef1d
33d6214
 
 
 
7afebb8
ce36f28
 
4926347
 
 
 
 
ba181f4
 
 
 
 
 
 
 
 
 
7afebb8
 
ba181f4
 
7afebb8
ba181f4
7afebb8
ba181f4
7afebb8
ba181f4
33d6214
ba181f4
 
 
0c228e3
 
7a0ef1d
33d6214
ba181f4
33d6214
 
 
 
ba181f4
 
33d6214
7a0ef1d
33d6214
7a0ef1d
33d6214
7a0ef1d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#!/usr/bin/env python
# coding: utf-8
from os import listdir
from os.path import isdir
from fastapi import FastAPI, HTTPException, Request, responses, Body
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama

from pydantic import BaseModel
from enum import Enum
from typing import Optional

# MODEL LOADING, FUNCTIONS, AND TESTING

print("Loading model...")
WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf", use_mmap=False, use_mlock=True)
COllm = Llama(model_path="/models/TunaCodes-Q8_0.gguf", use_mmap=False, use_mlock=True)
      # n_gpu_layers=28, # Uncomment to use GPU acceleration
      # seed=1337, # Uncomment to set a specific seed
      # n_ctx=2048, # Uncomment to increase the context window
#)

def extract_restext(response):
  return response['choices'][0]['text'].strip()

def ask_llm(llm, question, max_new_tokens=200, temperature=0.5):
  prompt = f"""###User: {question}\n###Assistant:"""
  result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
  return result


# TESTING THE MODEL
print("Testing model...")
assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
print("Ready.")


# START OF FASTAPI APP
app = FastAPI(
    title = "Gemma Finetuned API",
    description="Gemma Finetuned API for Open-ended and Coding questions.",
    version="1.0.0",
)

origins = ["*"]
app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"]
)


# API DATA CLASSES

class SAResponse(BaseModel):
  code: int = 200
  text: Optional[str] = None
  result: SA_Result = None

class QuestionResponse(BaseModel):
  code: int = 200
  question: Optional[str] = None
  answer: str = None
  config: Optional[dict] = None


# API ROUTES
@app.get('/')
def docs():
  "Redirects the user from the main page to the docs."
  return responses.RedirectResponse('./docs')

@app.post('/questions/open-ended')
async def ask_gemmaWild(
    prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
    temperature: float = Body(0.5, embed=True), 
    max_new_tokens: int = Body(200, embed=True)
) -> QuestionResponse:
  """
  Ask a finetuned Gemma an open-ended question..
  NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
  """
  if prompt:
    try:
      print(f'Asking GemmaWild with the question "{prompt}"')
      result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
      print(f"Result: {result}")
      return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
    except Exception as e:
      return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
  else:
    return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))

@app.post('/questions/coding')
async def ask_gemmaCode(
    prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
    temperature: float = Body(0.5, embed=True), 
    max_new_tokens: int = Body(200, embed=True)
) -> QuestionResponse:
  """
  Ask a finetuned Gemma an open-ended question..
  NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
  """
  if prompt:
    try:
      print(f'Asking GemmaCode with the question "{prompt}"')
      result = ask_llm(COllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
      print(f"Result: {result}")
      return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
    except Exception as e:
      return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
  else:
    return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))