BradSegal commited on
Commit
903cda8
1 Parent(s): b274276

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -93
app.py DELETED
@@ -1,93 +0,0 @@
1
- import os
2
- import numpy as np
3
- import torch
4
- import torch.nn.functional as F
5
- from components.model import Custom_bert
6
- from transformers import AutoTokenizer
7
- import gradio as gr
8
-
9
- os.system("gdown https://drive.google.com/uc?id=1whDb0yL_Kqoyx-sIw0sS5xTfb6r_9nlJ")
10
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
-
12
-
13
- def get_batches(input, tokenizer, batch_size=128, max_length=256, device='cpu'):
14
- out = tokenizer(input, return_tensors='pt', max_length=max_length, padding='max_length')
15
- out['input_ids'], out['attention_mask'] = out['input_ids'].to(device), out['attention_mask'].to(device)
16
- input_id_split = torch.split(out['input_ids'], max_length, dim=1)
17
- attention_split = torch.split(out['attention_mask'], max_length, dim=1)
18
-
19
- input_id_batches = []
20
- attention_batches = []
21
-
22
- i = 0
23
- input_length = len(input_id_split)
24
-
25
- while i * batch_size < input_length:
26
- if i * batch_size + batch_size <= input_length:
27
- input_id_batches.append(list(input_id_split[i * batch_size:(i + 1) * batch_size]))
28
- attention_batches.append(list(attention_split[i * batch_size:(i + 1) * batch_size]))
29
- else:
30
- input_id_batches.append(list(input_id_split[i * batch_size:input_length]))
31
- attention_batches.append(list(attention_split[i * batch_size:input_length]))
32
- i += 1
33
-
34
- if input_id_batches[-1][-1].shape[1] < max_length:
35
- input_id_batches[-1][-1] = F.pad(input_id_batches[-1][-1],
36
- (1, max_length - input_id_batches[-1][-1].shape[1] - 1),
37
- value=0)
38
- attention_batches[-1][-1] = F.pad(attention_batches[-1][-1],
39
- (1, max_length - attention_batches[-1][-1].shape[1] - 1),
40
- value=1)
41
-
42
- input_id_batches = [torch.cat(batch, dim=0) for batch in input_id_batches]
43
- attention_batches = [torch.cat(batch, dim=0) for batch in attention_batches]
44
-
45
- return tuple(zip(input_id_batches, attention_batches))
46
-
47
-
48
- def predict(input, tokenizer, model, batch_size=128, max_length=256, max_val=-4, min_val=3, score=100):
49
- device = model.base.device
50
- batches = get_batches(input, tokenizer, batch_size, max_length, device)
51
-
52
- predictions = []
53
-
54
- with torch.no_grad():
55
- for input_ids, attention_mask in batches:
56
- pred = model(input_ids, attention_mask)
57
- pred = score * (pred - min_val) / (max_val - min_val)
58
- predictions.append(pred)
59
-
60
- predictions = torch.cat(predictions, dim=0)
61
- mean, std = predictions.mean().cpu().item(), predictions.std().cpu().item()
62
- mean, std = round(mean, 2), round(std, 2)
63
- if np.isnan(std):
64
- return f"The reading difficulty score is {mean}."
65
- else:
66
- return f"""The reading difficulty score is {mean} with a standard deviation of {std}.
67
- \nThe 95% confidence interval of the score is {mean - 2 * std} to {mean + 2 * std}."""
68
-
69
-
70
- if __name__ == "__main__":
71
- deberta_loc = "deberta_large_0.pt"
72
- deberta_tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-large", model_max_length=256)
73
-
74
- model = Custom_bert("microsoft/deberta-large")
75
- model.load_state_dict(torch.load(deberta_loc))
76
- model.eval().to(device)
77
-
78
-
79
- description = """
80
- This tool attempts to estimate how difficult a piece of text is to read by a school child.
81
- The underlying model has been developed based on expert ranking of text difficulty for students from grade 3 to 12.
82
- The score has been scaled to range from zero (very easy) to one hundred (very difficult).
83
- Very long passages will be broken up and reported with the average as well as the standard deviation of the difficulty score.
84
- """
85
-
86
- interface = gr.Interface(fn=lambda x: predict(x, deberta_tokenizer, model, batch_size=4),
87
- inputs=gr.inputs.Textbox(lines = 7, label = "Text:",
88
- placeholder = "Insert text to be scored here."),
89
- outputs='text',
90
- title = "Reading Difficulty Analyser",
91
- description = description)
92
- interface.launch()
93
-