import os import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoConfig, AutoTokenizer import gradio as gr os.system("gdown https://drive.google.com/uc?id=1whDb0yL_Kqoyx-sIw0sS5xTfb6r_9nlJ") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def init_params(module_lst): for module in module_lst: for param in module.parameters(): if param.dim() > 1: torch.nn.init.xavier_uniform_(param) return class Custom_bert(nn.Module): def __init__(self, model_dir): super().__init__() # load base model config = AutoConfig.from_pretrained(model_dir) config.update({"output_hidden_states": True, "hidden_dropout_prob": 0.0, "layer_norm_eps": 1e-7}) self.base = AutoModel.from_pretrained(model_dir, config=config) dim = self.base.encoder.layer[0].output.dense.bias.shape[0] self.dropout = nn.Dropout(p=0.2) self.high_dropout = nn.Dropout(p=0.5) # weights for weighted layer average n_weights = 24 weights_init = torch.zeros(n_weights).float() weights_init.data[:-1] = -3 self.layer_weights = torch.nn.Parameter(weights_init) # attention head self.attention = nn.Sequential( nn.Linear(1024, 1024), nn.Tanh(), nn.Linear(1024, 1), nn.Softmax(dim=1) ) self.cls = nn.Sequential( nn.Linear(dim, 1) ) init_params([self.cls, self.attention]) def reini_head(self): init_params([self.cls, self.attention]) return def forward(self, input_ids, attention_mask): base_output = self.base(input_ids=input_ids, attention_mask=attention_mask) # weighted average of all encoder outputs cls_outputs = torch.stack( [self.dropout(layer) for layer in base_output['hidden_states'][-24:]], dim=0 ) cls_output = ( torch.softmax(self.layer_weights, dim=0).unsqueeze(1).unsqueeze(1).unsqueeze(1) * cls_outputs).sum( 0) # multisample dropout logits = torch.mean( torch.stack( [torch.sum(self.attention(self.high_dropout(cls_output)) * cls_output, dim=1) for _ in range(5)], dim=0, ), dim=0, ) return self.cls(logits) def get_batches(input, tokenizer, batch_size=128, max_length=256, device='cpu'): out = tokenizer(input, return_tensors='pt', max_length=max_length, padding='max_length') out['input_ids'], out['attention_mask'] = out['input_ids'].to(device), out['attention_mask'].to(device) input_id_split = torch.split(out['input_ids'], max_length, dim=1) attention_split = torch.split(out['attention_mask'], max_length, dim=1) input_id_batches = [] attention_batches = [] i = 0 input_length = len(input_id_split) while i * batch_size < input_length: if i * batch_size + batch_size <= input_length: input_id_batches.append(list(input_id_split[i * batch_size:(i + 1) * batch_size])) attention_batches.append(list(attention_split[i * batch_size:(i + 1) * batch_size])) else: input_id_batches.append(list(input_id_split[i * batch_size:input_length])) attention_batches.append(list(attention_split[i * batch_size:input_length])) i += 1 if input_id_batches[-1][-1].shape[1] < max_length: input_id_batches[-1][-1] = F.pad(input_id_batches[-1][-1], (1, max_length - input_id_batches[-1][-1].shape[1] - 1), value=0) attention_batches[-1][-1] = F.pad(attention_batches[-1][-1], (1, max_length - attention_batches[-1][-1].shape[1] - 1), value=1) input_id_batches = [torch.cat(batch, dim=0) for batch in input_id_batches] attention_batches = [torch.cat(batch, dim=0) for batch in attention_batches] return tuple(zip(input_id_batches, attention_batches)) def predict(input, tokenizer, model, batch_size=128, max_length=256, max_val=-4, min_val=3, score=100): device = model.base.device batches = get_batches(input, tokenizer, batch_size, max_length, device) predictions = [] with torch.no_grad(): for input_ids, attention_mask in batches: pred = model(input_ids, attention_mask) pred = score * (pred - min_val) / (max_val - min_val) predictions.append(pred) predictions = torch.cat(predictions, dim=0) mean, std = predictions.mean().cpu().item(), predictions.std().cpu().item() mean, std = round(mean, 2), round(std, 2) if np.isnan(std): return f"The reading difficulty score is {mean}." else: return f"""The reading difficulty score is {mean} with a standard deviation of {std}. \nThe 95% confidence interval of the score is {mean - 2 * std} to {mean + 2 * std}.""" if __name__ == "__main__": deberta_loc = "deberta_large_0.pt" deberta_tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-large", model_max_length=256) model = Custom_bert("microsoft/deberta-large") model.load_state_dict(torch.load(deberta_loc, map_location=torch.device(device))) model.eval().to(device) description = """ This tool attempts to estimate how difficult a piece of text is to read by a school child. The underlying model has been developed based on expert ranking of text difficulty for students from grade 3 to 12. The score has been scaled to range from zero (very easy) to one hundred (very difficult). Very long passages will be broken up and reported with the average as well as the standard deviation of the difficulty score. """ interface = gr.Interface(fn=lambda x: predict(x, deberta_tokenizer, model, batch_size=4), inputs=gr.inputs.Textbox(lines = 7, label = "Text:", placeholder = "Insert text to be scored here."), outputs='text', title = "Reading Difficulty Analyser", description = description) interface.launch()