#!/usr/bin/env python # coding: utf-8 from os import listdir from os.path import isdir from fastapi import FastAPI, HTTPException, Request, responses, Body from fastapi.middleware.cors import CORSMiddleware from llama_cpp import Llama from pydantic import BaseModel from enum import Enum from typing import Optional # MODEL LOADING, FUNCTIONS, AND TESTING print("Loading model...") WIllm = Llama(model_path="/models/final-GemmaWild7b-Q8_0.gguf")#, use_mmap=False, use_mlock=True) COllm = Llama(model_path="/models/TunaCodes-Q8_0.gguf")#), use_mmap=False, use_mlock=True) # n_gpu_layers=28, # Uncomment to use GPU acceleration # seed=1337, # Uncomment to set a specific seed # n_ctx=2048, # Uncomment to increase the context window #) def extract_restext(response): return response['choices'][0]['text'].strip() def ask_llm(llm, question, max_new_tokens=200, temperature=0.5): prompt = f"""###User: {question}\n###Assistant:""" result = extract_restext(llm(prompt, max_tokens=max_new_tokens, temperature=temperature, stop=["###User:", "###Assistant:"], echo=False)) return result # TESTING THE MODEL print("Testing model...") assert ask_llm(WIllm, "Hello!, How are you today?", max_new_tokens=1) #Just checking that it can run print("Ready.") # START OF FASTAPI APP app = FastAPI( title = "Gemma Finetuned API", description="Gemma Finetuned API for Open-ended and Coding questions.", version="1.0.0", ) origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) # API DATA CLASSES class QuestionResponse(BaseModel): code: int = 200 question: Optional[str] = None answer: str = None config: Optional[dict] = None # API ROUTES @app.get('/') def docs(): "Redirects the user from the main page to the docs." return responses.RedirectResponse('./docs') @app.post('/questions/open-ended') async def ask_gemmaWild( prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"), temperature: float = Body(0.5, embed=True), max_new_tokens: int = Body(200, embed=True) ) -> QuestionResponse: """ Ask a finetuned Gemma an open-ended question.. NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS. """ if prompt: try: print(f'Asking GemmaWild with the question "{prompt}"') result = ask_llm(WIllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature) print(f"Result: {result}") return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens}) except Exception as e: return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt)) else: return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided.")) @app.post('/questions/coding') async def ask_gemmaCode( prompt: str = Body(..., embed=True, example="Why is ice cream so delicious?"), temperature: float = Body(0.5, embed=True), max_new_tokens: int = Body(200, embed=True) ) -> QuestionResponse: """ Ask a finetuned Gemma an open-ended question.. NOTICE: IT MAY PRODUCE RANDOM/INACCURATE ANSWERS. PLEASE SEEK PROFESSIONAL ADVICE BEFORE DOING ANYTHING SERIOUS. """ if prompt: try: print(f'Asking GemmaCode with the question "{prompt}"') result = ask_llm(COllm, prompt, max_new_tokens=max_new_tokens, temperature=temperature) print(f"Result: {result}") return QuestionResponse(answer=result, question=prompt, config={"temperature": temperature, "max_new_tokens": max_new_tokens}) except Exception as e: return HTTPException(500, QuestionResponse(code=500, answer=str(e), question=prompt)) else: return HTTPException(400, QuestionResponse(code=400, answer="Request argument 'prompt' not provided."))