sohomghosh commited on
Commit
ed29696
1 Parent(s): 8a57cfb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -9
README.md CHANGED
@@ -43,13 +43,13 @@ class Triage(Dataset):
43
  This is a subclass of torch packages Dataset class. It processes input to create ids, masks and targets required for model training.
44
  """
45
 
46
- def __init__(self, dataframe, tokenizer, max_len, text_col_name, category_col):
47
  self.len = len(dataframe)
48
  self.data = dataframe
49
  self.tokenizer = tokenizer
50
  self.max_len = max_len
51
  self.text_col_name = text_col_name
52
- self.category_col = category_col
53
 
54
  def __getitem__(self, index):
55
  title = str(self.data[self.text_col_name][index])
@@ -69,14 +69,12 @@ class Triage(Dataset):
69
  return {
70
  "ids": torch.tensor(ids, dtype=torch.long),
71
  "mask": torch.tensor(mask, dtype=torch.long),
72
- "targets": torch.tensor(
73
- self.data[self.category_col][index], dtype=torch.long
74
- ),
75
  }
76
 
77
  def __len__(self):
78
  return self.len
79
-
80
  class BERTClass(torch.nn.Module):
81
  def __init__(self, num_class):
82
  super(BERTClass, self).__init__()
@@ -97,7 +95,7 @@ class BERTClass(torch.nn.Module):
97
  output = self.classifier(pooler)
98
  return output
99
 
100
- def do_predict(tokenizer):
101
  test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
102
  test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
103
  test_loader = DataLoader(test_set, **test_params)
@@ -116,12 +114,12 @@ def do_predict(tokenizer):
116
  actual_predictions = [i[0] for i in preds.tolist()]
117
  return actual_predictions
118
 
119
- model_sus = BERTClass(2)
120
  model_sustain.to(device)
121
  model_sustain.load_state_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
122
 
123
  tokenizer_sus = BertTokenizer.from_pretrained('roberta-base')
124
- actual_predictions_sus = do_predict(tokenizer_sus)
125
 
126
  test_df['sustainability'] = ['sustainable' if i==0 else 'unsustainable' for i in actual_predictions_read]
127
  ```
 
43
  This is a subclass of torch packages Dataset class. It processes input to create ids, masks and targets required for model training.
44
  """
45
 
46
+ def __init__(self, dataframe, tokenizer, max_len, text_col_name):
47
  self.len = len(dataframe)
48
  self.data = dataframe
49
  self.tokenizer = tokenizer
50
  self.max_len = max_len
51
  self.text_col_name = text_col_name
52
+
53
 
54
  def __getitem__(self, index):
55
  title = str(self.data[self.text_col_name][index])
 
69
  return {
70
  "ids": torch.tensor(ids, dtype=torch.long),
71
  "mask": torch.tensor(mask, dtype=torch.long),
72
+
 
 
73
  }
74
 
75
  def __len__(self):
76
  return self.len
77
+
78
  class BERTClass(torch.nn.Module):
79
  def __init__(self, num_class):
80
  super(BERTClass, self).__init__()
 
95
  output = self.classifier(pooler)
96
  return output
97
 
98
+ def do_predict(model, tokenizer):
99
  test_set = Triage(test_df, tokenizer, MAX_LEN, text_col_name)
100
  test_params = {'batch_size' : BATCH_SIZE, 'shuffle': False, 'num_workers':0}
101
  test_loader = DataLoader(test_set, **test_params)
 
114
  actual_predictions = [i[0] for i in preds.tolist()]
115
  return actual_predictions
116
 
117
+ model_sustain = BERTClass(2)
118
  model_sustain.to(device)
119
  model_sustain.load_state_dict(torch.load('pytorch_model.bin', map_location=device)['model_state_dict'])
120
 
121
  tokenizer_sus = BertTokenizer.from_pretrained('roberta-base')
122
+ actual_predictions_sus = do_predict(model_sustain, tokenizer_sus)
123
 
124
  test_df['sustainability'] = ['sustainable' if i==0 else 'unsustainable' for i in actual_predictions_read]
125
  ```