MNIST_Demo_1 / modeling_glm.py
xiaohua828's picture
Update modeling_glm.py
14333a0
raw
history blame
620 Bytes
import math
import torch
import torch.utils.checkpoint
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import (ModelOutput,)
class CounterModel(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.weight = config.weight
self.bias = config.bias
self.linear = torch.nn.Linear(1,1)
def forward(self, x,**kwargs):
x = self.weight * x + self.bias
logits = self.linear(x)
return logits
def add(self):
return self.weight + self.bias