ksvmuralidhar commited on
Commit
8fb0cad
1 Parent(s): d303714

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from unidecode import unidecode
5
+ import tensorflow as tf
6
+ import cloudpickle
7
+ from transformers import DistilBertTokenizerFast
8
+ import os
9
+
10
+ def load_model():
11
+ interpreter = tf.lite.Interpreter(model_path=os.path.join("models/lang_detect_hf_distilbert.tflite"))
12
+ with open("models/lang_detect_labelencoder.bin", "rb") as model_file_obj:
13
+ label_encoder = cloudpickle.load(model_file_obj)
14
+
15
+ model_checkpoint = "distilbert-base-multilingual-cased"
16
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)
17
+ return interpreter, label_encoder, tokenizer
18
+
19
+ interpreter, label_encoder, tokenizer = load_model()
20
+
21
+ def inference(text):
22
+ tflite_pred = "Can't Predict"
23
+ if text != "":
24
+ tokens = tokenizer(text, max_length=50, padding="max_length", truncation=True, return_tensors="tf")
25
+ # tflite model inference
26
+ interpreter.allocate_tensors()
27
+ input_details = interpreter.get_input_details()
28
+ output_details = interpreter.get_output_details()[0]
29
+ attention_mask, input_ids = tokens['attention_mask'], tokens['input_ids']
30
+ interpreter.set_tensor(input_details[0]["index"], attention_mask)
31
+ interpreter.set_tensor(input_details[1]["index"], input_ids)
32
+ interpreter.invoke()
33
+ tflite_pred = interpreter.get_tensor(output_details["index"])[0]
34
+ tflite_pred_argmax = np.argmax(tflite_pred)
35
+ tflite_pred = f"{label_encoder.inverse_transform([tflite_pred_argmax])[0].upper()} ({str(np.round(tflite_pred[tflite_pred_argmax], 3))})"
36
+ return tflite_pred
37
+
38
+
39
+ def main():
40
+ st.title("Language Detection")
41
+ lang_trained = 'eng, rus, ita, tur, epo, ber, deu, kab, fra, por, spa, hun, jpn, heb, ukr, nld, fin, pol, mkd, lit, cmn, mar, ces, dan'.upper()
42
+ st.write(f'Model is trained on the following languages \n{lang_trained}')
43
+ review = st.text_area("Enter Text:", "", height=200)
44
+ if st.button("Submit"):
45
+ result = inference(review)
46
+ st.write(result)
47
+
48
+ if __name__ == "__main__":
49
+ main()