Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
from os import listdir | |
from os.path import isdir | |
from fastapi import FastAPI, HTTPException, Request, responses | |
from fastapi.middleware.cors import CORSMiddleware | |
from llama_cpp import Llama | |
print("Loading model...") | |
llm = Llama( | |
model_path="/models/final-gemma2b_SA-Q5_K.gguf", | |
# n_gpu_layers=28, # Uncomment to use GPU acceleration | |
# seed=1337, # Uncomment to set a specific seed | |
# n_ctx=2048, # Uncomment to increase the context window | |
) | |
def ask(question, max_new_tokens=200): | |
output = llm( | |
question, # Prompt | |
max_tokens=max_new_tokens, # Generate up to 32 tokens, set to None to generate up to the end of the context window | |
stop=["\n"], # Stop generating just before the model would generate a new question | |
echo=False, # Echo the prompt back in the output | |
temperature=0.0, | |
) | |
return output | |
def check_sentiment(text): | |
result = ask(f'Analyze the sentiment of the tweet enclosed in square brackets, determine if it is positive or negative, and return the answer as the corresponding sentiment label "positive" or "negative" [{text}] =', max_new_tokens=3) | |
return result['choices'][0]['text'].strip() | |
print("Testing model...") | |
assert "positive" in check_sentiment("ดอกไม้ร้านนี้สวยจัง") | |
print("Ready.") | |
app = FastAPI( | |
title = "GemmaSA_2b", | |
description="A simple sentiment analysis API for the Thai language, powered by a finetuned version of Gemma-2b", | |
version="1.0.0", | |
) | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"] | |
) | |
def docs(): | |
"Redirects the user from the main page to the docs." | |
return responses.RedirectResponse('./docs') | |
def add(a: int,b: int): | |
return a + b | |
def perform_sentiment_analysis(request: Request): | |
"""Performs a sentiment analysis using a finetuned version of Gemma-7b""" | |
prompt = request.query_params.get('prompt') | |
if prompt: | |
try: | |
print(f"Checking sentiment for {prompt}") | |
result = check_sentiment(prompt) | |
print(f"Result: {result}") | |
return {'success': True, 'result': result} | |
except Exception as e: | |
return HTTPException(500, str(e)) | |
else: | |
return HTTPException(400, "Request argument 'prompt' not provided.") |