rendy-ad89 commited on
Commit
0f49bdc
1 Parent(s): 28efa8b

updated to 14 class

Browse files
Files changed (1) hide show
  1. app.py +21 -8
app.py CHANGED
@@ -1,21 +1,35 @@
1
  from transformers import BertTokenizer, BertForSequenceClassification
2
  import torch
3
  import gradio as gr
 
4
 
5
- model_name='indobenchmark/indobert-base-p1'
6
  label_dict={'Pasal 112 UU RI No. 35 Thn 2009': 0,
7
- 'Pasal 363 KUHP': 1,
8
- 'Pasal 372 KUHP': 2,
9
- 'Pasal 378 KUHP': 3}
 
 
 
 
 
 
 
 
 
 
10
 
11
  tokenizer = BertTokenizer.from_pretrained(model_name)
12
  model = BertForSequenceClassification.from_pretrained(model_name,
13
  num_labels=len(label_dict),
14
  output_attentions=False,
15
  output_hidden_states=False)
16
- model.load_state_dict(torch.load('finetuned_BERT_epoch_9.model', map_location=torch.device('cpu')))
17
 
18
- from transformers import TextClassificationPipeline
 
 
 
 
19
  pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)
20
 
21
  def get_nth_key(dictionary, n=0):
@@ -24,11 +38,10 @@ def get_nth_key(dictionary, n=0):
24
  for i, key in enumerate(dictionary.keys()):
25
  if i == n:
26
  return key
27
- raise IndexError("dictionary index out of range")
28
 
29
  def predict(text):
30
  predictions = pipe(text)[0]
31
- print(predictions)
32
  max = 0
33
  idx = -1
34
  for i in range(len(predictions)):
 
1
  from transformers import BertTokenizer, BertForSequenceClassification
2
  import torch
3
  import gradio as gr
4
+ from transformers import TextClassificationPipeline
5
 
6
+ model_name='indolem/indobert-base-uncased'
7
  label_dict={'Pasal 112 UU RI No. 35 Thn 2009': 0,
8
+ 'Pasal 114 UU RI No. 35 Thn 2009': 1,
9
+ 'Pasal 111 UU RI No. 35 Thn 2009': 2,
10
+ 'Pasal 127 UU RI No. 35 Thn 2009': 3,
11
+ 'Pasal 363 KUHP': 4,
12
+ 'Pasal 365 KUHP': 5,
13
+ 'Pasal 362 KUHP': 6,
14
+ 'Pasal 338 KUHP': 7,
15
+ 'Pasal 340 KUHP': 8,
16
+ 'Pasal 374 KUHP': 9,
17
+ 'Pasal 372 KUHP': 10,
18
+ 'Pasal 378 KUHP': 11,
19
+ 'Pasal 351 KUHP': 12,
20
+ 'Pasal 303 KUHP': 13}
21
 
22
  tokenizer = BertTokenizer.from_pretrained(model_name)
23
  model = BertForSequenceClassification.from_pretrained(model_name,
24
  num_labels=len(label_dict),
25
  output_attentions=False,
26
  output_hidden_states=False)
 
27
 
28
+ torch_model = torch.load('FineTune_IndoLEM_BERT_H_Mean_Pooling_LR1E-5_BS2_epoch_9.model')
29
+ torch_model['classifier.weight'] = torch_model.pop('out.weight')
30
+ torch_model['classifier.bias'] = torch_model.pop('out.bias')
31
+ model.load_state_dict(torch_model)
32
+
33
  pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True)
34
 
35
  def get_nth_key(dictionary, n=0):
 
38
  for i, key in enumerate(dictionary.keys()):
39
  if i == n:
40
  return key
41
+ raise IndexError("dictionary index out of range")
42
 
43
  def predict(text):
44
  predictions = pipe(text)[0]
 
45
  max = 0
46
  idx = -1
47
  for i in range(len(predictions)):