PawinC's picture
Update app/main.py (#1)
ba181f4 verified
raw
history blame
No virus
4.05 kB
#!/usr/bin/env python
# coding: utf-8
from os import listdir
from os.path import isdir
from fastapi import FastAPI, HTTPException, Request, responses, Body
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama
from pydantic import BaseModel
from enum import Enum
from typing import Optional
# MODEL LOADING, FUNCTIONS, AND TESTING
print("Loading model...")
WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf", use_mmap=False, use_mlock=True)
COllm = Llama(model_path="/models/TunaCodes-Q8_0.gguf", use_mmap=False, use_mlock=True)
# n_gpu_layers=28, # Uncomment to use GPU acceleration
# seed=1337, # Uncomment to set a specific seed
# n_ctx=2048, # Uncomment to increase the context window
#)
def extract_restext(response):
return response['choices'][0]['text'].strip()
def ask_llm(llm, question, max_new_tokens=200, temperature=0.5):
prompt = f"""###User: {question}\n###Assistant:"""
result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False))
return result
# TESTING THE MODEL
print("Testing model...")
assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run
print("Ready.")
# START OF FASTAPI APP
app = FastAPI(
title = "Gemma Finetuned API",
description="Gemma Finetuned API for Open-ended and Coding questions.",
version="1.0.0",
)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
# API DATA CLASSES
class SAResponse(BaseModel):
code: int = 200
text: Optional[str] = None
result: SA_Result = None
class QuestionResponse(BaseModel):
code: int = 200
question: Optional[str] = None
answer: str = None
config: Optional[dict] = None
# API ROUTES
@app.get('/')
def docs():
"Redirects the user from the main page to the docs."
return responses.RedirectResponse('./docs')
@app.post('/questions/open-ended')
async def ask_gemmaWild(
prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
temperature: float = Body(0.5, embed=True),
max_new_tokens: int = Body(200, embed=True)
) -> QuestionResponse:
"""
Ask a finetuned Gemma an open-ended question..
NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
"""
if prompt:
try:
print(f'Asking GemmaWild with the question "{prompt}"')
result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
print(f"Result: {result}")
return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
except Exception as e:
return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
else:
return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))
@app.post('/questions/coding')
async def ask_gemmaCode(
prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"),
temperature: float = Body(0.5, embed=True),
max_new_tokens: int = Body(200, embed=True)
) -> QuestionResponse:
"""
Ask a finetuned Gemma an open-ended question..
NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS.
"""
if prompt:
try:
print(f'Asking GemmaCode with the question "{prompt}"')
result = ask_llm(COllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature)
print(f"Result: {result}")
return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens})
except Exception as e:
return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt))
else:
return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))