Text-to-Image
Diffusers
Safetensors
File size: 1,753 Bytes
2e73038
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c3fa6c
2e73038
4c3fa6c
2e73038
ee4ca4f
2e73038
4c3fa6c
2e73038
ee4ca4f
2e73038
4c3fa6c
2e73038
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from torch import nn

from .RecCTCHead import CTCHead
from .RecMv1_enhance import MobileNetV1Enhance
from .RNN import Im2Im, Im2Seq, SequenceEncoder


backbone_dict = {"MobileNetV1Enhance": MobileNetV1Enhance}
neck_dict = {"SequenceEncoder": SequenceEncoder, "Im2Seq": Im2Seq, "None": Im2Im}
head_dict = {"CTCHead": CTCHead}


class RecModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert "in_channels" in config, "in_channels must in model config"
        backbone_type = config["backbone"].pop("type")
        assert backbone_type in backbone_dict, f"backbone.type must in {backbone_dict}"
        self.backbone = backbone_dict[backbone_type](config['in_channels'], **config['backbone'])

        neck_type = config['neck'].pop("type")
        assert neck_type in neck_dict, f"neck.type must in {neck_dict}"
        self.neck = neck_dict[neck_type](self.backbone.out_channels, **config['neck'])

        head_type = config['head'].pop("type")
        assert head_type in head_dict, f"head.type must in {head_dict}"
        self.head = head_dict[head_type](self.neck.out_channels, **config['head'])

        self.name = f"RecModel_{backbone_type}_{neck_type}_{head_type}"

    def load_3rd_state_dict(self, _3rd_name, _state):
        self.backbone.load_3rd_state_dict(_3rd_name, _state)
        self.neck.load_3rd_state_dict(_3rd_name, _state)
        self.head.load_3rd_state_dict(_3rd_name, _state)

    def forward(self, x):
        import torch

        x = x.to(torch.float32)
        x = self.backbone(x)
        x = self.neck(x)
        x = self.head(x)
        return x

    def encode(self, x):
        x = self.backbone(x)
        x = self.neck(x)
        x = self.head.ctc_encoder(x)
        return x