PawinC's picture
Upload 6 files
4926347 verified
raw
history blame
2.43 kB
#!/usr/bin/env python
# coding: utf-8
from os import listdir
from os.path import isdir
from fastapi import FastAPI, HTTPException, Request, responses
from fastapi.middleware.cors import CORSMiddleware
from llama_cpp import Llama
print("Loading model...")
llm = Llama(
model_path="/models/final-gemma2b_SA-Q5_K.gguf",
# n_gpu_layers=28, # Uncomment to use GPU acceleration
# seed=1337, # Uncomment to set a specific seed
# n_ctx=2048, # Uncomment to increase the context window
)
def ask(question, max_new_tokens=200):
output = llm(
question, # Prompt
max_tokens=max_new_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window
stop=["\n"], # Stop generating just before the model would generate a new question
echo=False, # Echo the prompt back in the output
temperature=0.0,
)
return output
def check_sentiment(text):
result = ask(f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] =', max_new_tokens=3)
return result['choices'][0]['text'].strip()
print("Testing model...")
assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง")
print("Ready.")
app = FastAPI(
title = "GemmaSA_2b",
description="A simple sentiment analysis API for the Thai language, powered by a finetuned version of Gemma-2b",
version="1.0.0",
)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)
@app.get('/')
def docs():
"Redirects the user from the main page to the docs."
return responses.RedirectResponse('./docs')
@app.get('/add/{a}/{b}')
def add(a: int,b: int):
return a + b
@app.get('/SA')
def perform_sentiment_analysis(request: Request):
"""Performs a sentiment analysis using a finetuned version of Gemma-7b"""
prompt = request.query_params.get('prompt')
if prompt:
try:
print(f"Checking sentiment for {prompt}")
result = check_sentiment(prompt)
print(f"Result: {result}")
return {'success': True, 'result': result}
except Exception as e:
return HTTPException(500, str(e))
else:
return HTTPException(400, "Request argument 'prompt' not provided.")