sohomghosh commited on
Commit
ed4339d
1 Parent(s): ed29696

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -95,7 +95,7 @@ class BERTClass(torch.nn.Module):
95
  output = self.classifier(pooler)
96
  return output
97
 
98
- def do_predict(model, tokenizer):
99
  test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
100
  test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
101
  test_loader = DataLoader(test_set, **test_params)
@@ -119,7 +119,7 @@ model_sustain.to(device)
119
  model_sustain.load_state_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
120
 
121
  tokenizer_sus = BertTokenizer.from_pretrained('roberta-base')
122
- actual_predictions_sus = do_predict(model_sustain, tokenizer_sus)
123
 
124
  test_df['sustainability'] = ['sustainable' if i==0 else 'unsustainable' for i in actual_predictions_read]
125
  ```
 
95
  output = self.classifier(pooler)
96
  return output
97
 
98
+ def do_predict(model, tokenizer, test_df):
99
  test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
100
  test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
101
  test_loader = DataLoader(test_set, **test_params)
 
119
  model_sustain.load_state_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
120
 
121
  tokenizer_sus = BertTokenizer.from_pretrained('roberta-base')
122
+ actual_predictions_sus = do_predict(model_sustain, tokenizer_sus, test_df)
123
 
124
  test_df['sustainability'] = ['sustainable' if i==0 else 'unsustainable' for i in actual_predictions_read]
125
  ```