tamilatis / app.py
seanbenhur's picture
add models
61697c7
raw history blame
No virus
1.15 kB
from tamilatis.predict import TamilATISPredictor
from tamilatis.model import JointATISModel
import numpy as np
from sklearn.preprocessing import LabelEncoder
import gradio as gr
model_name = "microsoft/xlm-align-base"
tokenizer_name = "microsoft/xlm-align-base"
num_labels = 78
num_intents = 23
checkpoint_path = "tamilatis/models/xlm_align_base.bin"
intent_encoder_path = "tamilatis/models/intent_classes.npy"
ner_encoder_path = "tamilatis/models/ner_classes.npy"
def predict_function(text):
label_encoder = LabelEncoder()
label_encoder.classes_ = np.load(ner_encoder_path)
intent_encoder = LabelEncoder()
intent_encoder.classes_ = np.load(intent_encoder_path)
model = JointATISModel(model_name,num_labels,num_intents)
predictor = TamilATISPredictor(model,checkpoint_path,tokenizer_name,
label_encoder,intent_encoder,num_labels)
slot_prediction, intent_preds = predictor.get_predictions(text)
return slot_prediction, intent_preds
title = "Tamil ATIS"
iface = gr.Interface(fn=predict_function, inputs="text", title=title,theme="huggingface",outputs=["text","text"])
iface.launch(share=True)