MTG_Drafting_AI / src /helpers.py
Timo
Works now
4205025
import torch
import torch.nn as nn
import pickle
class Deck_Attention(nn.Module):
def __init__(self, input_size, output_dim, num_heads=8, num_layers=3, output_layers = 2, dropout=0.2):
super(Deck_Attention, self).__init__()
# Input projection and normalization
self.hidden_dim = 1024
self.input_proj = nn.Linear(input_size, self.hidden_dim, bias = False)
self.input_norm = nn.LayerNorm(self.hidden_dim, bias = False)
self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_dim))
self.pos_encoding = nn.Embedding(45, self.hidden_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model= self.hidden_dim,
nhead = num_heads,
dim_feedforward= self.hidden_dim * 4,
dropout=dropout,
activation='gelu',
batch_first=True,
norm_first=True,
)
self.layers = nn.TransformerEncoder(encoder_layer,
num_layers=num_layers,
enable_nested_tensor=False,
)
self.transformer_norm = nn.LayerNorm(self.hidden_dim, bias = False)
# Output projection
self.output_proj = nn.ModuleList(
[nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim, bias = False),
nn.GELU(),
nn.LayerNorm(self.hidden_dim, bias = False),
nn.Dropout(dropout) ) for _ in range(output_layers)])
self.final_layer = nn.Sequential(
nn.Linear(self.hidden_dim, self.hidden_dim, bias = False),
nn.LayerNorm(self.hidden_dim, bias = False),
nn.GELU(),
nn.Linear(self.hidden_dim, output_dim, bias = False))
def forward(self, x, lens=None):
# Reshape input if needed
x = x.view(x.size(0), x.size(-2), x.size(-1))
batch_size = x.size(0)
# Create padding mask
padding_mask = None
if lens is not None:
lens = lens.to(x.device)
padding_mask = torch.arange(45, device=x.device).expand(batch_size, 45) >= lens.unsqueeze(1)
padding_mask = torch.cat((torch.zeros(padding_mask.shape[0], 1, device= padding_mask.device).bool(), padding_mask), dim = 1)
# Initial projection and add position embeddings
x = self.input_proj(x)
pos = torch.arange(45, device=x.device).expand(batch_size, 45)
pos = self.pos_encoding(pos)
x = x + pos
x = torch.cat([self.cls_token.expand(batch_size, -1, -1), x], dim=1)
x = self.input_norm(x)
x = self.layers(x, src_key_padding_mask=padding_mask)
x = self.transformer_norm(x)
x = x[:, 0, :]
for layer in self.output_proj:
x = x+ layer(x)
x = self.final_layer(x)
return x
class Card_Preprocessing(nn.Module):
def __init__(self, num_layers, input_size, output_size, nonlinearity = nn.GELU, internal_size = 1024, dropout = 0):
super(Card_Preprocessing,self).__init__()
self.internal_size = internal_size
self.input = nn.Sequential(
nn.Linear(input_size,internal_size, bias = False),
nonlinearity(),
nn.LayerNorm(internal_size, bias = False),
nn.Dropout(dropout),
)
self.hidden_layers = nn.ModuleList()
self.dropout_rate = dropout
for i in range(num_layers):
self.hidden_layers.append(nn.Sequential(
nn.Linear(internal_size,internal_size, bias = False),
nonlinearity(),
nn.LayerNorm(internal_size, bias = False),
nn.Dropout(dropout),
))
self.output = nn.Sequential(
nn.Linear(internal_size,output_size, bias = False),
nonlinearity(),
nn.LayerNorm(output_size, bias = False)
)
self.gammas = nn.ParameterList([torch.nn.Parameter(torch.ones(1, internal_size), requires_grad = True) for i in range(num_layers)])
def forward(self,x):
x = self.input(x)
for i,layer in enumerate(self.hidden_layers):
gamma = torch.sigmoid(self.gammas[i])
x = gamma * x + (1-gamma) * layer(x)
x = self.output(x)
return x
class CrossAttnBlock(nn.Module):
"""
One deck→pack cross-attention block, Pre-LayerNorm style.
cards : [B, K, d] (queries)
deck : [B, D, d] (keys / values)
returns updated cards [B, K, d]
"""
def __init__(self, d_model: int, n_heads: int, dropout: float):
super().__init__()
self.ln_q = nn.LayerNorm(d_model)
self.ln_k = nn.LayerNorm(d_model)
self.ln_v = nn.LayerNorm(d_model)
self.xattn = nn.MultiheadAttention(
d_model, n_heads,
dropout=dropout, batch_first=True)
self.ln_ff = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(4 * d_model, d_model),
nn.Dropout(dropout),
)
self.dropout_attn = nn.Dropout(dropout)
def forward(self, cards, deck, mask = None):
# 1) deck → card cross-attention
q = self.ln_q(cards)
k = self.ln_k(deck)
v = self.ln_v(deck)
attn_out, _ = self.xattn(q, k, v, key_padding_mask = mask) # [B, K, d]
x = cards + self.dropout_attn(attn_out) # residual
# 2) position-wise feed-forward
y = self.ffn(self.ln_ff(x))
return x + y
class MLP_CrossAttention(nn.Module):
def __init__(self, input_size, num_card_layers, card_output_dim, dropout, **kwargs):
super(MLP_CrossAttention, self).__init__()
self.input_size = input_size
self.card_encoder = Card_Preprocessing(num_card_layers,
input_size = input_size,
internal_size = 1024,
output_size = card_output_dim,
dropout = dropout)
self.attention_layers = nn.ModuleList([
CrossAttnBlock(card_output_dim, n_heads=4, dropout=dropout)
for _ in range(10)
])
self.output_layer = nn.Sequential(
nn.Linear(card_output_dim, card_output_dim*2),
nn.ReLU(),
nn.LayerNorm(card_output_dim*2, bias = False),
nn.Dropout(dropout),
nn.Linear(card_output_dim*2, card_output_dim*4),
nn.ReLU(),
nn.LayerNorm(card_output_dim*4, bias = False),
nn.Dropout(dropout),
nn.Linear(card_output_dim*4, card_output_dim),
nn.ReLU(),
nn.LayerNorm(card_output_dim, bias = False),
nn.Linear(card_output_dim, 1),
)
if kwargs['path'] is not None:
self.load_state_dict(torch.load(f"{kwargs['path']}/network.pt", map_location='cpu'))
print(f"Loaded model from {kwargs['path']}/network.pt")
def forward(self, deck, cards, get_embeddings = False, no_attention = False):
batch_size, deck_size, card_size = deck.shape
deck = deck.view(batch_size * deck_size, card_size)
deck_encoded = self.card_encoder(deck)
deck_encoded = deck_encoded.view(batch_size, deck_size, -1)
# identify padded cards
mask = (cards.sum(dim=-1) != 0)
cards_encoded = self.card_encoder(cards)
if not no_attention:
# Cross-attention
for layer in self.attention_layers:
cards_encoded = layer(cards_encoded, deck_encoded)
if get_embeddings:
for layer in self.output_layer[:-3]:
cards_encoded = layer(cards_encoded)
return cards_encoded
# Output layer
logits = self.output_layer(cards_encoded)
# Mask out padded cards
logits = logits.masked_fill(~mask.unsqueeze(-1), float('-inf'))
return logits.squeeze(-1)
def get_card_embedding(self, card_embedding):
card_embedding = card_embedding.view(1,1, -1)
empty_deck = torch.zeros((1, 45, self.input_size)).to(card_embedding.device)
return self.card_encoder(card_embedding).squeeze()
return self(deck = empty_deck,
cards = card_embedding,
get_embeddings = True,
no_attention = True).squeeze(0)
def get_embedding_dict(path, add_nontransformed = False):
with open(path, 'rb') as f:
embedding_dict = pickle.load(f)
if add_nontransformed:
embedding_dict_tmp = {}
for k,v in embedding_dict.items():
embedding_dict_tmp[k] = v
if '//' in k:
embedding_dict_tmp[k.split(' // ')[0]] = v
embedding_dict = embedding_dict_tmp
return embedding_dict_tmp
return embedding_dict
def get_card_embeddings(card_names, embedding_dict, embedding_size = 1330):
embeddings = []
new_embeddings = {}
for card in card_names:
if card == '':
embeddings.append([])
elif card == []:
if type(embedding_size) == tuple:
channels, height, width = embedding_size
new_embedding = torch.zeros(1,channels, height, width)
else:
new_embedding = torch.zeros(1,embedding_size)
embeddings.append(new_embedding)
elif isinstance(card, list):
if len(card) == 0:
embeddings.append(None)
continue
deck_embedding = []
for c in card:
embedding, got_new = get_embedding_of_card(c, embedding_dict)
deck_embedding.append(embedding)
try:
num_cards = len(deck_embedding)
deck_embedding = torch.stack(deck_embedding)
if type(embedding_size) == tuple:
channels, height, width = embedding_size
deck_embedding = deck_embedding.view(num_cards,channels, height, width)
else:
deck_embedding = deck_embedding.view(num_cards,-1)
except Exception as e:
raise e
embeddings.append(deck_embedding)
else:
embedding, got_new = get_embedding_of_card(card, embedding_dict)
embeddings.append(embedding)
return embeddings
def check_for_basics(card_name, embedding_dict):
ints = ['1','2','3','4','5']
basics = ['Mountain','Forest','Swamp','Island','Plains']
for b in basics:
if b in card_name:
for i in ints:
if card_name == f'{b}_{i}':
return b
return card_name
def get_embedding_of_card(card_name, embedding_dict):
try:
card_name = check_for_basics(card_name, embedding_dict)
card_name = card_name.replace('_', ' ')
card_name = card_name.replace("Sol'kanar","Sol'Kanar")
if card_name not in embedding_dict and card_name.split(' // ')[0] not in embedding_dict and card_name.replace('A-','') not in embedding_dict:
# print(f'Requesting new embedding for {card_name}')
# attributes, text = get_card_representation(card_name = card_name)
# text_embedding = embedd_text([text]).squeeze()
# return torch.Tensor(np.concatenate((attributes, text_embedding), axis = 0)), True
raise Exception(f'Could not find {card_name}')
else:
try:
return torch.Tensor(embedding_dict[card_name]), False
except:
try:
return torch.Tensor(embedding_dict[card_name.split(' // ')[0]]), False
except:
try:
return torch.Tensor(embedding_dict[card_name.replace('_',' ')]), False
except:
try:
return torch.Tensor(embedding_dict[card_name.replace('A-','')]), False
except:
print(f'Could not find {card_name}')
raise Exception
except Exception as e:
print(f'Could not find {card_name}')
print(e)
raise e