LangID-LIME / app.py
kargaranamir's picture
Update app.py
42207df verified
raw history blame
No virus
5.26 kB
# """
# Author: Amir Hossein Kargaran
# Date: August, 2023
# Description: This code applies LIME (Local Interpretable Model-Agnostic Explanations) on language identification models.
# MIT License
# Some part of the code is adopted from here: https://gist.github.com/ageitgey/60a8b556a9047a4ca91d6034376e5980
# """
import gradio as gr
from io import BytesIO
from fasttext.FastText import _FastText
import re
import lime.lime_text
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from selenium import webdriver
from selenium.common.exceptions import WebDriverException
import os
# Define a dictionary to map model choices to their respective paths
model_paths = {
"OpenLID": ["laurievb/OpenLID", 'model.bin'],
"GlotLID": ["cis-lmu/glotlid", 'model.bin'],
"NLLB": ["facebook/fasttext-language-identification", 'model.bin']
}
# Create a dictionary to cache classifiers
cached_classifiers = {}
def load_classifier(model_choice):
if model_choice in cached_classifiers:
return cached_classifiers[model_choice]
# Load the FastText language identification model from Hugging Face Hub
model_path = hf_hub_download(repo_id=model_paths[model_choice][0], filename=model_paths[model_choice][1])
# Create the FastText classifier
classifier = _FastText(model_path)
cached_classifiers[model_choice] = classifier
return classifier
# cache all models
for model_choice in model_paths.keys():
load_classifier(model_choice)
def remove_label_prefix(item):
return item.replace('__label__', '')
def remove_label_prefix_list(input_list):
if isinstance(input_list[0], list):
return [[remove_label_prefix(item) for item in inner_list] for inner_list in input_list]
else:
return [remove_label_prefix(item) for item in input_list]
def tokenize_string(sentence, n=None):
if n is None:
tokens = sentence.split()
else:
tokens = []
for i in range(len(sentence) - n + 1):
tokens.append(sentence[i:i + n])
return tokens
def fasttext_prediction_in_sklearn_format(classifier, texts, num_class):
# if isinstance(texts, str):
# texts = [texts]
res = []
labels, probabilities = classifier.predict(texts, -1)
labels = remove_label_prefix_list(labels)
for label, probs, text in zip(labels, probabilities, texts):
order = np.argsort(np.array(label))
res.append(probs[order])
return np.array(res)
def generate_explanation_html(input_sentence, explainer, classifier, num_class):
preprocessed_sentence = input_sentence
exp = explainer.explain_instance(
preprocessed_sentence,
classifier_fn=lambda x: fasttext_prediction_in_sklearn_format(classifier, x, num_class),
top_labels=2,
num_features=20,
)
output_html_filename = "explanation.html"
exp.save_to_file(output_html_filename)
return output_html_filename
def take_screenshot(local_html_path):
options = webdriver.ChromeOptions()
options.add_argument('--headless')
options.add_argument('--no-sandbox')
options.add_argument('--disable-dev-shm-usage')
try:
local_html_path = os.path.abspath(local_html_path)
wd = webdriver.Chrome(options=options)
wd.set_window_size(1366, 728)
wd.get('file://' + local_html_path)
wd.implicitly_wait(10)
screenshot = wd.get_screenshot_as_png()
except WebDriverException as e:
return Image.new('RGB', (1, 1))
finally:
if wd:
wd.quit()
return Image.open(BytesIO(screenshot))
# Define the merge function
def merge_function(input_sentence, selected_model):
input_sentence = input_sentence.replace('\n', ' ')
# Load the FastText language identification model from Hugging Face Hub
classifier = load_classifier(selected_model)
class_names = remove_label_prefix_list(classifier.labels)
class_names = np.sort(class_names)
num_class = len(class_names)
# Load Lime
explainer = lime.lime_text.LimeTextExplainer(
split_expression=tokenize_string,
bow=False,
class_names=class_names)
# Generate output
output_html_filename = generate_explanation_html(input_sentence, explainer, classifier, num_class)
im = take_screenshot(output_html_filename)
return im, output_html_filename
# Define the Gradio interface
input_text = gr.Textbox(label="Input Text", value="J'ai visited la beautiful beach avec mes amis for a relaxing journée under the sun.")
model_choice = gr.Radio(choices=["GlotLID", "OpenLID", "NLLB"], label="Select Model", value='GlotLID')
output_explanation = gr.outputs.File(label="Explanation HTML")
iface = gr.Interface(merge_function,
inputs=[input_text, model_choice],
outputs=[gr.Image(type="pil", height=364, width=683, label = "Explanation Image"), output_explanation],
title="LIME LID",
description="This code applies LIME (Local Interpretable Model-Agnostic Explanations) on fasttext language identification.",
allow_flagging='never',
theme=gr.themes.Soft())
iface.launch()