File size: 1,172 Bytes
33b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch

@torch.no_grad()
def add_feature_on_text(sae, feature_idx, steering_feature, module, input, output):
    ## input shape 
    if input[0].size(-1) == 768:
        return (output[0] + steering_feature[:,:768].unsqueeze(0)),
    else:
        return (output[0] + steering_feature[:,768:].unsqueeze(0)),

@torch.no_grad()
def add_feature_on_text_prompt(sae, steering_feature, module, input, output):
    if input[0].size(-1) == 768:
        return (output[0] + steering_feature[:,:768].unsqueeze(0)),
    else:
        return (output[0] + steering_feature[:,768:].unsqueeze(0)),

@torch.no_grad()
def add_feature_on_text_prompt_flux(sae, steering_feature, module, input, output):

    return (output[0] + steering_feature.unsqueeze(0)), output[1]

@torch.no_grad()
def minus_feature_on_text_prompt(sae, steering_feature, module, input, output):
    if input[0].size(-1) == 768:
        return (output[0] - steering_feature[:,:768].unsqueeze(0)),
    else:
        return (output[0] - steering_feature[:,768:].unsqueeze(0)),

@torch.no_grad()
def do_nothing(sae, steering_feature, module, input, output):
    return (output[0]),