|
import torch
|
|
|
|
@torch.no_grad()
|
|
def add_feature_on_text(sae, feature_idx, steering_feature, module, input, output):
|
|
|
|
if input[0].size(-1) == 768:
|
|
return (output[0] + steering_feature[:,:768].unsqueeze(0)),
|
|
else:
|
|
return (output[0] + steering_feature[:,768:].unsqueeze(0)),
|
|
|
|
@torch.no_grad()
|
|
def add_feature_on_text_prompt(sae, steering_feature, module, input, output):
|
|
if input[0].size(-1) == 768:
|
|
return (output[0] + steering_feature[:,:768].unsqueeze(0)),
|
|
else:
|
|
return (output[0] + steering_feature[:,768:].unsqueeze(0)),
|
|
|
|
@torch.no_grad()
|
|
def add_feature_on_text_prompt_flux(sae, steering_feature, module, input, output):
|
|
|
|
return (output[0] + steering_feature.unsqueeze(0)), output[1]
|
|
|
|
@torch.no_grad()
|
|
def minus_feature_on_text_prompt(sae, steering_feature, module, input, output):
|
|
if input[0].size(-1) == 768:
|
|
return (output[0] - steering_feature[:,:768].unsqueeze(0)),
|
|
else:
|
|
return (output[0] - steering_feature[:,768:].unsqueeze(0)),
|
|
|
|
@torch.no_grad()
|
|
def do_nothing(sae, steering_feature, module, input, output):
|
|
return (output[0]),
|
|
|
|
|