Spaces:
Running
Running
| # model.py | |
| import torch | |
| import torch.nn as nn | |
| from transformers import DistilBertTokenizer, DistilBertModel | |
| from torchvision.models import efficientnet_b0 | |
| from config import AUTHENTICITY_CLASSES, CATEGORIES | |
| class AuctionAuthenticityModel(nn.Module): | |
| def __init__(self, num_classes=None, device='cpu'): | |
| # If num_classes not specified, use config | |
| if num_classes is None: | |
| num_classes = len(AUTHENTICITY_CLASSES) | |
| # Category classes (separate head) | |
| num_categories = len(CATEGORIES) | |
| super().__init__() | |
| self.device = device | |
| # Vision | |
| self.vision_model = efficientnet_b0(pretrained=True) | |
| self.vision_model.classifier = nn.Identity() | |
| vision_out_dim = 1280 | |
| # Text | |
| self.text_model = DistilBertModel.from_pretrained( | |
| 'distilbert-base-multilingual-cased' | |
| ) | |
| text_out_dim = 768 | |
| self.tokenizer = DistilBertTokenizer.from_pretrained( | |
| 'distilbert-base-multilingual-cased' | |
| ) | |
| # Fusion encoder (shared) -> then two heads (authenticity + category) | |
| hidden_dim = 256 | |
| self.fusion_encoder = nn.Sequential( | |
| nn.Linear(vision_out_dim + text_out_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(hidden_dim, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| ) | |
| # Heads | |
| self.auth_head = nn.Linear(128, num_classes) | |
| self.cat_head = nn.Linear(128, num_categories) | |
| # store sizes for reference | |
| self.num_classes = num_classes | |
| self.num_categories = num_categories | |
| def forward(self, images, texts): | |
| vision_features = self.vision_model(images) | |
| tokens = self.tokenizer( | |
| texts, padding=True, truncation=True, max_length=512, return_tensors='pt' | |
| ).to(self.device) | |
| text_outputs = self.text_model(**tokens) | |
| text_features = text_outputs.last_hidden_state[:, 0, :] | |
| combined = torch.cat([vision_features, text_features], dim=1) | |
| shared = self.fusion_encoder(combined) | |
| auth_logits = self.auth_head(shared) | |
| cat_logits = self.cat_head(shared) | |
| # probabilities | |
| auth_probs = torch.softmax(auth_logits, dim=1) | |
| cat_probs = torch.softmax(cat_logits, dim=1) | |
| return { | |
| 'auth_logits': auth_logits, | |
| 'auth_probs': auth_probs, | |
| 'cat_logits': cat_logits, | |
| 'cat_probs': cat_probs, | |
| } | |
| def compute_loss(self, outputs, auth_labels=None, cat_labels=None, auth_weight=1.0, cat_weight=1.0): | |
| """Compute combined loss for two heads. Labels should be LongTensors on same device. | |
| Returns combined scalar loss and a dict with individual losses. | |
| """ | |
| losses = {} | |
| loss = 0.0 | |
| criterion = nn.CrossEntropyLoss() | |
| if auth_labels is not None: | |
| l_auth = criterion(outputs['auth_logits'], auth_labels) | |
| losses['auth_loss'] = l_auth | |
| loss = loss + auth_weight * l_auth | |
| if cat_labels is not None: | |
| # Allow sentinel -1 for unknown/uncertain categories and ignore them | |
| if cat_labels.dim() == 1: | |
| mask = cat_labels >= 0 | |
| else: | |
| mask = (cat_labels.squeeze(-1) >= 0) | |
| if mask.sum().item() > 0: | |
| selected_logits = outputs['cat_logits'][mask] | |
| selected_labels = cat_labels[mask] | |
| l_cat = criterion(selected_logits, selected_labels) | |
| losses['cat_loss'] = l_cat | |
| loss = loss + cat_weight * l_cat | |
| else: | |
| # No valid category labels in batch | |
| losses['cat_loss'] = torch.tensor(0.0, device=self.device) | |
| return loss, losses | |
| def count_parameters(self): | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| if __name__ == '__main__': | |
| print("Testowanie modelu...") | |
| device = torch.device('cpu') | |
| model = AuctionAuthenticityModel(device=device).to(device) | |
| print(f"✓ Model stworzony") | |
| print(f" - Parametrów: {model.count_parameters():,}") | |
| # Dummy test | |
| dummy_img = torch.randn(2, 3, 224, 224).to(device) | |
| dummy_texts = ["Silver spoon antique", "Polish silverware 19th century"] | |
| with torch.no_grad(): | |
| output = model(dummy_img, dummy_texts) | |
| # Print shapes | |
| print("✓ Forward pass:") | |
| print(f" - auth_logits: {output['auth_logits'].shape}") | |
| print(f" - auth_probs: {output['auth_probs'].shape}") | |
| print(f" - cat_logits: {output['cat_logits'].shape}") | |
| print(f" - cat_probs: {output['cat_probs'].shape}") | |
| # Show predicted labels and top probabilities | |
| auth_pred = torch.argmax(output['auth_probs'], dim=1) | |
| cat_pred = torch.argmax(output['cat_probs'], dim=1) | |
| for i in range(output['auth_probs'].shape[0]): | |
| a_idx = int(auth_pred[i].item()) | |
| a_prob = float(output['auth_probs'][i, a_idx].item()) | |
| c_idx = int(cat_pred[i].item()) | |
| c_prob = float(output['cat_probs'][i, c_idx].item()) | |
| a_name = AUTHENTICITY_CLASSES.get(a_idx, str(a_idx)) | |
| c_name = CATEGORIES.get(c_idx, str(c_idx)) | |
| print(f"\nSample {i}:") | |
| print(f" - Authenticity: {a_name} ({a_prob:.3f})") | |
| print(f" - Category: {c_name} ({c_prob:.3f})") | |
| # Estimate model size | |
| print(f"\n📊 Rozmiar modelu:") | |
| torch.save(model.state_dict(), 'temp_model.pt') | |
| import os | |
| size_mb = os.path.getsize('temp_model.pt') / (1024*1024) | |
| print(f" - {size_mb:.1f} MB") | |
| os.remove('temp_model.pt') |