Zeinab commited on
Commit
700864a
1 Parent(s): 0482418

Upload final_gradio.py

Browse files
Files changed (1) hide show
  1. final_gradio.py +141 -0
final_gradio.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """final_gradio.ipynb
3
+
4
+ Automatically generated by Colaboratory.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1Laxh069wv-cX4NqcOnNB2AyEDW0gmvbL
8
+ """
9
+
10
+ from google.colab import drive
11
+ drive.mount('/content/drive')
12
+
13
+ !pip install gradio
14
+ !pip install transformers==3.0.2
15
+ import pandas as pd
16
+ import numpy as np
17
+ from sklearn.model_selection import train_test_split
18
+ import torch
19
+ import seaborn as sns
20
+ import transformers
21
+ import json
22
+ from tqdm import tqdm
23
+ from torch.utils.data import Dataset, DataLoader
24
+ from transformers import RobertaModel, RobertaTokenizer
25
+ import logging
26
+ logging.basicConfig(level=logging.ERROR)
27
+ from torch import cuda
28
+ device = 'cuda' if cuda.is_available() else 'cpu'
29
+
30
+ class RobertaClass(torch.nn.Module):
31
+ def __init__(self):
32
+ super(RobertaClass, self).__init__()
33
+ self.l1 = RobertaModel.from_pretrained("roberta-base")
34
+ self.pre_classifier = torch.nn.Linear(768, 768)
35
+ self.dropout = torch.nn.Dropout(0.3)
36
+ self.classifier = torch.nn.Linear(768, 5)
37
+
38
+ def forward(self, input_ids, attention_mask, token_type_ids):
39
+ output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
40
+ hidden_state = output_1[0]
41
+ pooler = hidden_state[:, 0]
42
+ pooler = self.pre_classifier(pooler)
43
+ pooler = torch.nn.ReLU()(pooler)
44
+ pooler = self.dropout(pooler)
45
+ output = self.classifier(pooler)
46
+ return output
47
+ model = RobertaClass()
48
+ model.to(device)
49
+ #url = f"https://github.com/udacity/deep-learning-v2-pytorch/blob/master/convolutional-neural-networks/mnist-mlp/model.pt"
50
+ #!wget --no-cache --backups=1 {url}
51
+ #model.load_state_dict(torch.load("/content/model.pt"))
52
+ model.load_state_dict(torch.load("/content/drive/MyDrive/avicenna_model.pt"))
53
+ model.eval()
54
+
55
+
56
+ class SyllogismData(Dataset):
57
+ def __init__(self, dataframe, tokenizer, max_len):
58
+ self.tokenizer = tokenizer
59
+ self.data = dataframe
60
+ self.text = dataframe.Premises
61
+ self.targets = self.data.label
62
+ self.max_len = max_len
63
+
64
+ def __len__(self):
65
+ return len(self.text)
66
+
67
+ def __getitem__(self, index):
68
+ text = str(self.text[index])
69
+ text = " ".join(text.split())
70
+
71
+ inputs = self.tokenizer.encode_plus(
72
+ text,
73
+ None,
74
+ add_special_tokens=True,
75
+ max_length=self.max_len,
76
+ pad_to_max_length=True,
77
+ return_token_type_ids=True
78
+ )
79
+ ids = inputs['input_ids']
80
+ mask = inputs['attention_mask']
81
+ token_type_ids = inputs["token_type_ids"]
82
+
83
+
84
+ return {
85
+ 'ids': torch.tensor(ids, dtype=torch.long),
86
+ 'mask': torch.tensor(mask, dtype=torch.long),
87
+ 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
88
+ 'targets': torch.tensor(self.targets[index], dtype=torch.float)
89
+ }
90
+
91
+ MAX_LEN = 256
92
+ TRAIN_BATCH_SIZE = 8
93
+ VALID_BATCH_SIZE = 4
94
+ LEARNING_RATE = 1e-05
95
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base', truncation=True, do_lower_case=True)
96
+
97
+ def avicenna(input):
98
+ p1=input
99
+
100
+
101
+ data = [[p1, 10]]
102
+ testnew_df = pd.DataFrame(data, columns=['Premises', 'label'])
103
+ data
104
+ test_size = 1
105
+ test_data=testnew_df
106
+ testing_set = SyllogismData(test_data, tokenizer, MAX_LEN)
107
+ test_params = {'batch_size': VALID_BATCH_SIZE,
108
+ 'shuffle': False,
109
+ 'num_workers': 0
110
+ }
111
+ testing_loader = DataLoader(testing_set, **test_params)
112
+ model.eval()
113
+ n_correct = 0; n_wrong = 0; total = 0; tr_loss=0; nb_tr_steps=0; nb_tr_examples=0
114
+ with torch.no_grad():
115
+ for _, data in tqdm(enumerate(testing_loader, 0)):
116
+ ids = data['ids'].to(device, dtype = torch.long)
117
+ mask = data['mask'].to(device, dtype = torch.long)
118
+ token_type_ids = data['token_type_ids'].to(device, dtype=torch.long)
119
+ targets = data['targets'].to(device, dtype = torch.long)
120
+ outputs = model(ids, mask, token_type_ids).squeeze()
121
+ big_val, big_idx = torch.max(outputs.data, dim=0)
122
+ return big_idx.item()
123
+ #valid(model, testing_loader)
124
+ #print(big_idx)
125
+
126
+ import gradio as gr
127
+ iface = gr.Interface(
128
+ fn = avicenna,
129
+ title="Syllogistic NLI",
130
+ description="Select pair of sentences and see if the Avicenna-trained model can gauge the relation correctly",
131
+ #false
132
+ #inputs = gr.inputs.Dropdown(["Buildings should not be constructed on organic soils. All buildings should be constructed on stiff underlying soil with enough strength.","The launching of satellites while still contributing to national prestige is a significant economic activity. For each type of economic activity, there is some threshold for the launch cost per payload mass beyond which the economic activity will not be sustainable."], label="Select an example from Avicenna test set"),
133
+ #true
134
+ inputs = gr.inputs.Dropdown(["All humans are mortal. Socrates is a human.","Avicenna wrote the famous book the Canon of Medicine. The Canon of Medicine has influenced modern medicine","Police found signs of active bleeding before death around the corpse. A large volume of blood released from a body may indicate that the individual has died of exsanguination.","During the first trimester of pregnancy, the body undergoes hormonal fluctuations. Hormonal changes regularly are followed by extreme tiredness.","Influenza spreads the virus to the lungs. Garlic is useful to cure all infections.","With a single currency, there will no longer be a cost involved in changing currencies. It was no longer cost-effective for the government to convert metals into coins.","Pain and tension around your head and neck are known as tension headaches. Tension headaches are dull pain, tightness, or pressure around your head and neck.","Heavy rain can cause flooding. Many different health conditions can cause heavy breathing.","Eating foods that are in the Mediterranean diet helps with healthy weight loss and metabolism. Ana knows how to eat healthy and lose weight with the Mediterranean diet."], label="Select an example"),
135
+ #f and t
136
+ #inputs = gr.inputs.Dropdown(["All humans are mortal. Socrates is a human.","Avicenna wrote the famous book the Canon of Medicine. The Canon of Medicine has influenced modern medicine","Police found signs of active bleeding before death around the corpse. A large volume of blood released from a body may indicate that the individual has died of exsanguination.","During the first trimester of pregnancy, the body undergoes hormonal fluctuations. Hormonal changes regularly are followed by extreme tiredness.","Influenza spreads the virus to the lungs. Garlic is useful to cure all infections.","With a single currency, there will no longer be a cost involved in changing currencies. It was no longer cost-effective for the government to convert metals into coins.","Pain and tension around your head and neck are known as tension headaches. Tension headaches are dull pain, tightness, or pressure around your head and neck.","Heavy rain can cause flooding. Many different health conditions can cause heavy breathing.","Eating foods that are in the Mediterranean diet helps with healthy weight loss and metabolism. Ana knows how to eat healthy and lose weight with the Mediterranean diet.","Buildings should not be constructed on organic soils. All buildings should be constructed on stiff underlying soil with enough strength.","The launching of satellites while still contributing to national prestige is a significant economic activity. For each type of economic activity, there is some threshold for the launch cost per payload mass beyond which the economic activity will not be sustainable."], label="Select an example"),
137
+
138
+ outputs = gr.outputs.Textbox(label="Syllogistic relation")
139
+
140
+ )
141
+ iface.launch(enable_queue=True))