EC2 Default User
adding App files
6f69d1d
raw
history blame contribute delete
No virus
2.49 kB
from fastapi import FastAPI, Request, Response
import gradio
from transformers import pipeline
from sklearn.metrics import f1_score, precision_score, recall_score
import prometheus_client as prom
app = FastAPI()
username = "yrajm1997"
repo_name = "finetuned-sentiment-model"
repo_path = username+ '/' + repo_name
sentiment_model = pipeline(model= repo_path)
import pandas as pd
test_data = pd.read_csv("test_reviews.csv")
f1_metric = prom.Gauge('sentiment_f1_score', 'F1 score for random 100 test samples')
precision_metric = prom.Gauge('sentiment_precision_score', 'Precision score for random 100 test samples')
recall_metric = prom.Gauge('sentiment_recall_score', 'Recall score for random 100 test samples')
# Function for response generation
def predict_sentiment(text):
result = sentiment_model(text)
if result[0]['label'].endswith('0'):
return 'Negative'
else:
return 'Positive'
# Function for updating metrics
def update_metrics():
test = test_data.sample(100)
test_text = test['Text'].values
test_pred = sentiment_model(list(test_text))
pred_labels = [int(pred['label'].split("_")[1]) for pred in test_pred]
f1 = f1_score(test['labels'], pred_labels).round(3)
precision = precision_score(test['labels'], pred_labels).round(3)
recall = recall_score(test['labels'], pred_labels).round(3)
f1_metric.set(f1)
precision_metric.set(precision)
recall_metric.set(recall)
@app.get("/metrics")
async def get_metrics():
update_metrics()
return Response(media_type="text/plain", content= prom.generate_latest())
# Input from user
in_prompt = gradio.components.Textbox(lines=10, placeholder=None, label='Enter review text')
# Output response
out_response = gradio.components.Textbox(type="text", label='Sentiment')
# Gradio interface to generate UI link
title = "Sentiment Classification"
description = "Analyse sentiment of the given review"
iface = gradio.Interface(fn = predict_sentiment,
inputs = [in_prompt],
outputs = [out_response],
title = title,
description = description)
app = gradio.mount_gradio_app(app, iface, path="/")
#iface.launch(server_name = "0.0.0.0", server_port = 8001) # Ref. for parameters: https://www.gradio.app/docs/interface
if __name__ == "__main__":
# Use this for debugging purposes only
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)