Timo commited on
Commit
cb94f9b
·
1 Parent(s): 9790e4e

weird fixes

Browse files
Files changed (2) hide show
  1. src/draft_model.py +9 -4
  2. src/helpers.py +283 -0
src/draft_model.py CHANGED
@@ -5,7 +5,6 @@ from typing import List, Dict
5
 
6
  from huggingface_hub import hf_hub_download
7
 
8
- from src.models.winrate_model import Winrate_Model
9
  from src.training import train_mlp
10
  from src.utils import utils
11
 
@@ -19,6 +18,10 @@ DATA_REPO = "TimoBertram/MTG_Drafting_Dataset/"
19
  CARD_FILE = "cards_eoe.json"
20
  ENCODING_FILE = "card_encodings.pt"
21
 
 
 
 
 
22
  class DraftModel:
23
  def __init__(self, model_path: str):
24
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -34,12 +37,12 @@ class DraftModel:
34
  cfg = open(cfg_path, "r")
35
  cfg.pop("name", None)
36
 
37
- self.net = train_mlp.MLP_CrossAttention(**cfg).to(self.device)
38
  self.net.load_state_dict(torch.load(Path(model_path) / "network.pt", map_location=self.device))
39
  self.net.eval()
40
 
41
  # ---- embeddings – one-time load ------------------------------------
42
- self.embed_dict = utils.get_embedding_dict(
43
  hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"),
44
  add_nontransformed=True
45
  )
@@ -52,9 +55,11 @@ class DraftModel:
52
  names = [c["name"] for c in pack]
53
 
54
  def embed(name): # helper
55
- return utils.get_card_embeddings((name,), embedding_dict=self.embed_dict)[0]
56
 
57
  card_t = torch.stack([embed(n) for n in names]).unsqueeze(0).to(self.device)
58
  deck_t = torch.zeros((1, 45, self.emb_size), device=self.device)
59
 
60
  return torch.softmax(self.net(card_t, deck_t), dim=1).squeeze(0).cpu().numpy().tolist()
 
 
 
5
 
6
  from huggingface_hub import hf_hub_download
7
 
 
8
  from src.training import train_mlp
9
  from src.utils import utils
10
 
 
18
  CARD_FILE = "cards_eoe.json"
19
  ENCODING_FILE = "card_encodings.pt"
20
 
21
+ from helpers import get_embedding_dict, get_card_embeddings, MLP_CrossAttention
22
+
23
+
24
+
25
  class DraftModel:
26
  def __init__(self, model_path: str):
27
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
37
  cfg = open(cfg_path, "r")
38
  cfg.pop("name", None)
39
 
40
+ self.net = MLP_CrossAttention(**cfg).to(self.device)
41
  self.net.load_state_dict(torch.load(Path(model_path) / "network.pt", map_location=self.device))
42
  self.net.eval()
43
 
44
  # ---- embeddings – one-time load ------------------------------------
45
+ self.embed_dict = get_embedding_dict(
46
  hf_hub_download(repo_id=DATA_REPO, filename=ENCODING_FILE, repo_type="dataset"),
47
  add_nontransformed=True
48
  )
 
55
  names = [c["name"] for c in pack]
56
 
57
  def embed(name): # helper
58
+ return get_card_embeddings((name,), embedding_dict=self.embed_dict)[0]
59
 
60
  card_t = torch.stack([embed(n) for n in names]).unsqueeze(0).to(self.device)
61
  deck_t = torch.zeros((1, 45, self.emb_size), device=self.device)
62
 
63
  return torch.softmax(self.net(card_t, deck_t), dim=1).squeeze(0).cpu().numpy().tolist()
