Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import pickle | |
| class Deck_Attention(nn.Module): | |
| def __init__(self, input_size, output_dim, num_heads=8, num_layers=3, output_layers = 2, dropout=0.2): | |
| super(Deck_Attention, self).__init__() | |
| # Input projection and normalization | |
| self.hidden_dim = 1024 | |
| self.input_proj = nn.Linear(input_size, self.hidden_dim, bias = False) | |
| self.input_norm = nn.LayerNorm(self.hidden_dim, bias = False) | |
| self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_dim)) | |
| self.pos_encoding = nn.Embedding(45, self.hidden_dim) | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model= self.hidden_dim, | |
| nhead = num_heads, | |
| dim_feedforward= self.hidden_dim * 4, | |
| dropout=dropout, | |
| activation='gelu', | |
| batch_first=True, | |
| norm_first=True, | |
| ) | |
| self.layers = nn.TransformerEncoder(encoder_layer, | |
| num_layers=num_layers, | |
| enable_nested_tensor=False, | |
| ) | |
| self.transformer_norm = nn.LayerNorm(self.hidden_dim, bias = False) | |
| # Output projection | |
| self.output_proj = nn.ModuleList( | |
| [nn.Sequential( | |
| nn.Linear(self.hidden_dim, self.hidden_dim, bias = False), | |
| nn.GELU(), | |
| nn.LayerNorm(self.hidden_dim, bias = False), | |
| nn.Dropout(dropout) ) for _ in range(output_layers)]) | |
| self.final_layer = nn.Sequential( | |
| nn.Linear(self.hidden_dim, self.hidden_dim, bias = False), | |
| nn.LayerNorm(self.hidden_dim, bias = False), | |
| nn.GELU(), | |
| nn.Linear(self.hidden_dim, output_dim, bias = False)) | |
| def forward(self, x, lens=None): | |
| # Reshape input if needed | |
| x = x.view(x.size(0), x.size(-2), x.size(-1)) | |
| batch_size = x.size(0) | |
| # Create padding mask | |
| padding_mask = None | |
| if lens is not None: | |
| lens = lens.to(x.device) | |
| padding_mask = torch.arange(45, device=x.device).expand(batch_size, 45) >= lens.unsqueeze(1) | |
| padding_mask = torch.cat((torch.zeros(padding_mask.shape[0], 1, device= padding_mask.device).bool(), padding_mask), dim = 1) | |
| # Initial projection and add position embeddings | |
| x = self.input_proj(x) | |
| pos = torch.arange(45, device=x.device).expand(batch_size, 45) | |
| pos = self.pos_encoding(pos) | |
| x = x + pos | |
| x = torch.cat([self.cls_token.expand(batch_size, -1, -1), x], dim=1) | |
| x = self.input_norm(x) | |
| x = self.layers(x, src_key_padding_mask=padding_mask) | |
| x = self.transformer_norm(x) | |
| x = x[:, 0, :] | |
| for layer in self.output_proj: | |
| x = x+ layer(x) | |
| x = self.final_layer(x) | |
| return x | |
| class Card_Preprocessing(nn.Module): | |
| def __init__(self, num_layers, input_size, output_size, nonlinearity = nn.GELU, internal_size = 1024, dropout = 0): | |
| super(Card_Preprocessing,self).__init__() | |
| self.internal_size = internal_size | |
| self.input = nn.Sequential( | |
| nn.Linear(input_size,internal_size, bias = False), | |
| nonlinearity(), | |
| nn.LayerNorm(internal_size, bias = False), | |
| nn.Dropout(dropout), | |
| ) | |
| self.hidden_layers = nn.ModuleList() | |
| self.dropout_rate = dropout | |
| for i in range(num_layers): | |
| self.hidden_layers.append(nn.Sequential( | |
| nn.Linear(internal_size,internal_size, bias = False), | |
| nonlinearity(), | |
| nn.LayerNorm(internal_size, bias = False), | |
| nn.Dropout(dropout), | |
| )) | |
| self.output = nn.Sequential( | |
| nn.Linear(internal_size,output_size, bias = False), | |
| nonlinearity(), | |
| nn.LayerNorm(output_size, bias = False) | |
| ) | |
| self.gammas = nn.ParameterList([torch.nn.Parameter(torch.ones(1, internal_size), requires_grad = True) for i in range(num_layers)]) | |
| def forward(self,x): | |
| x = self.input(x) | |
| for i,layer in enumerate(self.hidden_layers): | |
| gamma = torch.sigmoid(self.gammas[i]) | |
| x = gamma * x + (1-gamma) * layer(x) | |
| x = self.output(x) | |
| return x | |
| class CrossAttnBlock(nn.Module): | |
| """ | |
| One deck→pack cross-attention block, Pre-LayerNorm style. | |
| cards : [B, K, d] (queries) | |
| deck : [B, D, d] (keys / values) | |
| returns updated cards [B, K, d] | |
| """ | |
| def __init__(self, d_model: int, n_heads: int, dropout: float): | |
| super().__init__() | |
| self.ln_q = nn.LayerNorm(d_model) | |
| self.ln_k = nn.LayerNorm(d_model) | |
| self.ln_v = nn.LayerNorm(d_model) | |
| self.xattn = nn.MultiheadAttention( | |
| d_model, n_heads, | |
| dropout=dropout, batch_first=True) | |
| self.ln_ff = nn.LayerNorm(d_model) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(d_model, 4 * d_model), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(4 * d_model, d_model), | |
| nn.Dropout(dropout), | |
| ) | |
| self.dropout_attn = nn.Dropout(dropout) | |
| def forward(self, cards, deck, mask = None): | |
| # 1) deck → card cross-attention | |
| q = self.ln_q(cards) | |
| k = self.ln_k(deck) | |
| v = self.ln_v(deck) | |
| attn_out, _ = self.xattn(q, k, v, key_padding_mask = mask) # [B, K, d] | |
| x = cards + self.dropout_attn(attn_out) # residual | |
| # 2) position-wise feed-forward | |
| y = self.ffn(self.ln_ff(x)) | |
| return x + y | |
| class MLP_CrossAttention(nn.Module): | |
| def __init__(self, input_size, num_card_layers, card_output_dim, dropout, **kwargs): | |
| super(MLP_CrossAttention, self).__init__() | |
| self.input_size = input_size | |
| self.card_encoder = Card_Preprocessing(num_card_layers, | |
| input_size = input_size, | |
| internal_size = 1024, | |
| output_size = card_output_dim, | |
| dropout = dropout) | |
| self.attention_layers = nn.ModuleList([ | |
| CrossAttnBlock(card_output_dim, n_heads=4, dropout=dropout) | |
| for _ in range(10) | |
| ]) | |
| self.output_layer = nn.Sequential( | |
| nn.Linear(card_output_dim, card_output_dim*2), | |
| nn.ReLU(), | |
| nn.LayerNorm(card_output_dim*2, bias = False), | |
| nn.Dropout(dropout), | |
| nn.Linear(card_output_dim*2, card_output_dim*4), | |
| nn.ReLU(), | |
| nn.LayerNorm(card_output_dim*4, bias = False), | |
| nn.Dropout(dropout), | |
| nn.Linear(card_output_dim*4, card_output_dim), | |
| nn.ReLU(), | |
| nn.LayerNorm(card_output_dim, bias = False), | |
| nn.Linear(card_output_dim, 1), | |
| ) | |
| if kwargs['path'] is not None: | |
| self.load_state_dict(torch.load(f"{kwargs['path']}/network.pt", map_location='cpu')) | |
| print(f"Loaded model from {kwargs['path']}/network.pt") | |
| def forward(self, deck, cards, get_embeddings = False, no_attention = False): | |
| batch_size, deck_size, card_size = deck.shape | |
| deck = deck.view(batch_size * deck_size, card_size) | |
| deck_encoded = self.card_encoder(deck) | |
| deck_encoded = deck_encoded.view(batch_size, deck_size, -1) | |
| # identify padded cards | |
| mask = (cards.sum(dim=-1) != 0) | |
| cards_encoded = self.card_encoder(cards) | |
| if not no_attention: | |
| # Cross-attention | |
| for layer in self.attention_layers: | |
| cards_encoded = layer(cards_encoded, deck_encoded) | |
| if get_embeddings: | |
| for layer in self.output_layer[:-3]: | |
| cards_encoded = layer(cards_encoded) | |
| return cards_encoded | |
| # Output layer | |
| logits = self.output_layer(cards_encoded) | |
| # Mask out padded cards | |
| logits = logits.masked_fill(~mask.unsqueeze(-1), float('-inf')) | |
| return logits.squeeze(-1) | |
| def get_card_embedding(self, card_embedding): | |
| card_embedding = card_embedding.view(1,1, -1) | |
| empty_deck = torch.zeros((1, 45, self.input_size)).to(card_embedding.device) | |
| return self.card_encoder(card_embedding).squeeze() | |
| return self(deck = empty_deck, | |
| cards = card_embedding, | |
| get_embeddings = True, | |
| no_attention = True).squeeze(0) | |
| def get_embedding_dict(path, add_nontransformed = False): | |
| with open(path, 'rb') as f: | |
| embedding_dict = pickle.load(f) | |
| if add_nontransformed: | |
| embedding_dict_tmp = {} | |
| for k,v in embedding_dict.items(): | |
| embedding_dict_tmp[k] = v | |
| if '//' in k: | |
| embedding_dict_tmp[k.split(' // ')[0]] = v | |
| embedding_dict = embedding_dict_tmp | |
| return embedding_dict_tmp | |
| return embedding_dict | |
| def get_card_embeddings(card_names, embedding_dict, embedding_size = 1330): | |
| embeddings = [] | |
| new_embeddings = {} | |
| for card in card_names: | |
| if card == '': | |
| embeddings.append([]) | |
| elif card == []: | |
| if type(embedding_size) == tuple: | |
| channels, height, width = embedding_size | |
| new_embedding = torch.zeros(1,channels, height, width) | |
| else: | |
| new_embedding = torch.zeros(1,embedding_size) | |
| embeddings.append(new_embedding) | |
| elif isinstance(card, list): | |
| if len(card) == 0: | |
| embeddings.append(None) | |
| continue | |
| deck_embedding = [] | |
| for c in card: | |
| embedding, got_new = get_embedding_of_card(c, embedding_dict) | |
| deck_embedding.append(embedding) | |
| try: | |
| num_cards = len(deck_embedding) | |
| deck_embedding = torch.stack(deck_embedding) | |
| if type(embedding_size) == tuple: | |
| channels, height, width = embedding_size | |
| deck_embedding = deck_embedding.view(num_cards,channels, height, width) | |
| else: | |
| deck_embedding = deck_embedding.view(num_cards,-1) | |
| except Exception as e: | |
| raise e | |
| embeddings.append(deck_embedding) | |
| else: | |
| embedding, got_new = get_embedding_of_card(card, embedding_dict) | |
| embeddings.append(embedding) | |
| return embeddings | |
| def check_for_basics(card_name, embedding_dict): | |
| ints = ['1','2','3','4','5'] | |
| basics = ['Mountain','Forest','Swamp','Island','Plains'] | |
| for b in basics: | |
| if b in card_name: | |
| for i in ints: | |
| if card_name == f'{b}_{i}': | |
| return b | |
| return card_name | |
| def get_embedding_of_card(card_name, embedding_dict): | |
| try: | |
| card_name = check_for_basics(card_name, embedding_dict) | |
| card_name = card_name.replace('_', ' ') | |
| card_name = card_name.replace("Sol'kanar","Sol'Kanar") | |
| if card_name not in embedding_dict and card_name.split(' // ')[0] not in embedding_dict and card_name.replace('A-','') not in embedding_dict: | |
| # print(f'Requesting new embedding for {card_name}') | |
| # attributes, text = get_card_representation(card_name = card_name) | |
| # text_embedding = embedd_text([text]).squeeze() | |
| # return torch.Tensor(np.concatenate((attributes, text_embedding), axis = 0)), True | |
| raise Exception(f'Could not find {card_name}') | |
| else: | |
| try: | |
| return torch.Tensor(embedding_dict[card_name]), False | |
| except: | |
| try: | |
| return torch.Tensor(embedding_dict[card_name.split(' // ')[0]]), False | |
| except: | |
| try: | |
| return torch.Tensor(embedding_dict[card_name.replace('_',' ')]), False | |
| except: | |
| try: | |
| return torch.Tensor(embedding_dict[card_name.replace('A-','')]), False | |
| except: | |
| print(f'Could not find {card_name}') | |
| raise Exception | |
| except Exception as e: | |
| print(f'Could not find {card_name}') | |
| print(e) | |
| raise e |