BradSegal's picture
Upload app.py
7f17572
raw
history blame
6.6 kB
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig, AutoTokenizer
import gradio as gr
os.system("gdown https://drive.google.com/uc?id=1whDb0yL_Kqoyx-sIw0sS5xTfb6r_9nlJ")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def init_params(module_lst):
for module in module_lst:
for param in module.parameters():
if param.dim() > 1:
torch.nn.init.xavier_uniform_(param)
return
class Custom_bert(nn.Module):
def __init__(self, model_dir):
super().__init__()
# load base model
config = AutoConfig.from_pretrained(model_dir)
config.update({"output_hidden_states": True,
"hidden_dropout_prob": 0.0,
"layer_norm_eps": 1e-7})
self.base = AutoModel.from_pretrained(model_dir, config=config)
dim = self.base.encoder.layer[0].output.dense.bias.shape[0]
self.dropout = nn.Dropout(p=0.2)
self.high_dropout = nn.Dropout(p=0.5)
# weights for weighted layer average
n_weights = 24
weights_init = torch.zeros(n_weights).float()
weights_init.data[:-1] = -3
self.layer_weights = torch.nn.Parameter(weights_init)
# attention head
self.attention = nn.Sequential(
nn.Linear(1024, 1024),
nn.Tanh(),
nn.Linear(1024, 1),
nn.Softmax(dim=1)
)
self.cls = nn.Sequential(
nn.Linear(dim, 1)
)
init_params([self.cls, self.attention])
def reini_head(self):
init_params([self.cls, self.attention])
return
def forward(self, input_ids, attention_mask):
base_output = self.base(input_ids=input_ids,
attention_mask=attention_mask)
# weighted average of all encoder outputs
cls_outputs = torch.stack(
[self.dropout(layer) for layer in base_output['hidden_states'][-24:]], dim=0
)
cls_output = (
torch.softmax(self.layer_weights, dim=0).unsqueeze(1).unsqueeze(1).unsqueeze(1) * cls_outputs).sum(
0)
# multisample dropout
logits = torch.mean(
torch.stack(
[torch.sum(self.attention(self.high_dropout(cls_output)) * cls_output, dim=1) for _ in range(5)],
dim=0,
),
dim=0,
)
return self.cls(logits)
def get_batches(input, tokenizer, batch_size=128, max_length=256, device='cpu'):
out = tokenizer(input, return_tensors='pt', max_length=max_length, padding='max_length')
out['input_ids'], out['attention_mask'] = out['input_ids'].to(device), out['attention_mask'].to(device)
input_id_split = torch.split(out['input_ids'], max_length, dim=1)
attention_split = torch.split(out['attention_mask'], max_length, dim=1)
input_id_batches = []
attention_batches = []
i = 0
input_length = len(input_id_split)
while i * batch_size < input_length:
if i * batch_size + batch_size <= input_length:
input_id_batches.append(list(input_id_split[i * batch_size:(i + 1) * batch_size]))
attention_batches.append(list(attention_split[i * batch_size:(i + 1) * batch_size]))
else:
input_id_batches.append(list(input_id_split[i * batch_size:input_length]))
attention_batches.append(list(attention_split[i * batch_size:input_length]))
i += 1
if input_id_batches[-1][-1].shape[1] < max_length:
input_id_batches[-1][-1] = F.pad(input_id_batches[-1][-1],
(1, max_length - input_id_batches[-1][-1].shape[1] - 1),
value=0)
attention_batches[-1][-1] = F.pad(attention_batches[-1][-1],
(1, max_length - attention_batches[-1][-1].shape[1] - 1),
value=1)
input_id_batches = [torch.cat(batch, dim=0) for batch in input_id_batches]
attention_batches = [torch.cat(batch, dim=0) for batch in attention_batches]
return tuple(zip(input_id_batches, attention_batches))
def predict(input, tokenizer, model, batch_size=128, max_length=256, max_val=-4, min_val=3, score=100):
device = model.base.device
batches = get_batches(input, tokenizer, batch_size, max_length, device)
predictions = []
with torch.no_grad():
for input_ids, attention_mask in batches:
pred = model(input_ids, attention_mask)
pred = score * (pred - min_val) / (max_val - min_val)
predictions.append(pred)
predictions = torch.cat(predictions, dim=0)
mean, std = predictions.mean().cpu().item(), predictions.std().cpu().item()
mean, std = round(mean, 2), round(std, 2)
if np.isnan(std):
return f"The reading difficulty score is {mean}."
else:
return f"""The reading difficulty score is {mean} with a standard deviation of {std}.
\nThe 95% confidence interval of the score is {mean - 2 * std} to {mean + 2 * std}."""
if __name__ == "__main__":
deberta_loc = "deberta_large_0.pt"
deberta_tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-large", model_max_length=256)
model = Custom_bert("microsoft/deberta-large")
model.load_state_dict(torch.load(deberta_loc))
model.eval().to(device)
description = """
This tool attempts to estimate how difficult a piece of text is to read by a school child.
The underlying model has been developed based on expert ranking of text difficulty for students from grade 3 to 12.
The score has been scaled to range from zero (very easy) to one hundred (very difficult).
Very long passages will be broken up and reported with the average as well as the standard deviation of the difficulty score.
"""
interface = gr.Interface(fn=lambda x: predict(x, deberta_tokenizer, model, batch_size=4),
inputs=gr.inputs.Textbox(lines = 7, label = "Text:",
placeholder = "Insert text to be scored here."),
outputs='text',
title = "Reading Difficulty Analyser",
description = description)
interface.launch()