adder_32bit / model.py
fateextra's picture
Upload 3 files
a3e24f1 verified
Raw
History Blame Contribute Delete
3.22 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import json
import safetensors.torch
class AdderLayer(nn.Module):
def __init__(self, i_dim: int, o_dim: int):
super().__init__()
self.linear = nn.Linear(i_dim, o_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear(x)
x = F.relu(x)
x_shape = x.shape
x = x.reshape(x_shape[:-1] + (-1, 2))
x = x / x.norm(dim=-1, keepdim=True)
x = x.reshape(x_shape)
return x
class AdderEncoder(nn.Module):
def __init__(self, bits: int = 32):
super().__init__()
self.bits = bits
def forward(self, x: torch.LongTensor) -> torch.Tensor:
o = torch.zeros(x.shape + (2 * self.bits,), device=x.device)
for i in range(self.bits):
v = (x >> i) & 1
o[..., i * 2] = 1 - v
o[..., i * 2 + 1] = v
return o
def extra_repr(self) -> str:
return f'bits={self.bits}'
class AdderDecoder(nn.Module):
def __init__(self, bits: int = 32):
super().__init__()
self.bits = bits
def forward(self, x: torch.Tensor) -> torch.LongTensor:
o = torch.zeros(x.shape[:-1], device=x.device, dtype=torch.long)
for i in range(self.bits):
v = x[..., i * 2 + 1] > x[..., i * 2]
o = o | (v << i)
return o
class Adder(nn.Module):
def __init__(self, layer_dims: list[int], bits: int = 32):
super().__init__()
self.layer_dims = layer_dims
self.bits = bits
self.encoder = AdderEncoder(bits)
self.encoder_c = AdderEncoder(1)
self.decoder = AdderDecoder(bits)
self.decoder_c = AdderDecoder(1)
self.layers = nn.ModuleList()
for i_dim, o_dim in zip(layer_dims[:-1], layer_dims[1:]):
self.layers.append(AdderLayer(i_dim, o_dim))
@property
def config(self) -> dict:
return {
'layer_dims': self.layer_dims,
'bits': self.bits
}
@classmethod
def from_pretrained(cls, filepath: str):
with open(os.path.join(filepath, 'config.json'), 'r') as f:
config = json.load(f)
model = cls(**config)
state_dict = safetensors.torch.load_file(
os.path.join(filepath, 'model.safetensors'))
model.load_state_dict(state_dict)
model.requires_grad_(False)
return model
def forward(self, a: torch.LongTensor, b: torch.LongTensor) -> torch.LongTensor:
assert (0 <= a < 2 ** self.bits).all()
assert (0 <= b < 2 ** self.bits).all()
a = self.encoder(a)
b = self.encoder(b)
c = self.encoder_c(torch.tensor(0, device=a.device))
x = torch.cat([a, b, c], dim=-1)
for m in self.layers:
x = m(x)
x, c = x.split([2 * self.bits, 2], dim=-1)
x = self.decoder(x)
c = self.decoder_c(c)
if (c > 0).any():
raise ValueError("Carry out is not 0")
return x
def main():
model = Adder.from_pretrained(os.path.dirname(__file__))
print(model(torch.tensor(3123), torch.tensor(5929)))
if __name__ == "__main__":
main()