Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	File size: 6,810 Bytes
			
			| eb4d305 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | from mmdet.models.detectors import CascadeRCNN
from mmdet.registry import MODELS
import torch
import torch.nn as nn
@MODELS.register_module()
class CustomCascadeWithMeta(CascadeRCNN):
    """Custom Cascade R-CNN with metadata prediction heads."""
    
    def __init__(self,
                 *args,
                 chart_cls_head=None,
                 plot_reg_head=None,
                 axes_info_head=None,
                 data_series_head=None,
                 data_points_count_head=None,
                 coordinate_standardization=None,
                 data_series_config=None,
                 axis_aware_feature=None,
                 **kwargs):
        super().__init__(*args, **kwargs)
        
        # Initialize metadata prediction heads
        if chart_cls_head is not None:
            self.chart_cls_head = MODELS.build(chart_cls_head)
        if plot_reg_head is not None:
            self.plot_reg_head = MODELS.build(plot_reg_head)
        if axes_info_head is not None:
            self.axes_info_head = MODELS.build(axes_info_head)
        if data_series_head is not None:
            self.data_series_head = MODELS.build(data_series_head)
        if data_points_count_head is not None:
            self.data_points_count_head = MODELS.build(data_points_count_head)
        else:
            # Default simple regression head for data point count
            self.data_points_count_head = nn.Sequential(
                nn.Linear(2048, 512),  # Assuming ResNet-50 backbone features
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(512, 1)  # Single output for count
            )
            
        # Store configurations
        self.coordinate_standardization = coordinate_standardization
        self.data_series_config = data_series_config
        self.axis_aware_feature = axis_aware_feature
        
    def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs):
        """Forward function during training."""
        # Get base detector predictions
        x = self.extract_feat(img)
        losses = dict()
        
        # RPN forward and loss
        if self.with_rpn:
            proposal_cfg = self.train_cfg.get('rpn_proposal',
                                            self.test_cfg.rpn)
            rpn_losses, proposal_list = self.rpn_head.forward_train(
                x,
                img_metas,
                gt_bboxes,
                gt_labels=None,
                ann_weight=None,
                proposal_cfg=proposal_cfg)
            losses.update(rpn_losses)
        else:
            proposal_list = kwargs.get('proposals', None)
            
        # ROI forward and loss
        roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
                                               gt_bboxes, gt_labels, **kwargs)
        losses.update(roi_losses)
        
        # Get global features for metadata prediction
        global_feat = x[-1].mean(dim=[2, 3])  # Global average pooling
        
        # Extract ground truth data point counts from img_metas
        gt_data_point_counts = []
        for img_meta in img_metas:
            count = img_meta.get('img_info', {}).get('num_data_points', 0)
            gt_data_point_counts.append(count)
        gt_data_point_counts = torch.tensor(gt_data_point_counts, dtype=torch.float32, device=global_feat.device)
        
        # Predict data point counts and compute loss
        pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
        data_points_count_loss = nn.MSELoss()(pred_data_point_counts, gt_data_point_counts)
        losses['data_points_count_loss'] = data_points_count_loss
        
        # Use predicted data point count as additional feature for ROI head
        # Expand the global feature with data point count information
        normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0)  # Normalize to 0-1 range
        enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1)
        
        # Metadata prediction losses
        if hasattr(self, 'chart_cls_head'):
            chart_cls_loss = self.chart_cls_head(enhanced_global_feat)
            losses['chart_cls_loss'] = chart_cls_loss
            
        if hasattr(self, 'plot_reg_head'):
            plot_reg_loss = self.plot_reg_head(enhanced_global_feat)
            losses['plot_reg_loss'] = plot_reg_loss
            
        if hasattr(self, 'axes_info_head'):
            axes_info_loss = self.axes_info_head(enhanced_global_feat)
            losses['axes_info_loss'] = axes_info_loss
            
        if hasattr(self, 'data_series_head'):
            data_series_loss = self.data_series_head(enhanced_global_feat)
            losses['data_series_loss'] = data_series_loss
            
        return losses
        
    def simple_test(self, img, img_metas, **kwargs):
        """Test without augmentation."""
        x = self.extract_feat(img)
        proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
        det_bboxes, det_labels = self.roi_head.simple_test_bboxes(
            x, img_metas, proposal_list, self.test_cfg.rcnn, **kwargs)
        
        # Get global features for metadata prediction
        global_feat = x[-1].mean(dim=[2, 3])  # Global average pooling
        
        # Predict data point counts
        pred_data_point_counts = self.data_points_count_head(global_feat).squeeze(-1)
        
        # Use predicted data point count as additional feature
        normalized_counts = torch.sigmoid(pred_data_point_counts / 100.0)  # Normalize to 0-1 range
        enhanced_global_feat = torch.cat([global_feat, normalized_counts.unsqueeze(-1)], dim=-1)
        
        # Get metadata predictions
        results = []
        for i, (bboxes, labels) in enumerate(zip(det_bboxes, det_labels)):
            result = DetDataSample()
            result.bboxes = bboxes
            result.labels = labels
            
            # Add data point count prediction
            result.predicted_data_points = pred_data_point_counts[i].item()
            
            # Add metadata predictions using enhanced features
            if hasattr(self, 'chart_cls_head'):
                result.chart_type = self.chart_cls_head(enhanced_global_feat[i:i+1])
            if hasattr(self, 'plot_reg_head'):
                result.plot_bb = self.plot_reg_head(enhanced_global_feat[i:i+1])
            if hasattr(self, 'axes_info_head'):
                result.axes_info = self.axes_info_head(enhanced_global_feat[i:i+1])
            if hasattr(self, 'data_series_head'):
                result.data_series = self.data_series_head(enhanced_global_feat[i:i+1])
                
            results.append(result)
            
        return results  | 
