File size: 5,263 Bytes
41c3c5a
 
 
 
42207df
41c3c5a
 
 
 
 
 
 
 
 
 
 
 
b94d9cd
41c3c5a
b94d9cd
 
 
41c3c5a
 
6dabd3f
 
a3a025c
5253665
6dabd3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41c3c5a
 
 
 
 
 
 
 
 
 
 
6dabd3f
 
 
 
 
 
 
 
41c3c5a
 
6dabd3f
 
 
 
41c3c5a
fe01253
41c3c5a
 
 
 
 
 
6dabd3f
 
b94d9cd
41c3c5a
 
6dabd3f
41c3c5a
 
 
 
 
 
 
b94d9cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6dabd3f
 
 
 
b94d9cd
6dabd3f
5253665
6dabd3f
 
 
 
 
 
 
 
 
 
 
 
 
b94d9cd
 
 
6dabd3f
69faebb
5253665
41c3c5a
b94d9cd
41c3c5a
b94d9cd
28913e1
6dabd3f
 
 
 
 
28913e1
48bb59c
6dabd3f
08f58ed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# """
# Author: Amir Hossein Kargaran
# Date: August, 2023

# Description: This code applies LIME (Local Interpretable Model-Agnostic Explanations) on language identification models.

# MIT License

# Some part of the code is adopted from here: https://gist.github.com/ageitgey/60a8b556a9047a4ca91d6034376e5980
# """

import gradio as gr
from io import BytesIO
from fasttext.FastText import _FastText
import re
import lime.lime_text
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from selenium import webdriver
from selenium.common.exceptions import WebDriverException
import os


# Define a dictionary to map model choices to their respective paths
model_paths = {
    "OpenLID": ["laurievb/OpenLID", 'model.bin'],
    "GlotLID": ["cis-lmu/glotlid", 'model.bin'],
    "NLLB": ["facebook/fasttext-language-identification", 'model.bin']
}

# Create a dictionary to cache classifiers
cached_classifiers = {}

def load_classifier(model_choice):
    if model_choice in cached_classifiers:
        return cached_classifiers[model_choice]
    
    # Load the FastText language identification model from Hugging Face Hub
    model_path = hf_hub_download(repo_id=model_paths[model_choice][0], filename=model_paths[model_choice][1])
    
    # Create the FastText classifier
    classifier = _FastText(model_path)
    
    cached_classifiers[model_choice] = classifier
    return classifier

# cache all models
for model_choice in model_paths.keys():
    load_classifier(model_choice)


def remove_label_prefix(item):
    return item.replace('__label__', '')

def remove_label_prefix_list(input_list):
    if isinstance(input_list[0], list):
        return [[remove_label_prefix(item) for item in inner_list] for inner_list in input_list]
    else:
        return [remove_label_prefix(item) for item in input_list]


def tokenize_string(sentence, n=None):
    if n is None:
        tokens = sentence.split()
    else:
        tokens = []
        for i in range(len(sentence) - n + 1):
            tokens.append(sentence[i:i + n])
    return tokens


def fasttext_prediction_in_sklearn_format(classifier, texts, num_class):
    # if isinstance(texts, str):
    #     texts = [texts]

    res = []
    labels, probabilities = classifier.predict(texts, -1)
    labels = remove_label_prefix_list(labels)
    for label, probs, text in zip(labels, probabilities, texts):
        order = np.argsort(np.array(label))
        res.append(probs[order])
    return np.array(res)


def generate_explanation_html(input_sentence, explainer, classifier, num_class):
    preprocessed_sentence = input_sentence
    exp = explainer.explain_instance(
        preprocessed_sentence,
        classifier_fn=lambda x: fasttext_prediction_in_sklearn_format(classifier, x, num_class),
        top_labels=2,
        num_features=20,
    )
    output_html_filename = "explanation.html"
    exp.save_to_file(output_html_filename)
    return output_html_filename

def take_screenshot(local_html_path):
    options = webdriver.ChromeOptions()
    options.add_argument('--headless')
    options.add_argument('--no-sandbox')
    options.add_argument('--disable-dev-shm-usage')

    try:
        local_html_path = os.path.abspath(local_html_path)
        wd = webdriver.Chrome(options=options)
        wd.set_window_size(1366, 728)
        wd.get('file://' + local_html_path)
        wd.implicitly_wait(10)
        screenshot = wd.get_screenshot_as_png()
    except WebDriverException as e:
        return Image.new('RGB', (1, 1))
    finally:
        if wd:
            wd.quit()

    return Image.open(BytesIO(screenshot))


# Define the merge function
def merge_function(input_sentence, selected_model):

    input_sentence = input_sentence.replace('\n', ' ')

    # Load the FastText language identification model from Hugging Face Hub
    classifier = load_classifier(selected_model)
    class_names = remove_label_prefix_list(classifier.labels)
    class_names = np.sort(class_names)
    num_class = len(class_names)

    # Load Lime
    explainer = lime.lime_text.LimeTextExplainer(
    split_expression=tokenize_string,
    bow=False,
    class_names=class_names)

    # Generate output
    output_html_filename = generate_explanation_html(input_sentence, explainer, classifier, num_class)
    im = take_screenshot(output_html_filename)
    return im, output_html_filename

# Define the Gradio interface
input_text = gr.Textbox(label="Input Text", value="J'ai visited la beautiful beach avec mes amis for a relaxing journée under the sun.")
model_choice = gr.Radio(choices=["GlotLID", "OpenLID", "NLLB"], label="Select Model",  value='GlotLID')

output_explanation = gr.outputs.File(label="Explanation HTML")



iface = gr.Interface(merge_function, 
                     inputs=[input_text, model_choice], 
                     outputs=[gr.Image(type="pil", height=364, width=683, label = "Explanation Image"), output_explanation],
                     title="LIME LID",
                     description="This code applies LIME (Local Interpretable Model-Agnostic Explanations) on fasttext language identification.",
                     allow_flagging='never',
                     theme=gr.themes.Soft())
                                    
iface.launch()