File size: 9,158 Bytes
f657d03
71775e2
33a2a6e
71775e2
33a2a6e
f657d03
ccfea75
 
71775e2
bb5868c
b3a537f
 
 
 
 
71775e2
 
f5b762d
 
 
 
 
 
71775e2
 
83f2778
 
b3a537f
83f2778
 
 
 
 
 
 
 
 
 
b3a537f
83f2778
 
 
8010151
33a2a6e
ccfea75
b3a537f
 
ccfea75
83f2778
 
 
 
2a26975
83f2778
 
2a26975
83f2778
 
2a26975
83f2778
 
2a26975
83f2778
8d3edad
 
 
2a26975
83f2778
 
2a26975
83f2778
 
2a26975
83f2778
 
 
 
2a26975
83f2778
 
 
2a26975
83f2778
 
 
 
 
 
 
2a26975
83f2778
8d3edad
2a26975
83f2778
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71775e2
 
 
fbdb9bf
 
71775e2
 
8010151
 
 
 
71775e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97a23a7
71775e2
 
 
97a23a7
71775e2
 
 
 
 
 
 
 
 
f657d03
 
83f2778
71775e2
3639664
97a23a7
71775e2
f657d03
0d18f32
f657d03
0d18f32
f657d03
71775e2
f657d03
 
71775e2
ccfea75
 
0188bfe
ccfea75
 
 
f5b762d
bb5868c
ccfea75
0188bfe
ccfea75
 
 
 
 
 
 
 
71775e2
f657d03
 
 
71775e2
ad2aa16
f657d03
83f2778
71775e2
f657d03
83f2778
2a26975
 
d157f96
4c431ad
83f2778
 
f657d03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import datetime
import gradio as gr
from huggingface_hub import hf_hub_download
from langdetect import detect, DetectorFactory, detect_langs
import fasttext
from transformers import pipeline
from transformers_interpret import ZeroShotClassificationExplainer
import string, nltk

models = {'en': 'facebook/bart-large-mnli', #'Narsil/deberta-large-mnli-zero-cls', #'microsoft/deberta-xlarge-mnli', # English
          #'de': 'Sahajtomar/German_Zeroshot', # German
          #'es': 'Recognai/zeroshot_selectra_medium', # Spanish
          #'it': 'joeddav/xlm-roberta-large-xnli', # Italian
          #'ru': 'DeepPavlov/xlm-roberta-large-en-ru-mnli', # Russian
          #'tr': 'vicgalle/xlm-roberta-large-xnli-anli', # Turkish
          'no': 'NbAiLab/nb-bert-base-mnli'} # Norsk

hypothesis_templates = {'en': 'This passage talks about {}.', # English
                        #'de': 'Dieses beispiel ist {}.', # German
                        #'es': 'Este ejemplo es {}.', # Spanish
                        #'it': 'Questo esempio è {}.', # Italian
                        #'ru': 'Этот пример {}.', # Russian
                        #'tr': 'Bu örnek {}.', # Turkish
                        'no': 'Dette eksempelet er {}.'} # Norsk

classifiers = {'en': pipeline("zero-shot-classification", hypothesis_template=hypothesis_templates['en'],
                              model=models['en']),
               '''
               'de': pipeline("zero-shot-classification", hypothesis_template=hypothesis_templates['de'],
                              model=models['de']),
               'es': pipeline("zero-shot-classification", hypothesis_template=hypothesis_templates['es'],
                              model=models['es']),
               'it': pipeline("zero-shot-classification", hypothesis_template=hypothesis_templates['it'],
                              model=models['it']),
               'ru': pipeline("zero-shot-classification", hypothesis_template=hypothesis_templates['ru'],
                              model=models['ru']),
               'tr': pipeline("zero-shot-classification", hypothesis_template=hypothesis_templates['tr'],
                              model=models['tr']),
               '''
               'no': pipeline("zero-shot-classification", hypothesis_template=hypothesis_templates['no'],
                              model=models['no'])}

fasttext_model = fasttext.load_model(hf_hub_download("julien-c/fasttext-language-id", "lid.176.bin"))

_ = nltk.download('stopwords', quiet=True)
#_ = nltk.download('wordnet', quiet=True)
#_ = nltk.download('punkt', quiet=True)

