WwYc commited on
Commit
08d7fd8
1 Parent(s): 9825a89

Update DETR/modules/ExplanationGenerator.py

Browse files
DETR/modules/ExplanationGenerator.py CHANGED
@@ -157,7 +157,7 @@ class Generator:
157
  one_hot[0, target_index, index] = 1
158
  one_hot_vector = one_hot
159
  one_hot.requires_grad_(True)
160
- one_hot = torch.sum(one_hot.cuda() * outputs)
161
 
162
  self.model.zero_grad()
163
  one_hot.backward(retain_graph=True)
 
157
  one_hot[0, target_index, index] = 1
158
  one_hot_vector = one_hot
159
  one_hot.requires_grad_(True)
160
+ one_hot = torch.sum(one_hot * outputs)
161
 
162
  self.model.zero_grad()
163
  one_hot.backward(retain_graph=True)