File size: 4,332 Bytes
bb74251
 
 
 
6a0df78
0cdf9af
6a0df78
14efc47
6a0df78
009d04a
 
7ebbadf
 
 
 
 
 
 
 
cf7aac0
7ebbadf
cf7aac0
 
ae9a535
4ce4395
7ebbadf
bb74251
 
 
 
 
 
 
009d04a
bb74251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b241ab
 
 
 
 
bb74251
9b241ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb74251
9b241ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb74251
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import os
from torch import nn
from transformers import PreTrainedModel

# Import configuration
try:
    from .configuration_ablang2paired import AbLang2PairedConfig
except ImportError:
    from configuration_ablang2paired import AbLang2PairedConfig

# Import the AbLang model from local files
try:
    from ablang import AbLang
except ImportError:
    # Fallback: try to import from the current directory
    try:
        from .ablang import AbLang
    except ImportError:
        raise ImportError(
            "Could not find AbLang module. Please ensure ablang.py is present in the repository."
        )





class AbLang2PairedHFModel(PreTrainedModel):
    config_class = AbLang2PairedConfig
    model_type = "ablang2-paired"

    def __init__(self, config: AbLang2PairedConfig):
        super().__init__(config)
        self.model = AbLang(
            vocab_size=config.vocab_size,
            hidden_embed_size=config.hidden_embed_size,
            n_attn_heads=config.n_attn_heads,
            n_encoder_blocks=config.n_encoder_blocks,
            padding_tkn=config.padding_tkn,
            mask_tkn=config.mask_tkn,
            layer_norm_eps=config.layer_norm_eps,
            a_fn=config.a_fn,
            dropout=config.dropout,
        )

    def forward(self, input_ids=None, x=None, attention_mask=None, **kwargs):
        # Handle both Hugging Face format (input_ids) and original format (x)
        if input_ids is not None:
            x = input_ids
        elif x is None:
            raise ValueError("Either input_ids or x must be provided")
        
        # Get the output from the underlying model
        output = self.model(x, attention_mask)
        
        # Return as a simple object with last_hidden_state attribute
        class ModelOutput:
            def __init__(self, last_hidden_state):
                self.last_hidden_state = last_hidden_state
        
        return ModelOutput(output)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
        # Load config first
        config = kwargs.get("config")
        if config is None:
            from transformers import AutoConfig
            config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
        
        # Create model with config
        model = cls(config)
        
        # Try to load custom weights
        try:
            from transformers.utils import cached_file
            custom_weights_path = cached_file(
                pretrained_model_name_or_path,
                "model.pt",
                cache_dir=kwargs.get("cache_dir"),
                force_download=kwargs.get("force_download", False),
                resume_download=kwargs.get("resume_download", False),
                proxies=kwargs.get("proxies"),
                token=kwargs.get("token"),
                revision=kwargs.get("revision"),
                local_files_only=kwargs.get("local_files_only", False),
            )
            
            if custom_weights_path is not None and os.path.exists(custom_weights_path):
                # Load custom weights
                state_dict = torch.load(custom_weights_path, map_location="cpu", weights_only=True)
                model.model.load_state_dict(state_dict)
                print(f"✅ Loaded custom weights from: {custom_weights_path}")
            else:
                print("⚠️ No custom weights found, using initialized model")
                
        except Exception as e:
            print(f"⚠️ Could not load custom weights: {e}")
            print("Using initialized model")
        
        # Move model to appropriate device (GPU if available, otherwise CPU)
        device = kwargs.get("device", None)
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
        model = model.to(device)
        
        return model

    def save_pretrained(self, save_directory, **kwargs):
        os.makedirs(save_directory, exist_ok=True)
        # Save custom weights
        torch.save(self.model.state_dict(), f"{save_directory}/model.pt")
        # Save config
        self.config.save_pretrained(save_directory)
        # Call parent method for any additional saving
        super().save_pretrained(save_directory, **kwargs)