| from typing import Dict, Optional, Any |
| from transformers import MistralConfig |
|
|
|
|
| class GigaCheckConfig(MistralConfig): |
| def __init__( |
| self, |
| with_detr: bool = False, |
| detr_config: Optional[Dict[str, Any]] = None, |
| freeze_backbone: bool = False, |
| id2label: Dict[int, str] = None, |
| num_labels: int = 2, |
| max_length: int = 1024, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
|
|
| self.with_detr = with_detr |
| self.detr_config = detr_config |
| self.freeze_backbone = freeze_backbone |
| self.id2label = id2label |
| self.num_labels = num_labels |
| self.max_length = max_length |
|
|
| if self.id2label: |
| self.id2label = {int(k): v for k, v in self.id2label.items()} |
|
|