File size: 33,719 Bytes
b9ba714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74f6289
 
b9ba714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74f6289
b9ba714
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
"""
@author: Caglar Aytekin
contact: caglar@deepcause.ai 
"""
import torch
import torch.nn as nn
import random 
import numpy as np 
import pandas as pd
import copy
class CustomEncodingFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, tau,alpha):
        ctx.save_for_backward(x, tau)
        # Perform the tanh operation on (x + tau) 
        y = torch.tanh(x + tau)
        # The actual forward output : binarized  output
        forward_output = alpha * (2 * torch.round((y + 1) / 2) - 1) + (1-alpha)*y
        return forward_output

    @staticmethod
    def backward(ctx, grad_output):
        x, tau = ctx.saved_tensors
        # Use the derivative of tanh for the backward pass: 1 - tanh^2(x + tau)
        grad_input = grad_output * (1 - torch.tanh(x + tau) ** 2)
        return grad_input, grad_input,None  # Assuming tau also requires gradient

# Wrapping the custom function in a nn.Module for easier use
class EncodingLayer(nn.Module):
    def __init__(self):
        super(EncodingLayer, self).__init__()
    def forward(self, x, tau,alpha):
        return CustomEncodingFunction.apply(x, tau,alpha)
    
class LEURN(nn.Module):
    def __init__(self, preprocessor,depth,droprate):
        """
        Initializes the model.
        
        Parameters:
        - preprocessor: A class containing useful info about the dataset 
            - Including: attribute names, categorical features details, suggested embedding size for each category, output type, output dimension, transformation information
        - depth: Depth of the network
        - droprate: dropout rate
        """
        super(LEURN, self).__init__()
        
        #Find categorical indices and category numbers for each
        self.alpha=1.0
        self.preprocessor=preprocessor
        self.attribute_names=preprocessor.attribute_names
        self.label_encoders=preprocessor.encoders_for_nn
        self.categorical_indices = [info[0] for info in preprocessor.category_details]
        self.num_categories = [info[1] for info in preprocessor.category_details]

        #If embedding_size is integer, cast it to all categories
        if isinstance(preprocessor.suggested_embeddings, int):
            embedding_sizes = [preprocessor.suggested_embeddings] * len(self.categorical_indices)
        else:
            assert len(preprocessor.suggested_embeddings) == len(self.categorical_indices), "Length of embedding_size must match number of categorical features"
            embedding_sizes = preprocessor.suggested_embeddings
        
        self.embedding_sizes=embedding_sizes
        
        #Embedding layers for categorical features
        self.embeddings = nn.ModuleList([
            nn.Embedding(num_categories, embedding_dim) 
            for num_categories, embedding_dim in zip(self.num_categories, embedding_sizes)
        ])
        
        for embedding_now in self.embeddings:
            nn.init.uniform_(embedding_now.weight, -1.0, 1.0)
        
        self.total_embedding_size = sum(embedding_sizes) #number of categorical features for NN
        self.non_cat_input_dim = len(self.attribute_names) - len(self.categorical_indices) #Number of numerical features for NN
        self.nn_input_dim = self.total_embedding_size + self.non_cat_input_dim #Number of features for NN
        

        #LAYERS
        
        self.tau_initial = nn.Parameter(torch.zeros(1,self.nn_input_dim))  # Initial tau as a learnable parameter
        self.layers = nn.ModuleList()
        self.depth = depth
        self.output_type=preprocessor.output_type
        
        for d_now in range(depth):
            # Each iteration adds an encoding layer followed by a dropout and then a linear layer
            self.layers.append(EncodingLayer())
            self.layers.append(nn.Dropout1d(droprate))
            linear_layer = nn.Linear((d_now + 1) * self.nn_input_dim, self.nn_input_dim)
            self._init_weights(linear_layer,d_now+1) #special layer initialization
            self.layers.append(linear_layer)
        
        
        # Final stage: dropout and linear layer
        self.final_dropout=nn.Dropout1d(droprate)
        self.final_linear = nn.Linear(depth * self.nn_input_dim, self.preprocessor.output_dim)
        self._init_weights(self.final_linear, depth)

    def set_alpha(self, alpha):
        """Method to update the dynamic parameter."""
        self.alpha = alpha

    def _init_weights(self, layer,depth_now):
        # Custom initialization 
        # Considering the binary (-1,1) nature of the input, 
        # when we initialize layer in (-1/dim,1/dim) range, output is bounded at (-1,1)
        # Knowing our input is roughly at (-1,1) range, this serves as good initialization for tau
        
        if not(self.embedding_sizes==[]):
            init_tensor = torch.tensor([1/size for size in self.embedding_sizes for _ in range(size)])
            if init_tensor.shape[0]<self.nn_input_dim: #Means we have numericals too
                init_tensor=torch.cat((init_tensor, torch.ones(self.non_cat_input_dim)), dim=0)
        else:
            init_tensor = torch.ones(self.non_cat_input_dim)
            
        init_tensor=init_tensor/((depth_now+1)*torch.tensor(len(self.attribute_names)))
        init_tensor=init_tensor.unsqueeze(0).repeat_interleave(repeats=layer.weight.shape[0],dim=0).repeat_interleave(repeats=depth_now,dim=1)
        layer.weight.data.uniform_(-1, 1)
        layer.weight=torch.nn.Parameter(layer.weight*init_tensor)

    
    def forward(self, x):
        # Defines forward function for provided input: Normalizes numericals, embeds categoricals, and gives to neural network.

                    
        # Separate categorical and numerical features for easier handling   
        cat_features = [x[:, i].long() for i in self.categorical_indices]
        non_cat_features = [x[:, i] for i in range(x.size(1)) if i not in self.categorical_indices]
        non_cat_features = torch.stack(non_cat_features, dim=1) if non_cat_features else x.new_empty(x.size(0), 0)
        
        # Embed categoricals
        embedded_features = [embedding(cat_feature) for embedding, cat_feature in zip(self.embeddings, cat_features)]
        # Combine categoricals and numericals
        try:
            embedded_features = torch.cat(embedded_features, dim=1)
            nninput = torch.cat([embedded_features, non_cat_features], dim=1)
        except:
            nninput=non_cat_features
        
        self.nninput=nninput
        
        # Forward pass neural network
        output=self.forward_from_embeddings(self.nninput)
        self.output=output
        return output

    def forward_from_embeddings(self,x):
        # Forward function for normalized numericals and embedded categoricals
        tau=self.tau_initial 
        tau=torch.repeat_interleave(tau,x.shape[0],0)  #tau is 1xF, cast it for batch
        # For each depth
        for i in range(0, self.depth * 3, 3):
            # encode, drop and find next tau
            encoding_layer = self.layers[i]
            dropout_layer = self.layers[i + 1]
            linear_layer = self.layers[i + 2]
            #encode and drop
            encoded_x =dropout_layer( encoding_layer(x, tau,self.alpha))
            #save encodings and thresholds
            #notice that threshold is -tau, not tau since we binarize x+tau
            if i==0:
                encodings=encoded_x
                taus=-tau
            else:
                encodings=torch.cat((encodings,encoded_x),dim=-1)
                taus=torch.cat((taus,-tau),dim=-1)
            #find next thresholds
            tau = linear_layer(encodings) #not used, redundant for last layer
        
        self.encodings=encodings
        self.taus=taus
        #Final layer: drop and linear
        output=self.final_linear(self.final_dropout(encodings))

        return output
    
    
    def find_boundaries(self, x):
        """
        Given input, find boundaries for numerical features and valid categories for categorical features
        Can accept unnormalized and not embedded input - set embedding False
        """
        # Ensure x is the correct shape [1, input_dim]
        if x.ndim == 1:
            x = x.unsqueeze(0)  # Add batch dimension if not present
        
        # Perform a forward pass to update self.encodings and self.taus
        # to update self.taus

        self(x)
        
        # self.taus has the shape [1, depth * input_dim]
        # reshape to [depth, input_dim] for easier boundary finding
        taus_reshaped = self.taus.view(self.depth, self.nn_input_dim) 
        
        # embedded and normalized input
        embedded_x=self.nninput
        
        # Initialize boundaries - numericals are in (-1,1) range and categoricals are from embeddings.
        # So -100,100 is safe min and max. -inf,+inf is not chosen since problematic for later sampling
        upper_boundaries = torch.full((embedded_x.size(1),), 100.0)
        lower_boundaries = torch.full((embedded_x.size(1),), -100.0)
        
        # Compare each threshold in self.taus with the corresponding input value
        for feature_index in range(self.nn_input_dim):
            for depth_index in range(self.depth):
                threshold = taus_reshaped[depth_index, feature_index]
                input_value = embedded_x[0, feature_index]
                
                # If the threshold is greater than the input value and less than the current upper boundary, update the upper boundary
                if threshold > input_value and threshold < upper_boundaries[feature_index]:
                    upper_boundaries[feature_index] = threshold
                
                # If the threshold is less than the input value and greater than the current lower boundary, update the lower boundary
                if threshold < input_value and threshold > lower_boundaries[feature_index]:
                    lower_boundaries[feature_index] = threshold
        
        # Convert boundaries to a list of tuples [(lower, upper), ...] for each feature
        boundaries = list(zip(lower_boundaries.tolist(), upper_boundaries.tolist()))
        
        
        self.upper_boundaries=upper_boundaries
        self.lower_boundaries=lower_boundaries
        

        return boundaries
    
    def categories_within_boundaries(self):
        """
        For each categorical feature, checks if embedding weights fall within the specified upper and lower boundaries.
        Returns a dictionary with categorical feature indices as keys and lists of category indices that fall within the boundaries.
        """
        categories_within_bounds = {}
        emb_st=0
        for cat_index, emb_layer in zip(range(len(self.categorical_indices)), self.embeddings):
            # Extract upper and lower boundaries for this categorical feature
            lower_bound=self.lower_boundaries[emb_st:emb_st+self.embedding_sizes[cat_index]]
            upper_bound=self.upper_boundaries[emb_st:emb_st+self.embedding_sizes[cat_index]]
            emb_st=emb_st+self.embedding_sizes[cat_index]
            # Initialize list to hold categories that fall within boundaries
            categories_within = []
    
            # Iterate over each embedding vector in the layer
            for i, weight in enumerate(emb_layer.weight):
                # Check if the embedding weight falls within the boundaries
                if torch.all(weight >= lower_bound) and torch.all(weight <= upper_bound):
                    categories_within.append(i)  # Using index i as category identifier
            
            # Store the categories that fall within the boundaries for this feature
            categories_within_bounds[cat_index] = categories_within
    
        return categories_within_bounds
   
    def global_importance(self):
        final_layer_weight=torch.clone(self.final_linear.weight).detach().numpy()
        importances=np.sum(np.abs(final_layer_weight),0)
        importances=importances.reshape(importances.shape[0]//self.nn_input_dim,self.nn_input_dim)
        importances=np.sum(importances,0)
        importances_features=[]
        st=0
        for i in range(len(self.attribute_names)):
            try:
                importances_features.append(np.sum(importances[st:st+self.embedding_sizes[i]]))
                st=st+self.embedding_sizes[i]
            except:
                
                st=st+1
        return np.argsort(importances_features)[::-1],np.sort(importances_features)[::-1]
        
    def influence_matrix(self):
        """
        Finds ADG from how each feature effects other's threshold via weight matrices
        """
                
        def create_block_sum_matrix(sizes, matrix):
            L = len(sizes)
            # Initialize the output matrix with zeros, using PyTorch
            block_sum_matrix = torch.zeros((L, L))
            
            # Define the starting row and column indices for slicing
            start_row = 0
            for i, row_size in enumerate(sizes):
                start_col = 0
                for j, col_size in enumerate(sizes):
                    # Calculate the sum of the current block using PyTorch
                    block_sum = torch.sum(matrix[start_row:start_row+row_size, start_col:start_col+col_size])
                    block_sum_matrix[i, j] = block_sum
                    # Update the starting column index for the next block in the row
                    start_col += col_size
                # Update the starting row index for the next block in the column
                start_row += row_size
            
            return block_sum_matrix

        def add_ones_until_target(initial_list, target_sum):
            # Continue adding 1s until the sum of the list equals the target sum
            while sum(initial_list) < target_sum:
                initial_list.append(1)
            return initial_list

        for i in range(0, self.depth * 3, 3):
            # encode, drop and find next tau
            weight_now=self.layers[i + 2].weight
            weight_now_reshaped=weight_now.reshape((weight_now.shape[0], weight_now.shape[1]//self.nn_input_dim,self.nn_input_dim)) #shape: output x depth x input
            if i==0:
                # effects=np.sum(np.abs(weight_now_reshaped.numpy()),axis=1)/self.depth #shape: output x input
                effects=torch.sum(torch.abs(weight_now_reshaped), dim=1) / self.depth
            else:
                effects=effects+torch.sum(torch.abs(weight_now_reshaped), dim=1) / self.depth
            
        effects=effects.t()  #shape: input x output
        
        modified_list = add_ones_until_target(copy.deepcopy(self.embedding_sizes), effects.shape[0])
        
        
        effects=create_block_sum_matrix(modified_list,effects)
        
        return effects
            
    
    def explain_without_causal_effects(self,x):
        """
        Explains decisions of the neural network for input sample.
        For numericals, extracts upper and lower boundaries on the sample
        For categoricals displays possible categories
        Also calculates contributions of each feature to final result
        """
        self.find_boundaries(x) #find upper, lower boundaries for all nn inputs
        
        #find valid categories for categorical features
        valid_categories=self.categories_within_boundaries()
        
        #numerical boundaries
        upper_numerical=self.upper_boundaries[sum(self.embedding_sizes):].detach().numpy()
        lower_numerical=self.lower_boundaries[sum(self.embedding_sizes):].detach().numpy()
        
        #Find contribution from each feature in final linear layer, distribute bias evenly
        contributions=self.encodings * self.final_linear.weight + self.final_linear.bias.unsqueeze(dim=-1)/self.final_linear.weight.shape[1]
        contributions=contributions.detach().resize_((contributions.shape[0], contributions.shape[1]//self.nn_input_dim,self.nn_input_dim))
        contributions=torch.sum(contributions,dim=1)
        
        # Initialize an empty list to store the summed contributions
        summed_contributions = []
        
        # Initialize start index for slicing
        start_idx = 0
        
        #Sum contribution of each categorical within respective embedding
        for size in self.embedding_sizes:
            # Calculate end index for the current chunk
            end_idx = start_idx + size
            
            # Sum the contributions in the current chunk
            chunk_sum = contributions[:, start_idx:end_idx].sum(dim=1, keepdim=True)
            
            # Append the summed chunk to the list
            summed_contributions.append(chunk_sum)
            
            # Update the start index for the next chunk
            start_idx = end_idx
        
        # If there are remaining elements not covered by embedding_sizes, add them as is (numerical features)
        if start_idx < contributions.shape[1]:
            remaining = contributions[:, start_idx:]
            summed_contributions.append(remaining)
        
        # Concatenate the summed contributions back into a tensor
        summed_contributions = torch.cat(summed_contributions, dim=1)
        # This is to handle multi-class explanations, for binary this is 0 automatically
        # Note: multi-output regression is not supported yet. This will just return largest regressed value's explanations
        highest_index=torch.argmax(summed_contributions.sum(dim=1))
        # This is contribution from each feature
        result=summed_contributions[highest_index]
        self.result=result

        #Explanation and Contribution formats are in ordered format (categoricals first, numericals later)
        #Bring them to original format in user input
        #Combine categoricals and numericals explanations and contributions
        Explanation = [None] * (len(self.categorical_indices) + len(upper_numerical))
        Contribution = np.zeros((len(self.categorical_indices) + len(upper_numerical),))   
        
        # Fill in the categorical samples
        for j, cat_index in enumerate(self.categorical_indices):
            Explanation[cat_index] = valid_categories[j]
            Contribution[cat_index] = result[j].numpy()
        
        
        #INVERSE TRANSFORM PART 1-------------------------------------------------------------------------------------------
        #Inverse transform upper and lower_numericals
        len_num=len(upper_numerical)
        if len_num>0:
            upper_numerical=self.preprocessor.scaler.inverse_transform(upper_numerical.reshape(1,-1))
            lower_numerical=self.preprocessor.scaler.inverse_transform(lower_numerical.reshape(1,-1))
            if len_num>1:
                upper_numerical=np.squeeze(upper_numerical)
                lower_numerical=np.squeeze(lower_numerical)
            upper_iter = iter(upper_numerical)
            lower_iter = iter(lower_numerical)
        
        
        cnt=0
        for i in range(len(Explanation)):
            if Explanation[i] is None:
                #Note the denormalization here
                Explanation[i] = next(lower_iter),next(upper_iter)
                if len(self.categorical_indices)>0:
                    Contribution[i] = result[j+cnt+1].numpy()
                else:
                    Contribution[i] = result[cnt].numpy()
                cnt=cnt+1

        attribute_names_list = []
        revised_explanations_list = []
        contributions_list = []
        # Process each feature to fill lists

        for idx, attr_name in enumerate(self.attribute_names):
            if isinstance(Explanation[idx], list):  # Categorical
                #INVERSE TRANSFORM PART 2-------------------------------------------------------------------------------------------
                #Inverse transform categoricals
                category_names = [key for key, value in self.label_encoders[attr_name].items() if value in Explanation[idx]]
                revised_explanation = " ,OR, ".join(category_names)
            elif isinstance(Explanation[idx], tuple):  # Numerical
                revised_explanation = f"{Explanation[idx][0].item()} to {Explanation[idx][1].item()}"
            else:
                revised_explanation = "Unknown"  #shouldn't really happen

            # Append to lists
            attribute_names_list.append(attr_name)
            revised_explanations_list.append(revised_explanation)
            contributions_list.append(Contribution[idx] if idx < len(Contribution) else None)
            
            

        # Construct DataFrame
        Explanation_df = pd.DataFrame({
            'Name': attribute_names_list,
            'Category': revised_explanations_list,
            'Contribution': contributions_list
        })
        
        

        
        result=self.preprocessor.inverse_transform_y(self.output)
        # Explanation_df['Result'] = [result] * len(Explanation_df)

        return copy.deepcopy(Explanation_df),self.output.clone(),copy.deepcopy(result),copy.deepcopy(Explanation)
    
    def explain(self,x,include_causal_analysis=False):
        """
        Fixes all features but one, sweeps that feature across its own categories, reports the average change from other categories to current one.
        """
        
        def update_intervals(available_intervals, incoming_interval):
            updated_intervals = []
            for interval in available_intervals:
                if incoming_interval[1] <= interval[0] or incoming_interval[0] >= interval[1]:
                    # The incoming interval does not overlap, keep the interval as is
                    updated_intervals.append(interval)
                else:
                    # There is some overlap, possibly split the interval
                    if incoming_interval[0] > interval[0]:
                        # Add the left part that doesn't overlap
                        updated_intervals.append((interval[0], incoming_interval[0]))
                    if incoming_interval[1] < interval[1]:
                        # Add the right part that doesn't overlap
                        updated_intervals.append((incoming_interval[1], interval[1]))
            return updated_intervals
        
        def sample_from_intervals(available_intervals):
            if not available_intervals:
                return None
            # Choose a random interval
            chosen_interval = random.choice(available_intervals)
            # Sample a random point within this interval
            return random.uniform(chosen_interval[0], chosen_interval[1])
        
    
                
        
        Explanation_df,output,result,Explanation=self.explain_without_causal_effects(x)
        if include_causal_analysis:
            # Causal analysis
            causal_effect=np.zeros((x.shape[-1],))
            numerical_cnt=0
            for idx, attr_name in enumerate(self.attribute_names):
                if isinstance(Explanation[idx], list):  # Categorical
                    all_category_names = [value for key, value in self.label_encoders[attr_name].items()]
                    sweeped_category_names = [value for key, value in self.label_encoders[attr_name].items() if value in Explanation[idx]]
                    
                    if list(set(all_category_names)-set(sweeped_category_names)) == []:
                        is_category_empty=True
                    else:
                        is_category_empty=False
                        
                    cnt=0
                    while is_category_empty==False:
                        new_x=x.clone()
                        next_category=list(set(all_category_names)-set(sweeped_category_names))[0]
                        new_x[0,idx]=float(next_category)
                        Explanation_df_new,output_new,result_new,Explanation_new=self.explain_without_causal_effects(new_x)
                        sweeped_category_names = sweeped_category_names+[value for key, value in self.label_encoders[attr_name].items() if value in Explanation_new[idx]]
    
                        if list(set(all_category_names)-set(sweeped_category_names)) == []:
                            is_category_empty=True
                        else:
                            is_category_empty=False
    
                        causal_effect[idx]=causal_effect[idx]+(output-output_new).detach().numpy()[0,0]
                        cnt=cnt+1
                    if cnt>0:
                        causal_effect[idx]=causal_effect[idx]/cnt
                    
                else:
    
                    search_complete=False
                    # Initial available interval . we know -100,100 from initial setting up lower, upper bounds
                    available_intervals = [(-100, 100)]
                    
                    # Example incoming intervals
                    #numerical boundaries
                    self.explain_without_causal_effects(x)
                    upper_numerical=self.upper_boundaries[sum(self.embedding_sizes):].detach().numpy()
                    lower_numerical=self.lower_boundaries[sum(self.embedding_sizes):].detach().numpy()    
                    incoming_interval = (lower_numerical[numerical_cnt],upper_numerical[numerical_cnt])
                    available_intervals = update_intervals(available_intervals, incoming_interval)
                    cnt=0
                    while not(search_complete):
                        new_sample=sample_from_intervals(available_intervals)
                        new_x=x.clone()
                        new_x[0,idx]=new_sample
                        Explanation_df_new,output_new,result_new,Explanation_new=self.explain_without_causal_effects(new_x)
                        causal_effect[idx]=causal_effect[idx]+(output-output_new).detach().numpy()[0,0]
                        cnt=cnt+1
                        upper_numerical=self.upper_boundaries[sum(self.embedding_sizes):].detach().numpy()
                        lower_numerical=self.lower_boundaries[sum(self.embedding_sizes):].detach().numpy()    
                        incoming_interval = (lower_numerical[numerical_cnt],upper_numerical[numerical_cnt])
                        available_intervals = update_intervals(available_intervals, incoming_interval)
                        if available_intervals == []:
                            search_complete=True
                    if cnt>0:
                        causal_effect[idx]=causal_effect[idx]/cnt
                    numerical_cnt=numerical_cnt+1
    
    
    
            Explanation_df['Causal Effects'] = causal_effect
        return Explanation_df,output,result
    

        
    def sample_from_boundaries(self):
        """
        Assumes higher and lower boundaries are already extracted (eg self.explain is run on one input already)
        Samples a value for each feature within the specified upper and lower boundaries stored in the class instance.
        For numericals, samples a value, for categoricals samples a category from possible categories
        Returns:
        - A tensor containing sampled values within the given boundaries for each feature.
        """
        #First sample from categories
        categories_within_bounds=self.categories_within_boundaries()
        try:
            sampled_indices = [random.choice(categories) for categories in categories_within_bounds.values()]
        except:
            categories_within_bounds=self.categories_within_boundaries()
        
        #Then from numericals
        samples = []
        cnt=0
        for lower, upper in zip(self.lower_boundaries[sum(self.embedding_sizes):], self.upper_boundaries[sum(self.embedding_sizes):]):
            # Sample from a uniform distribution between lower and upper boundaries
            upper=torch.minimum(upper.detach(),torch.from_numpy(np.array(1.0).astype('float32')))
            lower=torch.maximum(lower.detach(),torch.from_numpy(np.array(-1.0).astype('float32')))
            sample = lower + (upper - lower) * torch.rand(1)
            samples.append(sample)
            cnt=cnt+1
        

        #Combine categoricals and numericals
        # Initialize an empty list to hold the combined samples
        combined_samples = [None] * (len(self.categorical_indices) + len(samples))
        
        # Fill in the categorical samples
        for i, cat_index in enumerate(self.categorical_indices):
            combined_samples[cat_index] = torch.tensor([sampled_indices[i]], dtype=torch.float)
        
        # Fill in the numerical samples
        num_samples_iter = iter(samples)
        for i in range(len(combined_samples)):
            if combined_samples[i] is None:
                combined_samples[i] = next(num_samples_iter)
        
        # Combine into a single tensor
        combined_tensor = torch.cat(combined_samples, dim=-1)
        return combined_tensor.unsqueeze(dim=0)
    
    
    def generate(self):
        """
        Generates a data sample from learned network
        """
        def sample_with_tau(tau,max_bound,min_bound):
            # Sample according to tau, lower and upper bounds
            sampled=torch.zeros((self.nn_input_dim))
            st=0
            # Randomly pick from valid categories
            for embedding in self.embeddings:
                categories_within = []
                
                # Iterate over each embedding vector in the layer
                for i, weight in enumerate(embedding.weight):
                    # Check if the embedding weight falls within the boundaries
                    if torch.all(weight >= min_bound[st:st+embedding.weight.shape[-1]]) and torch.all(weight <= max_bound[st:st+embedding.weight.shape[-1]]):
                        categories_within.append(i)  # Using index i as category identifier
                feature_now=embedding.weight[np.random.choice(categories_within),:]
                cnt=0
                for j in range(st,st+embedding.weight.shape[-1]):
                    if feature_now[cnt]>-tau[0,j]:
                        sampled[j]=1.0
                    elif feature_now[cnt]<=-tau[0,j]:
                        sampled[j]=-1.0
                    cnt=cnt+1
                st=st+embedding.weight.shape[-1]
                        
            #Randomly sample for numericals
            for i in range(st,self.nn_input_dim):
                if -tau[0,i]>max_bound[i]: #In this case you have to pick -1
                    sampled[i]=-1.0
                elif -tau[0,i]<=min_bound[i]: #In this case you have to pick 1
                    sampled[i]=1.0
                else:
                    sampled[i] = (torch.randint(low=0, high=2, size=(1,)) * 2 - 1).float()
            return sampled
        
        def bound_update(tau,max_bound,min_bound,sampled):
            for i in range(self.nn_input_dim):
                if sampled[i]>0: #means input is larger than -tau, so -tau might set a lower bound
                    if -tau[0,i]>min_bound[i]:
                        min_bound[i]=-tau[0,i]
                elif sampled[i]<=0: #means input is smaller than -tau, so -tau might set an upper bound
                    if -tau[0,i]<max_bound[i]:
                        max_bound[i]=-tau[0,i]
            return max_bound,min_bound
        
        # Read first tau
        tau=self.tau_initial

        # Set initial maximum and minimum bounds
        max_bound=torch.zeros((self.nn_input_dim))+100.0
        min_bound=torch.zeros((self.nn_input_dim))-100.0

        
        for i in range(0, self.depth * 3, 3):
            encoding_layer = self.layers[i] #NOT USED HERE, WE ENCODE RANDOMLY MANUALLY
            dropout_layer = self.layers[i + 1]
            linear_layer = self.layers[i + 2]
            #Sample with current tau
            sample_now=sample_with_tau(tau,max_bound,min_bound)
            #Update bounds with new sample
            max_bound,min_bound=bound_update(tau,max_bound,min_bound,sample_now)
            encoded_x = dropout_layer(sample_now.unsqueeze(dim=0))
            if i==0:
                encodings=encoded_x
                taus=-tau
            else:
                encodings=torch.cat((encodings,encoded_x),dim=-1)
                taus=torch.cat((taus,-tau),dim=-1)

            tau = linear_layer(encodings) #not used for last layer
            
        
        self.encodings=encodings
        self.taus=taus
        self.upper_boundaries=torch.clone(max_bound)
        self.lower_boundaries=torch.clone(min_bound)
        
        generated_sample=self.sample_from_boundaries()
        ##Check if manually found and network generated boundaries are same
        # if torch.equal(self.upper_boundaries,max_bound) and torch.equal(self.lower_boundaries,min_bound):
        #     print(True)
        
        self.explain_without_causal_effects(generated_sample)
        generated_sample_original_format=self.preprocessor.inverse_transform_X(generated_sample)
        result=self.preprocessor.inverse_transform_y(self.output)
        generated_sample_original_format['prediction']=result
        
        return generated_sample,generated_sample_original_format,result
    
    def generate_from_same_category(self,x):
        self.explain_without_causal_effects(x)
        generated_sample=self.sample_from_boundaries()
        generated_sample_original_format=self.preprocessor.inverse_transform_X(generated_sample)
        result=self.preprocessor.inverse_transform_y(self.output)
        return generated_sample,generated_sample_original_format,result