FrancisGOS's picture
Fix recommender system
0d5add8
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, BertConfig, BertModel
from encoder import SearchActivityEncoder
class EnhancedUserTower(nn.Module):
def __init__(
self,
input_dim=20,
hidden_dim=64,
out_dim=32,
search_emb_dim=32,
num_heads=4,
num_layers=2,
with_search=True,
):
super().__init__()
config = BertConfig(
hidden_size=hidden_dim,
num_hidden_layers=num_layers,
num_attention_heads=num_heads,
intermediate_size=hidden_dim * 4,
)
self.with_search = with_search
self.user_input_dim = input_dim - search_emb_dim
self.feature_proj = nn.Linear(
self.user_input_dim if with_search else input_dim,
hidden_dim,
)
self.transformer = BertModel(config)
self.pooler = nn.Linear(hidden_dim, out_dim)
self.search_fusion = nn.Sequential(
nn.Linear(out_dim + search_emb_dim, out_dim),
nn.ReLU(),
)
def forward(self, encoding):
if self.with_search:
feat = encoding[:, : self.user_input_dim]
search_emb = encoding[:, self.user_input_dim :]
else:
feat = encoding
x = self.feature_proj(feat)
out = self.transformer(inputs_embeds=x.unsqueeze(1)).last_hidden_state.squeeze(
1
)
user_emb = self.pooler(out)
if self.with_search:
user_emb = self.search_fusion(torch.cat([user_emb, search_emb], dim=-1))
return F.normalize(user_emb, dim=1)
class TextTower(nn.Module):
def __init__(
self,
model_name="huawei-noah/TinyBERT_General_6L_768D",
proj_hidden=256,
out_dim=128,
):
super().__init__()
self.encoder = AutoModel.from_pretrained(model_name)
dim = self.encoder.config.hidden_size
self.proj = nn.Sequential(
nn.Linear(dim, proj_hidden), nn.ReLU(), nn.Linear(proj_hidden, out_dim)
)
def forward(self, input_ids, attention_mask):
# CLS token pooling
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
cls = outputs.last_hidden_state[:, 0]
return self.proj(cls)
class StructuredTower(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
)
def forward(self, x):
return self.mlp(x)
class MultiModalAttentionFusion(nn.Module):
"""
Fuse text, structured, and review embeddings via self-attention.
"""
def __init__(self, dim, fusion_dim, num_heads=4):
super().__init__()
self.mha = nn.MultiheadAttention(
embed_dim=dim, num_heads=num_heads, batch_first=True
)
self.proj = nn.Sequential(
nn.Linear(dim, fusion_dim), nn.ReLU(), nn.Linear(fusion_dim, fusion_dim)
)
def forward(self, text_emb, struct_emb, review_emb):
# stack modalities as sequence length=3
x = torch.stack([text_emb, struct_emb, review_emb], dim=1)
attn_out, _ = self.mha(x, x, x)
pooled = attn_out.mean(dim=1)
return self.proj(pooled)
class PropertyTower(nn.Module):
def __init__(
self,
struct_dim,
text_hidden_dim,
fusion_dim,
out_dim,
num_views=3,
dropout=0.2,
noise_std=0.01,
review_model="huawei-noah/TinyBERT_General_6L_768D",
review_proj=128,
review_out=64,
):
super().__init__()
# shared encoder for text & reviews
self.base_text = AutoModel.from_pretrained(review_model)
# separate projection heads
self.text_proj = nn.Sequential(
nn.Linear(self.base_text.config.hidden_size, text_hidden_dim),
nn.ReLU(),
nn.Linear(text_hidden_dim, fusion_dim),
)
self.review_proj = nn.Sequential(
nn.Linear(self.base_text.config.hidden_size, review_proj),
nn.ReLU(),
nn.Linear(review_proj, fusion_dim),
)
self.structured_tower = StructuredTower(struct_dim, fusion_dim)
self.fusion = MultiModalAttentionFusion(fusion_dim, fusion_dim)
self.dropout = nn.Dropout(dropout)
self.noise_std = noise_std
self.views = nn.ModuleList(
[
nn.Sequential(
nn.Linear(fusion_dim, fusion_dim // 2),
nn.ReLU(),
nn.Linear(fusion_dim // 2, out_dim),
)
for _ in range(num_views)
]
)
self.contrastive = nn.Linear(out_dim, out_dim)
def _encode_reviews(self, review_seq):
ids, mask = review_seq["input_ids"], review_seq["attention_mask"]
B, R, L = ids.size()
if R == 0:
# Return a zero tensor with expected output shape: (B, D)
out_dim = self.review_proj[
-1
].out_features # get the output dimension of the last layer in review_proj
return torch.zeros(B, out_dim, device=ids.device)
# Flatten to (B*R, L)
flat_ids = ids.view(-1, L)
flat_mask = mask.view(-1, L)
# Forward through base text encoder
out = self.base_text(
input_ids=flat_ids, attention_mask=flat_mask
).last_hidden_state[:, 0]
# Project and reshape
proj = self.review_proj(out).view(B, R, -1)
# Average over reviews
return proj.mean(dim=1)
def forward(self, struct_feat, text_seq, review_seq=None):
# text
t = self.base_text(
input_ids=text_seq["input_ids"], attention_mask=text_seq["attention_mask"]
).last_hidden_state[:, 0]
text_emb = self.text_proj(t)
# struct
s = self.structured_tower(struct_feat)
# review
if review_seq is not None:
r = self._encode_reviews(review_seq)
else:
# zero vector fallback
r = torch.zeros_like(text_emb)
fused = self.fusion(text_emb, s, r)
# multi-view
vs = []
for m in self.views:
x = self.dropout(fused)
x = x + torch.randn_like(x) * self.noise_std
vs.append(m(x))
avg = torch.stack(vs).mean(dim=0)
return F.normalize(avg, dim=1), [
F.normalize(self.contrastive(v), dim=1) for v in vs
]
class TwoTowerRec(nn.Module):
def __init__(
self,
search_args: dict,
user_input_dim: int,
prop_args: dict,
):
super().__init__()
self.search_enc = SearchActivityEncoder(**search_args)
self.user_tower = EnhancedUserTower(input_dim=user_input_dim, with_search=False)
self.prop_tower = PropertyTower(**prop_args)
def forward(self, user_feat, prop_text, prop_struct, review_text=None, search=None):
if search is not None:
s_emb = self.search_enc(search)
u_emb = self.user_tower(torch.cat([user_feat, s_emb], dim=-1))
else:
u_emb = self.user_tower(user_feat)
p_emb, p_views = self.prop_tower(prop_text, prop_struct, review_text)
return u_emb, p_emb, p_views
def cross_view_contrastive_loss(views, temp=0.07):
loss = 0.0
count = 0
B = views[0].size(0)
for i in range(len(views)):
for j in range(i + 1, len(views)):
l = torch.matmul(views[i], views[j].T) / temp
tgt = torch.arange(B, device=l.device)
loss += (F.cross_entropy(l, tgt) + F.cross_entropy(l.T, tgt)) / 2
count += 1
return loss / count
def multi_pos_info_nce(
dist_matrix: torch.Tensor,
label_matrix: torch.Tensor,
w_neg: float = 1.5,
w_unl: float = 0.3,
tau: float = 0.07,
eps: float = 1e-8,
) -> torch.Tensor:
"""
dist_matrix: [U, P] distances (>=0)
label_matrix: [U, P] binary (1=positive, 0=negative)
tau: temperature for softmax
w_neg, w_unl: weighting for explicit vs. unlabeled negatives
returns: scalar loss
"""
sim = -dist_matrix
pos_mask = label_matrix == 1
neg_mask = label_matrix == -1
unl_mask = label_matrix == 0
W = torch.zeros_like(sim)
W[pos_mask] = 1.0
W[neg_mask] = w_neg
W[unl_mask] = w_unl
exp_sim = torch.exp(sim / tau) # [U, P]
weighted = W * exp_sim
num = (weighted * pos_mask).sum(dim=1) # [U]
denom = weighted.sum(dim=1) + eps # [U]
valid = pos_mask.sum(dim=1) > 0 # [U]
if valid.sum() == 0:
return torch.tensor(0.0, device=dist_matrix.device)
loss_per_user = -torch.log(num[valid] / denom[valid])
return loss_per_user.mean()
def float_to_sign(tensor: torch.Tensor, low_thresh: float, high_thresh: float):
result = torch.zeros_like(tensor)
result[tensor > high_thresh] = 1
result[tensor < low_thresh] = -1
return result
def pairwise_positive_ranking_loss(dist_matrix, score_matrix, margin=0.1):
"""
dist_matrix: Tensor [U, P], pairwise distances between user and items
score_matrix: Tensor [U, P], scores or labels (higher = more relevant)
"""
loss = 0.0
num_users = dist_matrix.size(0)
for u in range(num_users):
pos_idx = (score_matrix[u] > 0).nonzero(as_tuple=True)[0]
if pos_idx.numel() < 2:
continue
pos_scores = score_matrix[u, pos_idx]
pos_dists = dist_matrix[u, pos_idx]
for i in range(len(pos_idx)):
for j in range(i + 1, len(pos_idx)):
s_i, s_j = pos_scores[i], pos_scores[j]
d_i, d_j = pos_dists[i], pos_dists[j]
if s_i == s_j:
continue
sign = torch.sign(s_j - s_i)
loss += torch.relu(sign * (d_i - d_j) + margin)
return loss / num_users
class SoftContrastiveLoss(torch.nn.Module):
def __init__(
self,
margin: float = 1.0,
temp: float = 0.3,
lambda_ortho: float = 0.1,
low_thresh: float = 0.4,
high_thresh: float = 0.7,
):
super().__init__()
self.margin = margin
self.temp = temp
self.lambda_ortho = lambda_ortho
self.low_thresh = low_thresh
self.high_thresh = high_thresh
def forward(self, u_emb, p_emb, p_views, t, user_ids, prop_ids):
# Create matrix
uniq_u, inv_u = torch.unique(user_ids, return_inverse=True)
uniq_p, inv_p = torch.unique(prop_ids, return_inverse=True)
U, P = uniq_u.size(0), uniq_p.size(0)
M = torch.zeros(U, P, device=user_ids.device)
M[inv_u, inv_p] = t
T = torch.zeros(U, P, device=user_ids.device)
T[inv_u, inv_p] = float_to_sign(t, self.low_thresh, self.high_thresh)
dir_matrix = torch.zeros(U, P, device=user_ids.device)
# scatter the scores
dir_matrix[inv_u, inv_p] = torch.sign(t - 0.5)
# Calculate distance
dist = F.pairwise_distance(u_emb, p_emb) # [batch]
dist_matrix = torch.zeros(U, P, device=user_ids.device)
dist_matrix[inv_u, inv_p] = dist
info_nce_loss = multi_pos_info_nce(dist_matrix, T, t=self.temp)
hinge_loss = pairwise_positive_ranking_loss(dist_matrix, M)
# Multi-view contrastive loss
ortho_loss = torch.mean(torch.abs(torch.matmul(u_emb.T, p_emb)))
return info_nce_loss + hinge_loss + ortho_loss * self.lambda_ortho