yiksiu's picture
upload submission files
f118c1e verified
raw
history blame contribute delete
901 Bytes
"""Test submission for ARC using identity featurizer."""
import torch
import torch.nn as nn
import pyvene as pv
from CausalAbstraction.neural.featurizers import Featurizer
class IdentityFeaturizerModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.dummy_param = torch.nn.Parameter(torch.zeros(1), requires_grad=True)
def forward(self, x):
# return x, None
return x + 0 * self.dummy_param.sum(), None
class IdentityInverseFeaturizerModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, f, error):
return f
class IdentityFeaturizer(Featurizer):
def __init__(self, id="identity"):
featurizer = IdentityFeaturizerModule()
inverse_featurizer = IdentityInverseFeaturizerModule()
super().__init__(featurizer, inverse_featurizer, id=id)