#!/usr/bin/env python # coding: utf-8 import gradio as gr import numpy as np import requests from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, pipeline from langdetect import detect from matplotlib import pyplot as plt import imageio # Load the model model = AutoModelForSequenceClassification.from_pretrained("saved_model") tokenizer = AutoTokenizer.from_pretrained("saved_model") pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer) # Function called by the UI def attribution(text): # Clean the plot plt.clf() # Detect the language language = detect(text) # Translate the input in german if necessary if language == 'fr': translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de") translatedText = translator(text[0:1000]) text = translatedText[0]["translation_text"] elif language != 'de': return "The language is not recognized, it must be either in German or in French.", None # Set the bars of the bar chart bars = "" if language == 'fr': bars = ("DDPS", "DFI", "AS-MPC", "DFJP", "DEFR", "DETEC", "DFAE", "Parl", "ChF", "DFF", "AF", "TF") else: bars = ("VBS", "EDI", "AB-BA", "EJPD", "WBF", "UVEK", "EDA", "Parl", "BK", "EFD", "BV", "BGer") # Make the prediction with the 1000 first characters results = pipe(text[0:1000], return_all_scores=True) rates = [row["score"] for row in results[0]] # Bar chart y_pos = np.arange(len(bars)) plt.barh(y_pos, rates) plt.yticks(y_pos, bars) # Set the output text name = "" maxRate = np.max(rates) maxIndex = np.argmax(rates) # ML model not sure if highest probability < 60% if maxRate < 0.6: # de / fr if language == 'de': name = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n" else: name = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n" i = 0 # Show each department that has a probability > 10% while i == 0: if rates[maxIndex] >= 0.1: name = name + "\t" + str(rates[maxIndex])[2:4] + "%" + "\t\t\t\t\t" + bars[maxIndex] + "\n" rates[maxIndex] = 0 maxIndex = np.argmax(rates) else: i = 1 # ML model pretty sure, show only one department else: name = str(maxRate)[2:4] + "%" + "\t\t\t\t\t\t" + bars[maxIndex] # Save the bar chart as png and load it (enables better display) plt.savefig('rates.png') im = imageio.imread('rates.png') return name, im # display the UI interface = gr.Interface(fn=attribution, inputs=[gr.inputs.Textbox(lines=20, placeholder="Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête.")], outputs=['text', 'image']) interface.launch()