ninagala commited on
Commit
5873e46
·
verified ·
1 Parent(s): 3ecf49e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -1
app.py CHANGED
@@ -9,7 +9,135 @@ import os
9
  import json
10
  import math
11
 
12
- # (Previous model code...)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  def generate_text(prompt, max_length=100, temperature=0.7):
15
  try:
 
9
  import json
10
  import math
11
 
12
+ class PositionalEncoding(nn.Module):
13
+ def __init__(self, d_model: int, max_seq_length: int = 512):
14
+ super().__init__()
15
+ position = torch.arange(max_seq_length).unsqueeze(1)
16
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
17
+ pe = torch.zeros(1, max_seq_length, d_model)
18
+ pe[0, :, 0::2] = torch.sin(position * div_term)
19
+ pe[0, :, 1::2] = torch.cos(position * div_term)
20
+ self.register_buffer('pe', pe)
21
+
22
+ def forward(self, x):
23
+ """x: [batch_size, seq_len, d_model]"""
24
+ return x + self.pe[:, :x.size(1), :]
25
+
26
+ class DecoderBlock(nn.Module):
27
+ def __init__(self, d_model: int, n_heads: int, d_ff: int = 2048, dropout: float = 0.1):
28
+ super().__init__()
29
+ self.self_attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
30
+ self.norm1 = nn.LayerNorm(d_model)
31
+ self.ff = nn.Sequential(
32
+ nn.Linear(d_model, d_ff),
33
+ nn.ReLU(),
34
+ nn.Dropout(dropout),
35
+ nn.Linear(d_ff, d_model)
36
+ )
37
+ self.norm2 = nn.LayerNorm(d_model)
38
+ self.dropout = nn.Dropout(dropout)
39
+
40
+ def forward(self, x, mask=None):
41
+ attn_output, _ = self.self_attention(x, x, x, attn_mask=mask)
42
+ x = self.norm1(x + self.dropout(attn_output))
43
+ ff_output = self.ff(x)
44
+ x = self.norm2(x + self.dropout(ff_output))
45
+ return x
46
+
47
+ class TransformerDecoder(nn.Module):
48
+ def __init__(self,
49
+ vocab_size: int,
50
+ d_model: int = 1024,
51
+ n_layers: int = 12,
52
+ n_heads: int = 16,
53
+ d_ff: int = 4096,
54
+ max_seq_length: int = 256,
55
+ dropout: float = 0.1):
56
+ super().__init__()
57
+
58
+ self.max_seq_length = max_seq_length
59
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
60
+ self.positional_encoding = PositionalEncoding(d_model, max_seq_length)
61
+ self.dropout = nn.Dropout(dropout)
62
+
63
+ self.layers = nn.ModuleList([
64
+ DecoderBlock(d_model, n_heads, d_ff, dropout)
65
+ for _ in range(n_layers)
66
+ ])
67
+
68
+ self.final_layer = nn.Linear(d_model, vocab_size)
69
+ self._init_weights()
70
+
71
+ def _init_weights(self):
72
+ nn.init.normal_(self.token_embedding.weight, mean=0.0, std=0.01)
73
+
74
+ for layer in self.layers:
75
+ nn.init.normal_(layer.self_attention.in_proj_weight, mean=0.0, std=0.01)
76
+ nn.init.normal_(layer.self_attention.out_proj.weight, mean=0.0, std=0.01)
77
+
78
+ for name, param in layer.ff.named_parameters():
79
+ if 'weight' in name:
80
+ nn.init.normal_(param, mean=0.0, std=0.01)
81
+ elif 'bias' in name:
82
+ nn.init.zeros_(param)
83
+
84
+ nn.init.normal_(self.final_layer.weight, mean=0.0, std=0.01)
85
+ nn.init.zeros_(self.final_layer.bias)
86
+
87
+ def forward(self, x, mask=None):
88
+ # Create causal mask if not provided
89
+ if mask is None:
90
+ seq_length = x.size(1)
91
+ mask = torch.triu(torch.ones(seq_length, seq_length), diagonal=1).bool()
92
+ mask = mask.to(x.device)
93
+
94
+ x = self.token_embedding(x)
95
+ x = x.transpose(0, 1) # Convert to sequence-first format
96
+ x = self.positional_encoding(x)
97
+ x = self.dropout(x)
98
+ x = x.transpose(0, 1) # Convert back to batch-first
99
+
100
+ for layer in self.layers:
101
+ x = layer(x, mask=mask)
102
+
103
+ output = self.final_layer(x)
104
+ return output
105
+
106
+ @classmethod
107
+ def from_pretrained(cls, model_path: str, device: str = 'cpu'):
108
+ """Load a pretrained model from a directory"""
109
+ try:
110
+ # Load config
111
+ config_path = os.path.join(model_path, "config.json")
112
+ if not os.path.exists(config_path):
113
+ raise FileNotFoundError(f"Config not found at {config_path}")
114
+
115
+ with open(config_path) as f:
116
+ config = json.load(f)
117
+
118
+ # Create model instance
119
+ model = cls(
120
+ vocab_size=config['vocab_size'],
121
+ d_model=config['d_model'],
122
+ n_layers=config['n_layers'],
123
+ n_heads=config['n_heads'],
124
+ d_ff=config['d_ff'],
125
+ max_seq_length=config['max_seq_length'],
126
+ dropout=config.get('dropout', 0.1)
127
+ )
128
+
129
+ # Load weights
130
+ weights_path = os.path.join(model_path, "pytorch_model.bin")
131
+ if not os.path.exists(weights_path):
132
+ raise FileNotFoundError(f"Weights not found at {weights_path}")
133
+
134
+ state_dict = torch.load(weights_path, map_location=device)
135
+ model.load_state_dict(state_dict)
136
+
137
+ return model.to(device)
138
+
139
+ except Exception as e:
140
+ raise Exception(f"Error loading model from {model_path}: {str(e)}")
141
 
142
  def generate_text(prompt, max_length=100, temperature=0.7):
143
  try: