|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from typing import Dict, Tuple |
|
|
|
class KickstarterModel(nn.Module): |
|
"""Kickstarter Project Success Prediction Model""" |
|
|
|
def __init__( |
|
self, |
|
desc_embedding_dim=768, |
|
blurb_embedding_dim=384, |
|
risk_embedding_dim=384, |
|
subcategory_embedding_dim=100, |
|
category_embedding_dim=15, |
|
country_embedding_dim=100, |
|
numerical_features_dim=9, |
|
hidden_dim=512, |
|
dropout_rate=0.3 |
|
): |
|
""" |
|
Initialize the model |
|
|
|
Args: |
|
desc_embedding_dim: Description embedding vector dimension |
|
blurb_embedding_dim: Blurb embedding vector dimension |
|
risk_embedding_dim: Risk embedding vector dimension |
|
subcategory_embedding_dim: Subcategory embedding vector dimension |
|
category_embedding_dim: Category embedding vector dimension |
|
country_embedding_dim: Country embedding vector dimension |
|
numerical_features_dim: Numerical features dimension |
|
hidden_dim: Hidden layer dimension |
|
dropout_rate: Dropout rate |
|
""" |
|
super(KickstarterModel, self).__init__() |
|
|
|
|
|
def create_fc_block(input_dim, output_dim): |
|
return nn.Sequential( |
|
nn.Linear(input_dim, output_dim), |
|
nn.BatchNorm1d(output_dim), |
|
nn.ReLU(), |
|
nn.Dropout(dropout_rate) |
|
) |
|
|
|
|
|
self.desc_fc = create_fc_block(desc_embedding_dim, hidden_dim) |
|
self.blurb_fc = create_fc_block(blurb_embedding_dim, hidden_dim // 2) |
|
self.risk_fc = create_fc_block(risk_embedding_dim, hidden_dim // 2) |
|
self.subcategory_fc = create_fc_block(subcategory_embedding_dim, hidden_dim // 4) |
|
self.category_fc = create_fc_block(category_embedding_dim, hidden_dim // 8) |
|
self.country_fc = create_fc_block(country_embedding_dim, hidden_dim // 8) |
|
self.numerical_fc = create_fc_block(numerical_features_dim, hidden_dim // 4) |
|
|
|
|
|
concat_dim = (hidden_dim + |
|
hidden_dim // 2 + |
|
hidden_dim // 2 + |
|
hidden_dim // 4 + |
|
hidden_dim // 8 + |
|
hidden_dim // 8 + |
|
hidden_dim // 4) |
|
|
|
|
|
self.fc1 = create_fc_block(concat_dim, hidden_dim) |
|
self.fc2 = create_fc_block(hidden_dim, hidden_dim // 2) |
|
|
|
|
|
self.output = nn.Linear(hidden_dim // 2, 1) |
|
|
|
|
|
self.input_names = [ |
|
'description_embedding', |
|
'blurb_embedding', |
|
'risk_embedding', |
|
'subcategory_embedding', |
|
'category_embedding', |
|
'country_embedding', |
|
'numerical_features' |
|
] |
|
|
|
def forward(self, inputs: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
""" |
|
Forward propagation |
|
|
|
Args: |
|
inputs: Dictionary containing all input features |
|
|
|
Returns: |
|
Prediction probability and intermediate feature representations |
|
""" |
|
|
|
desc_out = self.desc_fc(inputs['description_embedding']) |
|
blurb_out = self.blurb_fc(inputs['blurb_embedding']) |
|
risk_out = self.risk_fc(inputs['risk_embedding']) |
|
subcategory_out = self.subcategory_fc(inputs['subcategory_embedding']) |
|
category_out = self.category_fc(inputs['category_embedding']) |
|
country_out = self.country_fc(inputs['country_embedding']) |
|
numerical_out = self.numerical_fc(inputs['numerical_features']) |
|
|
|
|
|
combined = torch.cat([ |
|
desc_out, |
|
blurb_out, |
|
risk_out, |
|
subcategory_out, |
|
category_out, |
|
country_out, |
|
numerical_out |
|
], dim=1) |
|
|
|
|
|
x = self.fc1(combined) |
|
x = self.fc2(x) |
|
|
|
|
|
logits = self.output(x) |
|
probs = torch.sigmoid(logits) |
|
|
|
|
|
intermediate_features = { |
|
'description_embedding': desc_out, |
|
'blurb_embedding': blurb_out, |
|
'risk_embedding': risk_out, |
|
'subcategory_embedding': subcategory_out, |
|
'category_embedding': category_out, |
|
'country_embedding': country_out, |
|
'numerical_features': numerical_out, |
|
'combined': combined, |
|
'fc1': x |
|
} |
|
|
|
return probs.squeeze(1), intermediate_features |
|
|
|
def predict(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: |
|
""" |
|
Prediction function |
|
|
|
Args: |
|
inputs: Dictionary containing all input features |
|
|
|
Returns: |
|
Prediction probability |
|
""" |
|
self.eval() |
|
with torch.no_grad(): |
|
probs, _ = self.forward(inputs) |
|
return probs |