File size: 275 Bytes
1a6b654
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
import torch.nn as nn

class CodeGenerator(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model = transformers.AutoModelForCausalLM.from_pretrained(model_name)

    def forward(self, input_ids):
        return self.model(input_ids)[0]