File size: 1,172 Bytes
33b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
@torch.no_grad()
def add_feature_on_text(sae, feature_idx, steering_feature, module, input, output):
## input shape
if input[0].size(-1) == 768:
return (output[0] + steering_feature[:,:768].unsqueeze(0)),
else:
return (output[0] + steering_feature[:,768:].unsqueeze(0)),
@torch.no_grad()
def add_feature_on_text_prompt(sae, steering_feature, module, input, output):
if input[0].size(-1) == 768:
return (output[0] + steering_feature[:,:768].unsqueeze(0)),
else:
return (output[0] + steering_feature[:,768:].unsqueeze(0)),
@torch.no_grad()
def add_feature_on_text_prompt_flux(sae, steering_feature, module, input, output):
return (output[0] + steering_feature.unsqueeze(0)), output[1]
@torch.no_grad()
def minus_feature_on_text_prompt(sae, steering_feature, module, input, output):
if input[0].size(-1) == 768:
return (output[0] - steering_feature[:,:768].unsqueeze(0)),
else:
return (output[0] - steering_feature[:,768:].unsqueeze(0)),
@torch.no_grad()
def do_nothing(sae, steering_feature, module, input, output):
return (output[0]),
|