legal-indobert / app.py
rendy-ad89's picture
updated to 14 class
0f49bdc
raw
history blame
1.86 kB
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import gradio as gr
from transformers import TextClassificationPipeline
model_name='indolem/indobert-base-uncased'
label_dict={'Pasal 112 UU RI No. 35 Thn 2009': 0,
'Pasal 114 UU RI No. 35 Thn 2009': 1,
'Pasal 111 UU RI No. 35 Thn 2009': 2,
'Pasal 127 UU RI No. 35 Thn 2009': 3,
'Pasal 363 KUHP': 4,
'Pasal 365 KUHP': 5,
'Pasal 362 KUHP': 6,
'Pasal 338 KUHP': 7,
'Pasal 340 KUHP': 8,
'Pasal 374 KUHP': 9,
'Pasal 372 KUHP': 10,
'Pasal 378 KUHP': 11,
'Pasal 351 KUHP': 12,
'Pasal 303 KUHP': 13}
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name,
num_labels=len(label_dict),
output_attentions=False,
output_hidden_states=False)
torch_model = torch.load('FineTune_IndoLEM_BERT_H_Mean_Pooling_LR1E-5_BS2_epoch_9.model')
torch_model['classifier.weight'] = torch_model.pop('out.weight')
torch_model['classifier.bias'] = torch_model.pop('out.bias')
model.load_state_dict(torch_model)
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)
def get_nth_key(dictionary, n=0):
if n < 0:
n += len(dictionary)
for i, key in enumerate(dictionary.keys()):
if i == n:
return key
raise IndexError("dictionary index out of range")
def predict(text):
predictions = pipe(text)[0]
max = 0
idx = -1
for i in range(len(predictions)):
if max < predictions[i]['score']:
max = predictions[i]['score']
idx = i
return get_nth_key(label_dict, idx)
iface = gr.Interface(fn=predict, inputs="text", outputs="text")
iface.launch()