File size: 2,525 Bytes
f831e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr
from transformers import pipeline
import matplotlib.pyplot as plt
import numpy as np
import io
from PIL import Image

# Function to perform classification and create pie and bar charts
def classify_and_plot(text, labels):
    # Splitting labels entered by user
    labels_list = labels.split(',')

    # Load the zero-shot classification pipeline with the specific model
    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")

    # Perform classification
    result = classifier(text, labels_list)

    # Extract labels and scores
    labels = result['labels']
    scores = result['scores']

    # Generate a colour for each label
    colors = plt.cm.viridis(np.linspace(0, 1, len(labels)))

    # Create a pie chart
    fig1, ax1 = plt.subplots()
    wedges, texts = ax1.pie(scores, startangle=140, colors=colors)
    ax1.axis('equal')  # Equal aspect ratio ensures the pie chart is circular.
    ax1.set_title('Pie Chart')

    # Prepare labels with percentages for the pie chart legend
    legend_labels = ['{0} - {1:1.2f} %'.format(i,j*100) for i,j in zip(labels, scores)]
    ax1.legend(wedges, legend_labels, title="Labels with Scores", loc="center left", bbox_to_anchor=(1, 0.5))

    # Save the pie chart to a buffer
    buf1 = io.BytesIO()
    plt.savefig(buf1, format='png', bbox_inches='tight')
    buf1.seek(0)
    pie_chart = Image.open(buf1)
    pie_chart_array = np.array(pie_chart)
    plt.close()

    # Create a bar chart
    fig2, ax2 = plt.subplots()
    y_pos = np.arange(len(labels))
    ax2.bar(y_pos, scores, align='center', alpha=0.7, color='blue')
    ax2.set_xticks(y_pos)
    ax2.set_xticklabels(labels, rotation=45, ha="right")
    ax2.set_ylabel('Scores')
    ax2.set_title('Bar Chart')

    # Save the bar chart to a buffer
    buf2 = io.BytesIO()
    plt.savefig(buf2, format='png', bbox_inches='tight')
    buf2.seek(0)
    bar_chart = Image.open(buf2)
    bar_chart_array = np.array(bar_chart)
    plt.close()

    return pie_chart_array, bar_chart_array

# Create a Gradio interface
iface = gr.Interface(
    fn=classify_and_plot,
    inputs=["text", "text"],
    outputs=["image", "image"],
    title="Zero-Shot Classification with Pie and Bar Charts",
    description="Enter text and comma-separated labels for classification using the facebook/bart-large-mnli model. The outputs will be separate pie and bar charts representing the classification scores."
)

# Launch the interface with the 'share' argument
iface.launch(share=True)