BradSegal commited on
Commit
7f17572
1 Parent(s): 1cd6733

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -0
app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import AutoModel, AutoConfig, AutoTokenizer
7
+ import gradio as gr
8
+
9
+ os.system("gdown https://drive.google.com/uc?id=1whDb0yL_Kqoyx-sIw0sS5xTfb6r_9nlJ")
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+
12
+
13
+ def init_params(module_lst):
14
+ for module in module_lst:
15
+ for param in module.parameters():
16
+ if param.dim() > 1:
17
+ torch.nn.init.xavier_uniform_(param)
18
+ return
19
+
20
+
21
+ class Custom_bert(nn.Module):
22
+ def __init__(self, model_dir):
23
+ super().__init__()
24
+
25
+ # load base model
26
+ config = AutoConfig.from_pretrained(model_dir)
27
+ config.update({"output_hidden_states": True,
28
+ "hidden_dropout_prob": 0.0,
29
+ "layer_norm_eps": 1e-7})
30
+
31
+ self.base = AutoModel.from_pretrained(model_dir, config=config)
32
+
33
+ dim = self.base.encoder.layer[0].output.dense.bias.shape[0]
34
+
35
+ self.dropout = nn.Dropout(p=0.2)
36
+ self.high_dropout = nn.Dropout(p=0.5)
37
+
38
+ # weights for weighted layer average
39
+ n_weights = 24
40
+ weights_init = torch.zeros(n_weights).float()
41
+ weights_init.data[:-1] = -3
42
+ self.layer_weights = torch.nn.Parameter(weights_init)
43
+
44
+ # attention head
45
+ self.attention = nn.Sequential(
46
+ nn.Linear(1024, 1024),
47
+ nn.Tanh(),
48
+ nn.Linear(1024, 1),
49
+ nn.Softmax(dim=1)
50
+ )
51
+ self.cls = nn.Sequential(
52
+ nn.Linear(dim, 1)
53
+ )
54
+ init_params([self.cls, self.attention])
55
+
56
+ def reini_head(self):
57
+ init_params([self.cls, self.attention])
58
+ return
59
+
60
+ def forward(self, input_ids, attention_mask):
61
+ base_output = self.base(input_ids=input_ids,
62
+ attention_mask=attention_mask)
63
+
64
+ # weighted average of all encoder outputs
65
+ cls_outputs = torch.stack(
66
+ [self.dropout(layer) for layer in base_output['hidden_states'][-24:]], dim=0
67
+ )
68
+ cls_output = (
69
+ torch.softmax(self.layer_weights, dim=0).unsqueeze(1).unsqueeze(1).unsqueeze(1) * cls_outputs).sum(
70
+ 0)
71
+
72
+ # multisample dropout
73
+ logits = torch.mean(
74
+ torch.stack(
75
+ [torch.sum(self.attention(self.high_dropout(cls_output)) * cls_output, dim=1) for _ in range(5)],
76
+ dim=0,
77
+ ),
78
+ dim=0,
79
+ )
80
+ return self.cls(logits)
81
+
82
+
83
+ def get_batches(input, tokenizer, batch_size=128, max_length=256, device='cpu'):
84
+ out = tokenizer(input, return_tensors='pt', max_length=max_length, padding='max_length')
85
+ out['input_ids'], out['attention_mask'] = out['input_ids'].to(device), out['attention_mask'].to(device)
86
+ input_id_split = torch.split(out['input_ids'], max_length, dim=1)
87
+ attention_split = torch.split(out['attention_mask'], max_length, dim=1)
88
+
89
+ input_id_batches = []
90
+ attention_batches = []
91
+
92
+ i = 0
93
+ input_length = len(input_id_split)
94
+
95
+ while i * batch_size < input_length:
96
+ if i * batch_size + batch_size <= input_length:
97
+ input_id_batches.append(list(input_id_split[i * batch_size:(i + 1) * batch_size]))
98
+ attention_batches.append(list(attention_split[i * batch_size:(i + 1) * batch_size]))
99
+ else:
100
+ input_id_batches.append(list(input_id_split[i * batch_size:input_length]))
101
+ attention_batches.append(list(attention_split[i * batch_size:input_length]))
102
+ i += 1
103
+
104
+ if input_id_batches[-1][-1].shape[1] < max_length:
105
+ input_id_batches[-1][-1] = F.pad(input_id_batches[-1][-1],
106
+ (1, max_length - input_id_batches[-1][-1].shape[1] - 1),
107
+ value=0)
108
+ attention_batches[-1][-1] = F.pad(attention_batches[-1][-1],
109
+ (1, max_length - attention_batches[-1][-1].shape[1] - 1),
110
+ value=1)
111
+
112
+ input_id_batches = [torch.cat(batch, dim=0) for batch in input_id_batches]
113
+ attention_batches = [torch.cat(batch, dim=0) for batch in attention_batches]
114
+
115
+ return tuple(zip(input_id_batches, attention_batches))
116
+
117
+
118
+ def predict(input, tokenizer, model, batch_size=128, max_length=256, max_val=-4, min_val=3, score=100):
119
+ device = model.base.device
120
+ batches = get_batches(input, tokenizer, batch_size, max_length, device)
121
+
122
+ predictions = []
123
+
124
+ with torch.no_grad():
125
+ for input_ids, attention_mask in batches:
126
+ pred = model(input_ids, attention_mask)
127
+ pred = score * (pred - min_val) / (max_val - min_val)
128
+ predictions.append(pred)
129
+
130
+ predictions = torch.cat(predictions, dim=0)
131
+ mean, std = predictions.mean().cpu().item(), predictions.std().cpu().item()
132
+ mean, std = round(mean, 2), round(std, 2)
133
+ if np.isnan(std):
134
+ return f"The reading difficulty score is {mean}."
135
+ else:
136
+ return f"""The reading difficulty score is {mean} with a standard deviation of {std}.
137
+ \nThe 95% confidence interval of the score is {mean - 2 * std} to {mean + 2 * std}."""
138
+
139
+
140
+ if __name__ == "__main__":
141
+ deberta_loc = "deberta_large_0.pt"
142
+ deberta_tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-large", model_max_length=256)
143
+
144
+ model = Custom_bert("microsoft/deberta-large")
145
+ model.load_state_dict(torch.load(deberta_loc))
146
+ model.eval().to(device)
147
+
148
+
149
+ description = """
150
+ This tool attempts to estimate how difficult a piece of text is to read by a school child.
151
+ The underlying model has been developed based on expert ranking of text difficulty for students from grade 3 to 12.
152
+ The score has been scaled to range from zero (very easy) to one hundred (very difficult).
153
+ Very long passages will be broken up and reported with the average as well as the standard deviation of the difficulty score.
154
+ """
155
+
156
+ interface = gr.Interface(fn=lambda x: predict(x, deberta_tokenizer, model, batch_size=4),
157
+ inputs=gr.inputs.Textbox(lines = 7, label = "Text:",
158
+ placeholder = "Insert text to be scored here."),
159
+ outputs='text',
160
+ title = "Reading Difficulty Analyser",
161
+ description = description)
162
+ interface.launch()
163
+