from fastapi import FastAPI, Request, Response import gradio from transformers import pipeline from sklearn.metrics import f1_score, precision_score, recall_score import prometheus_client as prom app = FastAPI() username = "yrajm1997" repo_name = "finetuned-sentiment-model" repo_path = username+ '/' + repo_name sentiment_model = pipeline(model= repo_path) import pandas as pd test_data = pd.read_csv("test_reviews.csv") f1_metric = prom.Gauge('sentiment_f1_score', 'F1 score for random 100 test samples') precision_metric = prom.Gauge('sentiment_precision_score', 'Precision score for random 100 test samples') recall_metric = prom.Gauge('sentiment_recall_score', 'Recall score for random 100 test samples') # Function for response generation def predict_sentiment(text): result = sentiment_model(text) if result[0]['label'].endswith('0'): return 'Negative' else: return 'Positive' # Function for updating metrics def update_metrics(): test = test_data.sample(100) test_text = test['Text'].values test_pred = sentiment_model(list(test_text)) pred_labels = [int(pred['label'].split("_")[1]) for pred in test_pred] f1 = f1_score(test['labels'], pred_labels).round(3) precision = precision_score(test['labels'], pred_labels).round(3) recall = recall_score(test['labels'], pred_labels).round(3) f1_metric.set(f1) precision_metric.set(precision) recall_metric.set(recall) @app.get("/metrics") async def get_metrics(): update_metrics() return Response(media_type="text/plain", content= prom.generate_latest()) # Input from user in_prompt = gradio.components.Textbox(lines=10, placeholder=None, label='Enter review text') # Output response out_response = gradio.components.Textbox(type="text", label='Sentiment') # Gradio interface to generate UI link title = "Sentiment Classification" description = "Analyse sentiment of the given review" iface = gradio.Interface(fn = predict_sentiment, inputs = [in_prompt], outputs = [out_response], title = title, description = description) app = gradio.mount_gradio_app(app, iface, path="/") #iface.launch(server_name = "0.0.0.0", server_port = 8001) # Ref. for parameters: https://www.gradio.app/docs/interface if __name__ == "__main__": # Use this for debugging purposes only import uvicorn uvicorn.run(app, host="0.0.0.0", port=8001)