def prep_examples():
    example_text1 = "Coronavirus disease (COVID-19) is an infectious disease caused by the SARS-CoV-2 virus. Most \
    people who fall sick with COVID-19 will experience mild to moderate symptoms and recover without special treatment. \
    However, some will become seriously ill and require medical attention."
    example_labels1 = "business;;health related;;politics;;climate change"

    example_text2 = "Elephants are"
    example_labels2 = "big;;small;;strong;;fast;;carnivorous"

    example_text3 = "Elephants"
    example_labels3 = "are big;;can be very small;;generally not strong enough;;are faster than you think"

    example_text4 = "Dogs are man's best friend"
    example_labels4 = "positive;;negative;;neutral"

    example_text5 = "Şampiyonlar Ligi’nde 5. hafta oynanan karşılaşmaların ardından sona erdi. Real Madrid, \
    Inter ve Sporting oynadıkları mücadeleler sonrasında Son 16 turuna yükselmeyi başardı. \
    Gecenin dev mücadelesinde ise Manchester City, PSG’yi yenerek liderliği garantiledi."
    example_labels5 = "dünya;;ekonomi;;kültür;;siyaset;;spor;;teknoloji"

    example_text6 = "Letzte Woche gab es einen Selbstmord in einer nahe gelegenen kolonie"
    example_labels6 = "verbrechen;;tragödie;;stehlen"

    example_text7 = "El autor se perfila, a los 50 años de su muerte, como uno de los grandes de su siglo"
    example_labels7 = "cultura;;sociedad;;economia;;salud;;deportes"

    example_text8 = "Россия в среду заявила, что военные учения в аннексированном Москвой Крыму закончились \
    и что солдаты возвращаются в свои гарнизоны, на следующий день после того, как она объявила о первом выводе \
    войск от границ Украины."
    example_labels8 = "новости;;комедия"

    example_text9 = "I quattro registi - Federico Fellini, Pier Paolo Pasolini, Bernardo Bertolucci e Vittorio De Sica - \
    hanno utilizzato stili di ripresa diversi, ma hanno fortemente influenzato le giovani generazioni di registi."
    example_labels9 = "cinema;;politica;;cibo"

    example_text10 = "Ja, vi elsker dette landet,\
    som det stiger frem,\
    furet, værbitt over vannet,\
    med de tusen hjem.\
    Og som fedres kamp har hevet\
    det av nød til seir"
    example_labels10 = "helse;;sport;;religion;;mat;;patriotisme og nasjonalisme"

    example_text11 = "Amar sonar bangla ami tomay bhalobasi"
    example_labels11 = "bhalo;;kharap"

    examples = [
        [example_text1, example_labels1],
        [example_text2, example_labels2],
        [example_text3, example_labels3],
        [example_text4, example_labels4],
        [example_text5, example_labels5],
        [example_text6, example_labels6],
        [example_text7, example_labels7],
        [example_text8, example_labels8],
        [example_text9, example_labels9],
        [example_text10, example_labels10],
        [example_text11, example_labels11]]

    return examples

def detect_lang(sequence, labels):
    DetectorFactory.seed = 0
    seq_lang = 'en'
    
    sequence = sequence.replace('\n', ' ')

    try:
        #seq_lang = detect(sequence)
        #lbl_lang = detect(labels)
        seq_lang = fasttext_model.predict(sequence, k=1)[0][0].split("__label__")[1]
        lbl_lang = fasttext_model.predict(labels, k=1)[0][0].split("__label__")[1]
    except:
        print("Language detection failed!",
              "Date:{}, Sequence:{}, Labels:{}".format(
                  str(datetime.datetime.now()),
                  labels))

    if seq_lang != lbl_lang:
        print("Different languages detected for sequence and labels!",
              "Date:{}, Sequence:{}, Labels:{}, Sequence Language:{}, Label Language:{}".format(
                  str(datetime.datetime.now()),
                  sequence,
                  labels,
                  seq_lang,
                  lbl_lang))

    if seq_lang in models:
        print("Sequence Language detected.",
              "Date:{}, Sequence:{}, Sequence Language:{}".format(
                  str(datetime.datetime.now()),
                  sequence,
                  seq_lang))
    else:
        print("Language not supported. Defaulting to English!",
              "Date:{}, Sequence:{}, Sequence Language:{}".format(
                  str(datetime.datetime.now()),
                  sequence,
                  seq_lang))
        seq_lang = 'en'

    return seq_lang

def sequence_to_classify(sequence, labels):
    classifier = classifiers[detect_lang(sequence, labels)]

    label_clean = str(labels).split(";;")
    response = classifier(sequence, label_clean, multi_label=True)

    predicted_labels = response['labels']
    print(predicted_labels)
    predicted_scores = response['scores']
    print(predicted_scores)
    clean_output = {idx: float(predicted_scores.pop(0)) for idx in predicted_labels}
    print("Date:{}, Sequence:{}, Labels: {}".format(
        str(datetime.datetime.now()),
        sequence,
        predicted_labels))
    
    # Explain word attributes
    stop_words = nltk.corpus.stopwords.words('english')
    puncts = list(string.punctuation)
        
    model_expl = ZeroShotClassificationExplainer(classifier.model, classifier.tokenizer)
    response_expl = model_expl(sequence, label_clean, hypothesis_template="This passage talks about {}.")
    print(model_expl.predicted_label)
    
    if len(predicted_labels) == 1:
        response_expl = response_expl[model_expl.predicted_label]

    for key in response_expl:
        for idx, elem in enumerate(response_expl[key]):
            if elem[0] in stop_words:
                del response_expl[key][idx]
    
    print(response_expl)

    return clean_output

iface = gr.Interface(
    title="Multilingual Multi-label Zero-shot Classification",
    description="Currently supported languages are English, German, Spanish, Italian, Russian, Turkish, Norsk.",
    fn=sequence_to_classify,
    inputs=[gr.inputs.Textbox(lines=10,
        label="Please enter the text you would like to classify...",
        placeholder="Text here..."),
        gr.inputs.Textbox(lines=2,
        label="Please enter the candidate labels (separated by 2 consecutive semicolons)...",
        placeholder="Labels here separated by ;;")],
    outputs=gr.outputs.Label(num_top_classes=5),
    #interpretation="default",
    examples=prep_examples())

iface.launch()