Spaces:
Runtime error
Runtime error
Update BERT/BERT_explainability/modules/BERT/ExplanationGenerator.py
Browse files
BERT/BERT_explainability/modules/BERT/ExplanationGenerator.py
CHANGED
@@ -37,7 +37,7 @@ class Generator:
|
|
37 |
one_hot[0, index] = 1
|
38 |
one_hot_vector = one_hot
|
39 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
40 |
-
one_hot = torch.sum(one_hot
|
41 |
|
42 |
self.model.zero_grad()
|
43 |
one_hot.backward(retain_graph=True)
|
@@ -70,7 +70,7 @@ class Generator:
|
|
70 |
one_hot[0, index] = 1
|
71 |
one_hot_vector = one_hot
|
72 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
73 |
-
one_hot = torch.sum(one_hot
|
74 |
|
75 |
self.model.zero_grad()
|
76 |
one_hot.backward(retain_graph=True)
|
@@ -94,7 +94,7 @@ class Generator:
|
|
94 |
one_hot[0, index] = 1
|
95 |
one_hot_vector = one_hot
|
96 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
97 |
-
one_hot = torch.sum(one_hot
|
98 |
|
99 |
self.model.zero_grad()
|
100 |
one_hot.backward(retain_graph=True)
|
@@ -136,7 +136,7 @@ class Generator:
|
|
136 |
one_hot[0, index] = 1
|
137 |
one_hot_vector = one_hot
|
138 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
139 |
-
one_hot = torch.sum(one_hot
|
140 |
|
141 |
self.model.zero_grad()
|
142 |
one_hot.backward(retain_graph=True)
|
|
|
37 |
one_hot[0, index] = 1
|
38 |
one_hot_vector = one_hot
|
39 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
40 |
+
one_hot = torch.sum(one_hot * output)
|
41 |
|
42 |
self.model.zero_grad()
|
43 |
one_hot.backward(retain_graph=True)
|
|
|
70 |
one_hot[0, index] = 1
|
71 |
one_hot_vector = one_hot
|
72 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
73 |
+
one_hot = torch.sum(one_hot * output)
|
74 |
|
75 |
self.model.zero_grad()
|
76 |
one_hot.backward(retain_graph=True)
|
|
|
94 |
one_hot[0, index] = 1
|
95 |
one_hot_vector = one_hot
|
96 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
97 |
+
one_hot = torch.sum(one_hot * output)
|
98 |
|
99 |
self.model.zero_grad()
|
100 |
one_hot.backward(retain_graph=True)
|
|
|
136 |
one_hot[0, index] = 1
|
137 |
one_hot_vector = one_hot
|
138 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
139 |
+
one_hot = torch.sum(one_hot * output)
|
140 |
|
141 |
self.model.zero_grad()
|
142 |
one_hot.backward(retain_graph=True)
|