ablang2 / modeling_ablang2paired.py
hemantn's picture
Remove problematic imports and add utility files to config.json auto_map
ae9a535
import torch
import os
from torch import nn
from transformers import PreTrainedModel
# Import configuration
try:
from .configuration_ablang2paired import AbLang2PairedConfig
except ImportError:
from configuration_ablang2paired import AbLang2PairedConfig
# Import the AbLang model from local files
try:
from ablang import AbLang
except ImportError:
# Fallback: try to import from the current directory
try:
from .ablang import AbLang
except ImportError:
raise ImportError(
"Could not find AbLang module. Please ensure ablang.py is present in the repository."
)
class AbLang2PairedHFModel(PreTrainedModel):
config_class = AbLang2PairedConfig
model_type = "ablang2-paired"
def __init__(self, config: AbLang2PairedConfig):
super().__init__(config)
self.model = AbLang(
vocab_size=config.vocab_size,
hidden_embed_size=config.hidden_embed_size,
n_attn_heads=config.n_attn_heads,
n_encoder_blocks=config.n_encoder_blocks,
padding_tkn=config.padding_tkn,
mask_tkn=config.mask_tkn,
layer_norm_eps=config.layer_norm_eps,
a_fn=config.a_fn,
dropout=config.dropout,
)
def forward(self, input_ids=None, x=None, attention_mask=None, **kwargs):
# Handle both Hugging Face format (input_ids) and original format (x)
if input_ids is not None:
x = input_ids
elif x is None:
raise ValueError("Either input_ids or x must be provided")
# Get the output from the underlying model
output = self.model(x, attention_mask)
# Return as a simple object with last_hidden_state attribute
class ModelOutput:
def __init__(self, last_hidden_state):
self.last_hidden_state = last_hidden_state
return ModelOutput(output)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Load config first
config = kwargs.get("config")
if config is None:
from transformers import AutoConfig
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
# Create model with config
model = cls(config)
# Try to load custom weights
try:
from transformers.utils import cached_file
custom_weights_path = cached_file(
pretrained_model_name_or_path,
"model.pt",
cache_dir=kwargs.get("cache_dir"),
force_download=kwargs.get("force_download", False),
resume_download=kwargs.get("resume_download", False),
proxies=kwargs.get("proxies"),
token=kwargs.get("token"),
revision=kwargs.get("revision"),
local_files_only=kwargs.get("local_files_only", False),
)
if custom_weights_path is not None and os.path.exists(custom_weights_path):
# Load custom weights
state_dict = torch.load(custom_weights_path, map_location="cpu", weights_only=True)
model.model.load_state_dict(state_dict)
print(f"✅ Loaded custom weights from: {custom_weights_path}")
else:
print("⚠️ No custom weights found, using initialized model")
except Exception as e:
print(f"⚠️ Could not load custom weights: {e}")
print("Using initialized model")
# Move model to appropriate device (GPU if available, otherwise CPU)
device = kwargs.get("device", None)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
return model
def save_pretrained(self, save_directory, **kwargs):
os.makedirs(save_directory, exist_ok=True)
# Save custom weights
torch.save(self.model.state_dict(), f"{save_directory}/model.pt")
# Save config
self.config.save_pretrained(save_directory)
# Call parent method for any additional saving
super().save_pretrained(save_directory, **kwargs)