Spaces:
Runtime error
Runtime error
from transformers import BertTokenizer, BertForSequenceClassification | |
import torch | |
import gradio as gr | |
from transformers import TextClassificationPipeline | |
model_name='indolem/indobert-base-uncased' | |
label_dict={'Pasal 112 UU RI No. 35 Thn 2009': 0, | |
'Pasal 114 UU RI No. 35 Thn 2009': 1, | |
'Pasal 111 UU RI No. 35 Thn 2009': 2, | |
'Pasal 127 UU RI No. 35 Thn 2009': 3, | |
'Pasal 363 KUHP': 4, | |
'Pasal 365 KUHP': 5, | |
'Pasal 362 KUHP': 6, | |
'Pasal 338 KUHP': 7, | |
'Pasal 340 KUHP': 8, | |
'Pasal 374 KUHP': 9, | |
'Pasal 372 KUHP': 10, | |
'Pasal 378 KUHP': 11, | |
'Pasal 351 KUHP': 12, | |
'Pasal 303 KUHP': 13} | |
tokenizer = BertTokenizer.from_pretrained(model_name) | |
model = BertForSequenceClassification.from_pretrained(model_name, | |
num_labels=len(label_dict), | |
output_attentions=False, | |
output_hidden_states=False) | |
torch_model = torch.load('FineTune_IndoLEM_BERT_H_Mean_Pooling_LR1E-5_BS2_epoch_9.model') | |
torch_model['classifier.weight'] = torch_model.pop('out.weight') | |
torch_model['classifier.bias'] = torch_model.pop('out.bias') | |
model.load_state_dict(torch_model) | |
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True) | |
def get_nth_key(dictionary, n=0): | |
if n < 0: | |
n += len(dictionary) | |
for i, key in enumerate(dictionary.keys()): | |
if i == n: | |
return key | |
raise IndexError("dictionary index out of range") | |
def predict(text): | |
predictions = pipe(text)[0] | |
max = 0 | |
idx = -1 | |
for i in range(len(predictions)): | |
if max < predictions[i]['score']: | |
max = predictions[i]['score'] | |
idx = i | |
return get_nth_key(label_dict, idx) | |
iface = gr.Interface(fn=predict, inputs="text", outputs="text") | |
iface.launch() |