Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import pipeline | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import io | |
from PIL import Image | |
# Function to perform classification and create pie and bar charts | |
def classify_and_plot(text, labels): | |
# Splitting labels entered by user | |
labels_list = labels.split(',') | |
# Load the zero-shot classification pipeline with the specific model | |
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") | |
# Perform classification | |
result = classifier(text, labels_list) | |
# Extract labels and scores | |
labels = result['labels'] | |
scores = result['scores'] | |
# Generate a colour for each label | |
colors = plt.cm.viridis(np.linspace(0, 1, len(labels))) | |
# Create a pie chart | |
fig1, ax1 = plt.subplots() | |
wedges, texts = ax1.pie(scores, startangle=140, colors=colors) | |
ax1.axis('equal') # Equal aspect ratio ensures the pie chart is circular. | |
ax1.set_title('Pie Chart') | |
# Prepare labels with percentages for the pie chart legend | |
legend_labels = ['{0} - {1:1.2f} %'.format(i,j*100) for i,j in zip(labels, scores)] | |
ax1.legend(wedges, legend_labels, title="Labels with Scores", loc="center left", bbox_to_anchor=(1, 0.5)) | |
# Save the pie chart to a buffer | |
buf1 = io.BytesIO() | |
plt.savefig(buf1, format='png', bbox_inches='tight') | |
buf1.seek(0) | |
pie_chart = Image.open(buf1) | |
pie_chart_array = np.array(pie_chart) | |
plt.close() | |
# Create a bar chart | |
fig2, ax2 = plt.subplots() | |
y_pos = np.arange(len(labels)) | |
ax2.bar(y_pos, scores, align='center', alpha=0.7, color='blue') | |
ax2.set_xticks(y_pos) | |
ax2.set_xticklabels(labels, rotation=45, ha="right") | |
ax2.set_ylabel('Scores') | |
ax2.set_title('Bar Chart') | |
# Save the bar chart to a buffer | |
buf2 = io.BytesIO() | |
plt.savefig(buf2, format='png', bbox_inches='tight') | |
buf2.seek(0) | |
bar_chart = Image.open(buf2) | |
bar_chart_array = np.array(bar_chart) | |
plt.close() | |
return pie_chart_array, bar_chart_array | |
# Create a Gradio interface | |
iface = gr.Interface( | |
fn=classify_and_plot, | |
inputs=["text", "text"], | |
outputs=["image", "image"], | |
title="Zero-Shot Classification with Pie and Bar Charts", | |
description="Enter text and comma-separated labels for classification using the facebook/bart-large-mnli model. The outputs will be separate pie and bar charts representing the classification scores." | |
) | |
# Launch the interface with the 'share' argument | |
iface.launch(share=True) | |