Spaces:
Runtime error
Runtime error
Cosmos-Predict2-2BASD
/
diffusers_repo
/examples
/research_projects
/anytext
/ocr_recog
/RecModel.py
| from torch import nn | |
| from .RecCTCHead import CTCHead | |
| from .RecMv1_enhance import MobileNetV1Enhance | |
| from .RNN import Im2Im, Im2Seq, SequenceEncoder | |
| backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance} | |
| neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im} | |
| head_dict = {"CTCHead": CTCHead} | |
| class RecModel(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| assert "in_channels" in config, "in_channels must in model config" | |
| backbone_type = config["backbone"].pop("type") | |
| assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}" | |
| self.backbone = backbone_dict[backbone_type](config["in_channels"], **config["backbone"]) | |
| neck_type = config["neck"].pop("type") | |
| assert neck_type in neck_dict, f"neck.type must in {neck_dict}" | |
| self.neck = neck_dict[neck_type](self.backbone.out_channels, **config["neck"]) | |
| head_type = config["head"].pop("type") | |
| assert head_type in head_dict, f"head.type must in {head_dict}" | |
| self.head = head_dict[head_type](self.neck.out_channels, **config["head"]) | |
| self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}" | |
| def load_3rd_state_dict(self, _3rd_name, _state): | |
| self.backbone.load_3rd_state_dict(_3rd_name, _state) | |
| self.neck.load_3rd_state_dict(_3rd_name, _state) | |
| self.head.load_3rd_state_dict(_3rd_name, _state) | |
| def forward(self, x): | |
| import torch | |
| x = x.to(torch.float32) | |
| x = self.backbone(x) | |
| x = self.neck(x) | |
| x = self.head(x) | |
| return x | |
| def encode(self, x): | |
| x = self.backbone(x) | |
| x = self.neck(x) | |
| x = self.head.ctc_encoder(x) | |
| return x | |