jgs-430
updated my_model.py
b4ce828
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
class StoryPointIncrementModel(nn.Module):
"""
A custom model wrapper designed to load and use the weights of a fine-tuned
Transformer model for regression (story point prediction).
"""
# CRITICAL FIX: Add cache_dir argument to __init__ and set a default to None
def __init__(self, model_name="prajjwal1/bert-tiny", num_labels=1, cache_dir=None):
super().__init__()
# Load the configuration of a small BERT-like model as a base template.
# PASS cache_dir to from_pretrained to prevent permission errors
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
# We load the base encoder (up to the pooler)
self.encoder = AutoModel.from_config(config)
# A simple linear layer for regression (predicting a single story point value)
self.regressor = nn.Linear(config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
# Pass the tokenized inputs through the Transformer encoder
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Use the pooled output (representation of the whole sequence) for regression
pooled_output = outputs.pooler_output
# Pass the pooled output through the regressor head
logits = self.regressor(pooled_output)
return logits