import gradio as gr
import numpy as np
# import requests
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, pipeline
from langdetect import detect
from matplotlib import pyplot as plt
import imageio
def greet(name):
return "Hello " + name + "!!"
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
# Load the model
model = AutoModelForSequenceClassification.from_pretrained("saved_model")
tokenizer = AutoTokenizer.from_pretrained("saved_model")
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)
# Function called by the UI
def attribution(text):
# Clean the plot
# Detect the language
language = detect(text)
# Translate the input in german if necessary
if language == 'fr':
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de")
translatedText = translator(text)
text = translatedText[0]["translation_text"]
# Set the bars of the bar chart
bars = ""
if language == 'fr':
bars = ("DDPS", "DFI", "AS-MPC", "DFJP", "DEFR", "DETEC", "DFAE", "Parl", "ChF", "DFF", "AF", "TF")
bars = ("VBS", "EDI", "AB-BA", "EJPD", "WBF", "UVEK", "EDA", "Parl", "BK", "EFD", "BV", "BGer")
# Make the prediction with the 512 first characters
results = pipe(text[0:511], return_all_scores=True)
rates = [row["score"] for row in results[0]]
# Bar chart
y_pos = np.arange(len(bars))
plt.barh(y_pos, rates)
plt.yticks(y_pos, bars)
# Set the output text
name = ""
maxRate = np.max(rates)
maxIndex = np.argmax(rates)
# ML model not sure if highest probability < 60%
if maxRate < 0.6:
# de / fr
if language == 'de':
name = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n"
name = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n"
i = 0
# Show each department that has a probability > 10%
while i == 0:
if rates[maxIndex] >= 0.1:
name = name + "\t" + str(rates[maxIndex])[2:4] + "%" + "\t\t\t\t\t" + bars[maxIndex] + "\n"
rates[maxIndex] = 0
maxIndex = np.argmax(rates)
i = 1
# ML model pretty sure, show only one department
name = str(maxRate)[2:4] + "%" + "\t\t\t\t\t\t" + bars[maxIndex]
# Save the bar chart as png and load it (enables better display)
im = imageio.imread('rates.png')
return name, im
# display the UI
interface = gr.Interface(fn=attribution, layout="vertical",
placeholder="Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête.")],
outputs=['text', 'image'])