minGRU-sentiment2 / modeling_minGRU.py
suayptalha's picture
Create modeling_minGRU.py
c19f25f verified
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from typing import Optional
from .configuration_minGRU import MinGRUConfig
from minGRU_pytorch.minGRU import minGRU
class MinGRUWrapped(nn.Module):
def __init__(self, min_gru_model):
super().__init__()
self.min_gru_model = min_gru_model
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def forward(self, *args, **kwargs):
args = [arg.to(self.device) if isinstance(arg, torch.Tensor) else arg for arg in args]
kwargs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
return self.min_gru_model(*args, **kwargs)
def to(self, device):
self.device = device
self.min_gru_model.to(device)
return self
class MinGRUPreTrainedModel(PreTrainedModel):
config_class = MinGRUConfig
base_model_prefix = "model"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, param in module.named_parameters():
if torch.isnan(param).any():
print(f"NaN detected in parameter {name}. Replacing with a safe number.")
param.data = torch.nan_to_num(param.data, nan=1e-6)
class MinGRUForSequenceClassification(PreTrainedModel):
config_class = MinGRUConfig
base_model_prefix = "model"
def __init__(self, config: MinGRUConfig):
super().__init__(config)
self.embedding = nn.Embedding(config.vocab_size, config.d_model)
raw_min_gru = minGRU(
dim=config.d_model,
expansion_factor=config.ff_mult
)
self.model = MinGRUWrapped(raw_min_gru)
# Final linear layer for classification
self.classifier = nn.Linear(config.d_model, config.num_labels)
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = True,
**kwargs
):
embeddings = self.embedding(input_ids)
logits = self.model(embeddings)
pooled_output = logits.mean(dim=1)
logits = self.classifier(pooled_output) # No need for additional layers here
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
return (loss, logits) if loss is not None else (logits,)
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
"""
Load model from a pretrained checkpoint.
"""
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
for name, param in model.named_parameters():
if name in ['embedding.weight', 'model.min_gru_model.to_hidden_and_gate.weight', 'model.min_gru_model.to_out.weight']:
if param is None or torch.isnan(param).any() or torch.isinf(param).any():
nn.init.xavier_normal_(param) # Başlatma işlemi
print(f"Initialized parameter {name} manually.")
return model
def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True, **kwargs):
"""
Save the model and configuration to a directory.
Args:
save_directory (str): Directory to save the model.
safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
kwargs: Additional arguments like max_shard_size (ignored in this implementation).
"""
import os
os.makedirs(save_directory, exist_ok=True)
if safe_serialization:
print("Saving with safe serialization.")
state_dict = {}
for name, param in self.model.min_gru_model.named_parameters():
state_dict[f"model.{name}"] = param
for name, param in self.classifier.named_parameters():
state_dict[f"classifier.{name}"] = param
state_dict['config'] = self.config.__dict__
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
self.config.save_pretrained(save_directory)
else:
print("Saving without safe serialization.")
super().save_pretrained(save_directory)