WwYc commited on
Commit
bb871c3
1 Parent(s): d03353e

Update BERT/BERT_explainability/modules/BERT/ExplanationGenerator.py

Browse files
BERT/BERT_explainability/modules/BERT/ExplanationGenerator.py CHANGED
@@ -37,7 +37,7 @@ class Generator:
37
  one_hot[0, index] = 1
38
  one_hot_vector = one_hot
39
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
40
- one_hot = torch.sum(one_hot.cuda() * output)
41
 
42
  self.model.zero_grad()
43
  one_hot.backward(retain_graph=True)
@@ -70,7 +70,7 @@ class Generator:
70
  one_hot[0, index] = 1
71
  one_hot_vector = one_hot
72
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
73
- one_hot = torch.sum(one_hot.cuda() * output)
74
 
75
  self.model.zero_grad()
76
  one_hot.backward(retain_graph=True)
@@ -94,7 +94,7 @@ class Generator:
94
  one_hot[0, index] = 1
95
  one_hot_vector = one_hot
96
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
97
- one_hot = torch.sum(one_hot.cuda() * output)
98
 
99
  self.model.zero_grad()
100
  one_hot.backward(retain_graph=True)
@@ -136,7 +136,7 @@ class Generator:
136
  one_hot[0, index] = 1
137
  one_hot_vector = one_hot
138
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
139
- one_hot = torch.sum(one_hot.cuda() * output)
140
 
141
  self.model.zero_grad()
142
  one_hot.backward(retain_graph=True)
 
37
  one_hot[0, index] = 1
38
  one_hot_vector = one_hot
39
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
40
+ one_hot = torch.sum(one_hot * output)
41
 
42
  self.model.zero_grad()
43
  one_hot.backward(retain_graph=True)
 
70
  one_hot[0, index] = 1
71
  one_hot_vector = one_hot
72
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
73
+ one_hot = torch.sum(one_hot * output)
74
 
75
  self.model.zero_grad()
76
  one_hot.backward(retain_graph=True)
 
94
  one_hot[0, index] = 1
95
  one_hot_vector = one_hot
96
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
97
+ one_hot = torch.sum(one_hot * output)
98
 
99
  self.model.zero_grad()
100
  one_hot.backward(retain_graph=True)
 
136
  one_hot[0, index] = 1
137
  one_hot_vector = one_hot
138
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
139
+ one_hot = torch.sum(one_hot * output)
140
 
141
  self.model.zero_grad()
142
  one_hot.backward(retain_graph=True)