File size: 3,049 Bytes
5c49760
 
 
2f4692b
8ec661c
5c49760
8ec661c
 
 
 
2f4692b
8ec661c
5c49760
8ec661c
 
 
 
 
5c49760
8ec661c
 
5c49760
8ec661c
 
5c49760
8ec661c
 
 
5c49760
8ec661c
5c49760
 
 
8ec661c
 
 
 
 
 
 
5c49760
 
8ec661c
5c49760
8ec661c
 
 
 
5c49760
8ec661c
5c49760
8ec661c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c49760
 
8ec661c
 
 
 
 
5c49760
8ec661c
 
97183a3
5c49760
 
8ec661c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/env python
# coding: utf-8

import gradio as gr
import numpy as np
import requests
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, pipeline
from langdetect import detect
from matplotlib import pyplot as plt
import imageio

# Load the model
model = AutoModelForSequenceClassification.from_pretrained("saved_model") 
tokenizer = AutoTokenizer.from_pretrained("saved_model")
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)

# Function called by the UI
def attribution(text):
    
    # Clean the plot
    plt.clf()
    
    # Detect the language
    language = detect(text)
     
    # Translate the input in german if necessary
    if language == 'fr':
        translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de")
        translatedText = translator(text[0:1000])
        text = translatedText[0]["translation_text"]
    elif language != 'de':
        return "The language is not recognized, it must be either in German or in French.", None
    
    # Set the bars of the bar chart
    bars = ""
    if language == 'fr':
        bars = ("DDPS", "DFI", "AS-MPC", "DFJP", "DEFR", "DETEC", "DFAE", "Parl", "ChF", "DFF", "AF", "TF")
    else:
        bars = ("VBS", "EDI", "AB-BA", "EJPD", "WBF", "UVEK", "EDA", "Parl", "BK", "EFD", "BV", "BGer")

    # Make the prediction with the 1000 first characters 
    results = pipe(text[0:1000], return_all_scores=True)
    rates = [row["score"] for row in results[0]]
    
    # Bar chart
    y_pos = np.arange(len(bars))
    plt.barh(y_pos, rates)
    plt.yticks(y_pos, bars)
    
    # Set the output text
    name = ""    
    maxRate = np.max(rates)
    maxIndex = np.argmax(rates)

    # ML model not sure if highest probability < 60%
    if maxRate < 0.6:
        # de / fr
        if language == 'de':
            name = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n"
        else:
            name = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n"
        i = 0
        # Show each department that has a probability > 10%
        while i == 0:
            if rates[maxIndex] >= 0.1:
                name = name + "\t" + str(rates[maxIndex])[2:4] + "%" + "\t\t\t\t\t" + bars[maxIndex] + "\n"
                rates[maxIndex] = 0
                maxIndex = np.argmax(rates)
            else:
                i = 1
    # ML model pretty sure, show only one department
    else:
        name =  str(maxRate)[2:4] + "%" + "\t\t\t\t\t\t" + bars[maxIndex]
    
    # Save the bar chart as png and load it (enables better display)
    plt.savefig('rates.png')
    im = imageio.imread('rates.png')

    return name, im
 

# display the UI
interface = gr.Interface(fn=attribution, 
                         inputs=[gr.inputs.Textbox(lines=20, placeholder="Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête.")],
                          outputs=['text', 'image'])
interface.launch()