64
+
65
+
src/helpers.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import pickle
4
+
5
+
6
+
7
+ class Deck_Attention(nn.Module):
8
+ def __init__(self, input_size, output_dim, num_heads=8, num_layers=3, output_layers = 2, dropout=0.2):
9
+ super(Deck_Attention, self).__init__()
10
+
11
+ # Input projection and normalization
12
+ self.hidden_dim = 1024
13
+ self.input_proj = nn.Linear(input_size, self.hidden_dim, bias = False)
14
+ self.input_norm = nn.LayerNorm(self.hidden_dim, bias = False)
15
+ self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_dim))
16
+ self.pos_encoding = nn.Embedding(45, self.hidden_dim)
17
+
18
+
19
+ encoder_layer = nn.TransformerEncoderLayer(
20
+ d_model= self.hidden_dim,
21
+ nhead = num_heads,
22
+ dim_feedforward= self.hidden_dim * 4,
23
+ dropout=dropout,
24
+ activation='gelu',
25
+ batch_first=True,
26
+ norm_first=True,
27
+ )
28
+ self.layers = nn.TransformerEncoder(encoder_layer,
29
+ num_layers=num_layers,
30
+ enable_nested_tensor=False,
31
+ )
32
+ self.transformer_norm = nn.LayerNorm(self.hidden_dim, bias = False)
33
+ # Output projection
34
+ self.output_proj = nn.ModuleList(
35
+ [nn.Sequential(
36
+ nn.Linear(self.hidden_dim, self.hidden_dim, bias = False),
37
+ nn.GELU(),
38
+ nn.LayerNorm(self.hidden_dim, bias = False),
39
+ nn.Dropout(dropout) ) for _ in range(output_layers)])
40
+
41
+ self.final_layer = nn.Sequential(
42
+ nn.Linear(self.hidden_dim, self.hidden_dim, bias = False),
43
+ nn.LayerNorm(self.hidden_dim, bias = False),
44
+ nn.GELU(),
45
+ nn.Linear(self.hidden_dim, output_dim, bias = False))
46
+
47
+ def forward(self, x, lens=None):
48
+ # Reshape input if needed
49
+ x = x.view(x.size(0), x.size(-2), x.size(-1))
50
+ batch_size = x.size(0)
51
+
52
+
53
+ # Create padding mask
54
+ padding_mask = None
55
+ if lens is not None:
56
+ lens = lens.to(x.device)
57
+ padding_mask = torch.arange(45, device=x.device).expand(batch_size, 45) >= lens.unsqueeze(1)
58
+ padding_mask = torch.cat((torch.zeros(padding_mask.shape[0], 1, device= padding_mask.device).bool(), padding_mask), dim = 1)
59
+
60
+ # Initial projection and add position embeddings
61
+ x = self.input_proj(x)
62
+
63
+ pos = torch.arange(45, device=x.device).expand(batch_size, 45)
64
+ pos = self.pos_encoding(pos)
65
+ x = x + pos
66
+ x = torch.cat([self.cls_token.expand(batch_size, -1, -1), x], dim=1)
67
+ x = self.input_norm(x)
68
+
69
+ x = self.layers(x, src_key_padding_mask=padding_mask)
70
+ x = self.transformer_norm(x)
71
+
72
+
73
+ x = x[:, 0, :]
74
+ for layer in self.output_proj:
75
+ x = x+ layer(x)
76
+
77
+ x = self.final_layer(x)
78
+ return x
79
+
80
+
81
+ class Card_Preprocessing(nn.Module):
82
+ def __init__(self, num_layers, input_size, output_size, nonlinearity = nn.GELU, internal_size = 1024, dropout = 0):
83
+ super(Card_Preprocessing,self).__init__()
84
+ self.internal_size = internal_size
85
+ self.input = nn.Sequential(
86
+ nn.Linear(input_size,internal_size, bias = False),
87
+ nonlinearity(),
88
+ nn.LayerNorm(internal_size, bias = False),
89
+ nn.Dropout(dropout),
90
+ )
91
+ self.hidden_layers = nn.ModuleList()
92
+ self.dropout_rate = dropout
93
+ for i in range(num_layers):
94
+ self.hidden_layers.append(nn.Sequential(
95
+ nn.Linear(internal_size,internal_size, bias = False),
96
+ nonlinearity(),
97
+ nn.LayerNorm(internal_size, bias = False),
98
+ nn.Dropout(dropout),
99
+ ))
100
+ self.output = nn.Sequential(
101
+ nn.Linear(internal_size,output_size, bias = False),
102
+ nonlinearity(),
103
+ nn.LayerNorm(output_size, bias = False)
104
+ )
105
+ self.gammas = nn.ParameterList([torch.nn.Parameter(torch.ones(1, internal_size), requires_grad = True) for i in range(num_layers)])
106
+
107
+ def forward(self,x):
108
+ x = self.input(x)
109
+ for i,layer in enumerate(self.hidden_layers):
110
+ gamma = torch.sigmoid(self.gammas[i])
111
+ x = gamma * x + (1-gamma) * layer(x)
112
+ x = self.output(x)
113
+ return x
114
+
115
+
116
+ class CrossAttnBlock(nn.Module):
117
+ """
118
+ One deck→pack cross-attention block, Pre-LayerNorm style.
119
+ cards : [B, K, d] (queries)
120
+ deck : [B, D, d] (keys / values)
121
+ returns updated cards [B, K, d]
122
+ """
123
+ def __init__(self, d_model: int, n_heads: int, dropout: float):
124
+ super().__init__()
125
+ self.ln_q = nn.LayerNorm(d_model)
126
+ self.ln_k = nn.LayerNorm(d_model)
127
+ self.ln_v = nn.LayerNorm(d_model)
128
+ self.xattn = nn.MultiheadAttention(
129
+ d_model, n_heads,
130
+ dropout=dropout, batch_first=True)
131
+ self.ln_ff = nn.LayerNorm(d_model)
132
+ self.ffn = nn.Sequential(
133
+ nn.Linear(d_model, 4 * d_model),
134
+ nn.GELU(),
135
+ nn.Dropout(dropout),
136
+ nn.Linear(4 * d_model, d_model),
137
+ nn.Dropout(dropout),
138
+ )
139
+ self.dropout_attn = nn.Dropout(dropout)
140
+
141
+
142
+ def forward(self, cards, deck, mask = None):
143
+ # 1) deck → card cross-attention
144
+ q = self.ln_q(cards)
145
+ k = self.ln_k(deck)
146
+ v = self.ln_v(deck)
147
+ attn_out, _ = self.xattn(q, k, v, key_padding_mask = mask) # [B, K, d]
148
+ x = cards + self.dropout_attn(attn_out) # residual
149
+
150
+ # 2) position-wise feed-forward
151
+ y = self.ffn(self.ln_ff(x))
152
+ return x + y
153
+
154
+ class MLP_CrossAttention(nn.Module):
155
+ def __init__(self, input_size, num_card_layers, card_output_dim, dropout, **kwargs):
156
+ super(MLP_CrossAttention, self).__init__()
157
+ self.input_size = input_size
158
+
159
+ self.card_encoder = Card_Preprocessing(num_card_layers,
160
+ input_size = input_size,
161
+ internal_size = 1024,
162
+ output_size = card_output_dim,
163
+ dropout = dropout)
164
+
165
+ self.attention_layers = nn.ModuleList([
166
+ CrossAttnBlock(card_output_dim, n_heads=4, dropout=dropout)
167
+ for _ in range(10)
168
+ ])
169
+
170
+ self.output_layer = nn.Sequential(
171
+ nn.Linear(card_output_dim, card_output_dim*2),
172
+ nn.ReLU(),
173
+ nn.LayerNorm(card_output_dim*2, bias = False),
174
+ nn.Dropout(dropout),
175
+
176
+ nn.Linear(card_output_dim*2, card_output_dim*4),
177
+ nn.ReLU(),
178
+ nn.LayerNorm(card_output_dim*4, bias = False),
179
+ nn.Dropout(dropout),
180
+
181
+
182
+ nn.Linear(card_output_dim*4, card_output_dim),
183
+ nn.ReLU(),
184
+ nn.LayerNorm(card_output_dim, bias = False),
185
+
186
+ nn.Linear(card_output_dim, 1),
187
+ )
188
+ if kwargs['path'] is not None:
189
+ self.load_state_dict(torch.load(f"{kwargs['path']}/network.pt", map_location='cpu'))
190
+ print(f"Loaded model from {kwargs['path']}/network.pt")
191
+
192
+ def forward(self, deck, cards, get_embeddings = False, no_attention = False):
193
+ batch_size, deck_size, card_size = deck.shape
194
+
195
+ deck = deck.view(batch_size * deck_size, card_size)
196
+
197
+ deck_encoded = self.card_encoder(deck.cuda())
198
+ deck_encoded = deck_encoded.view(batch_size, deck_size, -1)
199
+
200
+
201
+ # identify padded cards
202
+ mask = (cards.sum(dim=-1) != 0).cuda()
203
+ cards_encoded = self.card_encoder(cards.cuda())
204
+
205
+ if not no_attention:
206
+ # Cross-attention
207
+ for layer in self.attention_layers:
208
+ cards_encoded = layer(cards_encoded, deck_encoded)
209
+
210
+ if get_embeddings:
211
+ for layer in self.output_layer[:-3]:
212
+ cards_encoded = layer(cards_encoded)
213
+ return cards_encoded
214
+
215
+ # Output layer
216
+ logits = self.output_layer(cards_encoded)
217
+ # Mask out padded cards
218
+ logits = logits.masked_fill(~mask.unsqueeze(-1), float('-inf'))
219
+ return logits.squeeze(-1)
220
+
221
+ def get_card_embedding(self, card_embedding):
222
+ card_embedding = card_embedding.view(1,1, -1)
223
+ empty_deck = torch.zeros((1, 45, self.input_size)).to(card_embedding.device)
224
+
225
+ return self.card_encoder(card_embedding).squeeze()
226
+
227
+ return self(deck = empty_deck,
228
+ cards = card_embedding,
229
+ get_embeddings = True,
230
+ no_attention = True).squeeze(0)
231
+
232
+
233
+ def get_embedding_dict(path, add_nontransformed = False):
234
+ with open(path, 'rb') as f:
235
+ embedding_dict = pickle.load(f)
236
+
237
+ if add_nontransformed:
238
+ embedding_dict_tmp = {}
239
+ for k,v in embedding_dict.items():
240
+ embedding_dict_tmp[k] = v
241
+ if '//' in k:
242
+ embedding_dict_tmp[k.split(' // ')[0]] = v
243
+ embedding_dict = embedding_dict_tmp
244
+ return embedding_dict_tmp
245
+ return embedding_dict
246
+
247
+ def get_card_embeddings(card_names, embedding_dict, embedding_size = 1330):
248
+ embeddings = []
249
+ new_embeddings = {}
250
+ for card in card_names:
251
+ if card == '':
252
+ embeddings.append([])
253
+ elif card == []:
254
+ if type(embedding_size) == tuple:
255
+ channels, height, width = embedding_size
256
+ new_embedding = torch.zeros(1,channels, height, width)
257
+ else:
258
+ new_embedding = torch.zeros(1,embedding_size)
259
+ embeddings.append(new_embedding)
260
+
261
+ elif isinstance(card, list):
262
+ if len(card) == 0:
263
+ embeddings.append(None)
264
+ continue
265
+ deck_embedding = []
266
+ for c in card:
267
+ embedding, got_new = get_embedding_of_card(c, embedding_dict)
268
+ deck_embedding.append(embedding)
269
+ try:
270
+ num_cards = len(deck_embedding)
271
+ deck_embedding = torch.stack(deck_embedding)
272
+ if type(embedding_size) == tuple:
273
+ channels, height, width = embedding_size
274
+ deck_embedding = deck_embedding.view(num_cards,channels, height, width)
275
+ else:
276
+ deck_embedding = deck_embedding.view(num_cards,-1)
277
+ except Exception as e:
278
+ raise e
279
+ embeddings.append(deck_embedding)
280
+ else:
281
+ embedding, got_new = get_embedding_of_card(card, embedding_dict)
282
+ embeddings.append(embedding)
283
+ return embeddings