|
"""Test submission for ARC using identity featurizer.""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import pyvene as pv |
|
from CausalAbstraction.neural.featurizers import Featurizer |
|
|
|
|
|
class IdentityFeaturizerModule(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.dummy_param = torch.nn.Parameter(torch.zeros(1), requires_grad=True) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
return x + 0 * self.dummy_param.sum(), None |
|
|
|
|
|
class IdentityInverseFeaturizerModule(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, f, error): |
|
return f |
|
|
|
|
|
class IdentityFeaturizer(Featurizer): |
|
def __init__(self, id="identity"): |
|
featurizer = IdentityFeaturizerModule() |
|
inverse_featurizer = IdentityInverseFeaturizerModule() |
|
|
|
super().__init__(featurizer, inverse_featurizer, id=id) |