Babyloncoder's picture
Create app.py
f831e78 verified
import gradio as gr
from transformers import pipeline
import matplotlib.pyplot as plt
import numpy as np
import io
from PIL import Image
# Function to perform classification and create pie and bar charts
def classify_and_plot(text, labels):
# Splitting labels entered by user
labels_list = labels.split(',')
# Load the zero-shot classification pipeline with the specific model
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
# Perform classification
result = classifier(text, labels_list)
# Extract labels and scores
labels = result['labels']
scores = result['scores']
# Generate a colour for each label
colors = plt.cm.viridis(np.linspace(0, 1, len(labels)))
# Create a pie chart
fig1, ax1 = plt.subplots()
wedges, texts = ax1.pie(scores, startangle=140, colors=colors)
ax1.axis('equal') # Equal aspect ratio ensures the pie chart is circular.
ax1.set_title('Pie Chart')
# Prepare labels with percentages for the pie chart legend
legend_labels = ['{0} - {1:1.2f} %'.format(i,j*100) for i,j in zip(labels, scores)]
ax1.legend(wedges, legend_labels, title="Labels with Scores", loc="center left", bbox_to_anchor=(1, 0.5))
# Save the pie chart to a buffer
buf1 = io.BytesIO()
plt.savefig(buf1, format='png', bbox_inches='tight')
buf1.seek(0)
pie_chart = Image.open(buf1)
pie_chart_array = np.array(pie_chart)
plt.close()
# Create a bar chart
fig2, ax2 = plt.subplots()
y_pos = np.arange(len(labels))
ax2.bar(y_pos, scores, align='center', alpha=0.7, color='blue')
ax2.set_xticks(y_pos)
ax2.set_xticklabels(labels, rotation=45, ha="right")
ax2.set_ylabel('Scores')
ax2.set_title('Bar Chart')
# Save the bar chart to a buffer
buf2 = io.BytesIO()
plt.savefig(buf2, format='png', bbox_inches='tight')
buf2.seek(0)
bar_chart = Image.open(buf2)
bar_chart_array = np.array(bar_chart)
plt.close()
return pie_chart_array, bar_chart_array
# Create a Gradio interface
iface = gr.Interface(
fn=classify_and_plot,
inputs=["text", "text"],
outputs=["image", "image"],
title="Zero-Shot Classification with Pie and Bar Charts",
description="Enter text and comma-separated labels for classification using the facebook/bart-large-mnli model. The outputs will be separate pie and bar charts representing the classification scores."
)
# Launch the interface with the 'share' argument
iface.launch(share=True)