#!/usr/bin/env python # coding: utf-8 from os import listdir from os.path import isdir from fastapi import FastAPI, HTTPException, Request, responses from fastapi.middleware.cors import CORSMiddleware from llama_cpp import Llama print("Loading model...") llm = Llama( model_path="/models/final-gemma2b_SA-Q5_K.gguf", # n_gpu_layers=28, # Uncomment to use GPU acceleration # seed=1337, # Uncomment to set a specific seed # n_ctx=2048, # Uncomment to increase the context window ) def ask(question, max_new_tokens=200): output = llm( question, # Prompt max_tokens=max_new_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window stop=["\n"], # Stop generating just before the model would generate a new question echo=False, # Echo the prompt back in the output temperature=0.0, ) return output def check_sentiment(text): result = ask(f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] =', max_new_tokens=3) return result['choices'][0]['text'].strip() print("Testing model...") assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง") print("Ready.") app = FastAPI( title = "GemmaSA_2b", description="A simple sentiment analysis API for the Thai language, powered by a finetuned version of Gemma-2b", version="1.0.0", ) origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"] ) @app.get('/') def docs(): "Redirects the user from the main page to the docs." return responses.RedirectResponse('./docs') @app.get('/add/{a}/{b}') def add(a: int,b: int): return a + b @app.get('/SA') def perform_sentiment_analysis(request: Request): """Performs a sentiment analysis using a finetuned version of Gemma-7b""" prompt = request.query_params.get('prompt') if prompt: try: print(f"Checking sentiment for {prompt}") result = check_sentiment(prompt) print(f"Result: {result}") return {'success': True, 'result': result} except Exception as e: return HTTPException(500, str(e)) else: return HTTPException(400, "Request argument 'prompt' not provided.")