Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from unidecode import unidecode | |
import tensorflow as tf | |
import cloudpickle | |
from transformers import DistilBertTokenizerFast | |
import os | |
def load_model(): | |
interpreter = tf.lite.Interpreter(model_path=os.path.join("models/lang_detect_hf_distilbert.tflite")) | |
with open("models/lang_detect_labelencoder.bin", "rb") as model_file_obj: | |
label_encoder = cloudpickle.load(model_file_obj) | |
model_checkpoint = "distilbert-base-multilingual-cased" | |
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint) | |
return interpreter, label_encoder, tokenizer | |
interpreter, label_encoder, tokenizer = load_model() | |
def inference(text): | |
tflite_pred = "Can't Predict" | |
if text != "": | |
tokens = tokenizer(text, max_length=50, padding="max_length", truncation=True, return_tensors="tf") | |
# tflite model inference | |
interpreter.allocate_tensors() | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details()[0] | |
attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids'] | |
interpreter.set_tensor(input_details[0]["index"], attention_mask) | |
interpreter.set_tensor(input_details[1]["index"], input_ids) | |
interpreter.invoke() | |
tflite_pred = interpreter.get_tensor(output_details["index"])[0] | |
tflite_pred_argmax = np.argmax(tflite_pred) | |
tflite_pred = f"{label_encoder.inverse_transform([tflite_pred_argmax])[0].upper()} ({str(np.round(tflite_pred[tflite_pred_argmax], 3))})" | |
return tflite_pred | |
def main(): | |
st.title("Language Detection") | |
lang_trained = 'eng, rus, ita, tur, epo, ber, deu, kab, fra, por, spa, hun, jpn, heb, ukr, nld, fin, pol, mkd, lit, cmn, mar, ces, dan'.upper() | |
st.write(f'Model is trained on the following languages \n{lang_trained}') | |
review = st.text_area("Enter Text:", "", height=200) | |
if st.button("Submit"): | |
result = inference(review) | |
st.write(result) | |
if __name__ == "__main__": | |
main() | |