Spaces:
Runtime error
Runtime error
Update helpers/listeners.py
Browse files- helpers/listeners.py +18 -4
helpers/listeners.py
CHANGED
@@ -182,8 +182,12 @@ def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, layer_sel,
|
|
182 |
|
183 |
elif (channel is None and nodeX is None and nodeY is None):
|
184 |
gr.Info("Convolutional Layer Specific")
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
187 |
|
188 |
# Unknown
|
189 |
else:
|
@@ -196,11 +200,21 @@ def generate(lr, epochs, img_size, channel, nodeX, nodeY, node, layer_sel,
|
|
196 |
obj = objs.channel(layer_sel[0], node)
|
197 |
else:
|
198 |
gr.Info("Linear Layer Specific")
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
200 |
case _:
|
201 |
gr.Info("Attempting unknown Layer Specific")
|
202 |
transforms = [] # Just in case
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
204 |
|
205 |
thresholds = h_manip.expo_tuple(epochs, 6)
|
206 |
|
|
|
182 |
|
183 |
elif (channel is None and nodeX is None and nodeY is None):
|
184 |
gr.Info("Convolutional Layer Specific")
|
185 |
+
if torch.cuda.is_available():
|
186 |
+
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(),
|
187 |
+
torch.tensor(2).cuda())).cuda()
|
188 |
+
else:
|
189 |
+
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]),
|
190 |
+
torch.tensor(2)))
|
191 |
|
192 |
# Unknown
|
193 |
else:
|
|
|
200 |
obj = objs.channel(layer_sel[0], node)
|
201 |
else:
|
202 |
gr.Info("Linear Layer Specific")
|
203 |
+
if torch.cuda.is_available():
|
204 |
+
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(),
|
205 |
+
torch.tensor(2).cuda())).cuda()
|
206 |
+
else:
|
207 |
+
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]),
|
208 |
+
torch.tensor(2)))
|
209 |
case _:
|
210 |
gr.Info("Attempting unknown Layer Specific")
|
211 |
transforms = [] # Just in case
|
212 |
+
if torch.cuda.is_available():
|
213 |
+
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]).cuda(),
|
214 |
+
torch.tensor(2).cuda())).cuda()
|
215 |
+
else:
|
216 |
+
obj = lambda m: torch.mean(torch.pow(-m(layer_sel[0]),
|
217 |
+
torch.tensor(2)))
|
218 |
|
219 |
thresholds = h_manip.expo_tuple(epochs, 6)
|
220 |
|