import gradio as gr from transformers import pipeline import matplotlib.pyplot as plt import numpy as np import io from PIL import Image # Function to perform classification and create pie and bar charts def classify_and_plot(text, labels): # Splitting labels entered by user labels_list = labels.split(',') # Load the zero-shot classification pipeline with the specific model classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") # Perform classification result = classifier(text, labels_list) # Extract labels and scores labels = result['labels'] scores = result['scores'] # Generate a colour for each label colors = plt.cm.viridis(np.linspace(0, 1, len(labels))) # Create a pie chart fig1, ax1 = plt.subplots() wedges, texts = ax1.pie(scores, startangle=140, colors=colors) ax1.axis('equal') # Equal aspect ratio ensures the pie chart is circular. ax1.set_title('Pie Chart') # Prepare labels with percentages for the pie chart legend legend_labels = ['{0} - {1:1.2f} %'.format(i,j*100) for i,j in zip(labels, scores)] ax1.legend(wedges, legend_labels, title="Labels with Scores", loc="center left", bbox_to_anchor=(1, 0.5)) # Save the pie chart to a buffer buf1 = io.BytesIO() plt.savefig(buf1, format='png', bbox_inches='tight') buf1.seek(0) pie_chart = Image.open(buf1) pie_chart_array = np.array(pie_chart) plt.close() # Create a bar chart fig2, ax2 = plt.subplots() y_pos = np.arange(len(labels)) ax2.bar(y_pos, scores, align='center', alpha=0.7, color='blue') ax2.set_xticks(y_pos) ax2.set_xticklabels(labels, rotation=45, ha="right") ax2.set_ylabel('Scores') ax2.set_title('Bar Chart') # Save the bar chart to a buffer buf2 = io.BytesIO() plt.savefig(buf2, format='png', bbox_inches='tight') buf2.seek(0) bar_chart = Image.open(buf2) bar_chart_array = np.array(bar_chart) plt.close() return pie_chart_array, bar_chart_array # Create a Gradio interface iface = gr.Interface( fn=classify_and_plot, inputs=["text", "text"], outputs=["image", "image"], title="Zero-Shot Classification with Pie and Bar Charts", description="Enter text and comma-separated labels for classification using the facebook/bart-large-mnli model. The outputs will be separate pie and bar charts representing the classification scores." ) # Launch the interface with the 'share' argument iface.launch(share=True)