|
|
|
import math |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
import torch.nn.functional as F |
|
from transformers import PreTrainedModel |
|
|
|
from transformers.modeling_outputs import (ModelOutput,) |
|
|
|
|
|
class CounterModel(PreTrainedModel): |
|
|
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.config = config |
|
self.weight = config.weight |
|
self.bias = config.bias |
|
self.linear = torch.nn.Linear(1,1) |
|
def forward(self, x,**kwargs): |
|
x = self.weight * x + self.bias |
|
logits = self.linear(x) |
|
return logits |
|
def add(self): |
|
return self.weight + self.bias |