dvsth commited on
Commit
38c803c
·
1 Parent(s): 5d363cf

Upload model

Browse files
Files changed (2) hide show
  1. LegibilityModel.py +26 -0
  2. config.json +3 -0
LegibilityModel.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import VisionEncoderDecoderModel, PreTrainedModel, AutoConfig
3
+
4
+ class LegibilityModel(PreTrainedModel):
5
+ def __init__(self):
6
+ config = AutoConfig.from_pretrained("microsoft/trocr-base-handwritten")
7
+ super(LegibilityModel, self).__init__(config=config)
8
+
9
+ # base model architecture
10
+ self.model = VisionEncoderDecoderModel(config).encoder
11
+
12
+ # change dropout during training
13
+ self.stack = nn.Sequential(
14
+ nn.Dropout(0),
15
+ nn.Linear(768, 768),
16
+ nn.ReLU(),
17
+ nn.Dropout(0),
18
+ nn.Linear(768, 1)
19
+ )
20
+
21
+ def forward(self, img_batch):
22
+ output = self.model(img_batch)
23
+ # average the output of the last hidden layer
24
+ output = output.last_hidden_state.mean(dim=1)
25
+ scores = self.stack(output)
26
+ return scores
config.json CHANGED
@@ -4,6 +4,9 @@
4
  "architectures": [
5
  "LegibilityModel"
6
  ],
 
 
 
7
  "decoder": {
8
  "_name_or_path": "",
9
  "activation_dropout": 0.0,
 
4
  "architectures": [
5
  "LegibilityModel"
6
  ],
7
+ "auto_map": {
8
+ "AutoModel": "LegibilityModel.LegibilityModel"
9
+ },
10
  "decoder": {
11
  "_name_or_path": "",
12
  "activation_dropout": 0.0,