sohomghosh commited on
Commit
9f31d25
1 Parent(s): ee4af77

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -95,7 +95,7 @@ class BERTClass(torch.nn.Module):
95
  output = self.classifier(pooler)
96
  return output
97
 
98
- def do_predict(model, tokenizer):
99
  test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
100
  test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
101
  test_loader = DataLoader(test_set, **test_params)
@@ -119,7 +119,7 @@ model_read.to(device)
119
  model_read.load_stat_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
120
 
121
  tokenizer_read = BertTokenizer.from_pretrained('ProsusAI/finbert')
122
- actual_predictions_read = do_predict(model_read, tokenizer_read)
123
 
124
  test_df['readability'] = ['readable' if i==1 else 'not_reabale' for i in actual_predictions_read]
125
 
 
95
  output = self.classifier(pooler)
96
  return output
97
 
98
+ def do_predict(model, tokenizer, test_df):
99
  test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
100
  test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
101
  test_loader = DataLoader(test_set, **test_params)
 
119
  model_read.load_stat_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
120
 
121
  tokenizer_read = BertTokenizer.from_pretrained('ProsusAI/finbert')
122
+ actual_predictions_read = do_predict(model_read, tokenizer_read, test_df)
123
 
124
  test_df['readability'] = ['readable' if i==1 else 'not_reabale' for i in actual_predictions_read]
125