nazneen commited on
Commit
d394488
1 Parent(s): 2443328
seal/run_inference.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest import result
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.utils.data import DataLoader, Subset
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ from datasets import load_dataset
7
+
8
+ import os
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ from utils.inference_utils import InferenceResults, saveResults
13
+
14
+ # Load validation set
15
+
16
+ def load_session(dataset, model, split):
17
+ dataset = load_dataset(dataset, split=split)
18
+ dataloader = DataLoader(
19
+ dataset,
20
+ batch_size=256, drop_last=True
21
+ )
22
+ model = AutoModelForSequenceClassification.from_pretrained(model)
23
+ tokenizer = AutoTokenizer.from_pretrained(model)
24
+ return tokenizer, dataloader, model
25
+
26
+ # Add hook to capture hidden layer
27
+ def get_input(name, model):
28
+ hidden_layers = {}
29
+ def hook(model, input, output):
30
+ if name in hidden_layers:
31
+ del hidden_layers[name]
32
+ hidden_layers[name] = input[0].detach()
33
+ return hook, hidden_layers
34
+
35
+ def run_inference(dataset='yelp_polarity', model='textattack/albert-base-v2-yelp-polarity', split='test', output_path='./assets/data/inference_results'):
36
+ tokenizer, dataloader, model = load_session(dataset,model,split)
37
+ model.eval()
38
+ model.to('cpu')
39
+ hook, hidden_layers = model.classifier.register_forward_hook(get_input('last_layer', model))
40
+ # Run inference on entire dataset
41
+ hidden_list = []
42
+ loss_list = []
43
+ output_list = []
44
+ example = []
45
+ labels = []
46
+ criterion = nn.CrossEntropyLoss(reduction='none')
47
+ softmax = nn.Softmax(dim=1)
48
+ with torch.no_grad():
49
+ for batch_num, batch in tqdm(enumerate(dataloader), total=len(dataloader), position=0, leave=True):
50
+ batch_ex = [ex[:512] for ex in batch['text']]
51
+ inputs = tokenizer(batch_ex, padding=True, return_tensors='pt').to('cpu')
52
+ targets = batch['label']
53
+
54
+ outputs = model(**inputs)['logits']
55
+ loss = criterion(outputs, targets)
56
+ predictions = softmax(outputs)
57
+
58
+ hidden_list.append(hidden_layers['last_layer'].cpu())
59
+ loss_list.append(loss.cpu())
60
+ #output_list.append(predictions[:, 1].cpu())
61
+ output_list.append(np.argmax(predictions, axis=1))
62
+ labels.append(targets)
63
+ example.append(inputs['input_ids'])
64
+ embeddings = torch.vstack(hidden_list)
65
+ #outputs = torch.hstack(output_list)
66
+ losses = torch.hstack(loss_list)
67
+ targets = torch.hstack(labels)
68
+ #inputs = torch.hstack(example)
69
+ results = save_results(embeddings,losses,targets)
70
+ saveResults(os.path.join(output_path,dataset+'.pkl'),results)
71
+
72
+
73
+
74
+ def save_results(embeddings, losses, labels):
75
+ results = InferenceResults(
76
+ embeddings = torch.clone(embeddings),
77
+ losses = losses,
78
+ labels = labels
79
+ )
80
+ return results
81
+
82
+
seal/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .style_hacks import *
seal/utils/inference_utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from dataclasses import dataclass
3
+ import torch
4
+
5
+ @dataclass
6
+ class InferenceResults:
7
+ """
8
+ Class for storing embeddings and losses from running inference on a model.
9
+
10
+ Fields:
11
+ - embeddings: (num_examples x num_dimensions) tensor of last-layer embeddings
12
+ - losses: (num_examples x 1) tensor of losses
13
+ - outputs: optional (num_examples x num_classes) tensor of output logits
14
+ - labels: optional (num_examples x 1) tensor of labels
15
+ """
16
+
17
+ embeddings: torch.Tensor
18
+ losses: torch.Tensor
19
+ outputs: torch.Tensor = None
20
+ labels: torch.Tensor = None
21
+
22
+ def saveResults(fname, results):
23
+ with open(fname, 'wb+') as f:
24
+ pickle.dump(results, f)
seal/utils/style_hacks.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ placeholder for all streamlit style hacks
3
+ """
4
+ import streamlit as st
5
+
6
+
7
+ def init_style():
8
+ return st.markdown(
9
+ """
10
+ <style>
11
+ /* Side Bar */
12
+ [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
13
+ width: 250px;
14
+ }
15
+ [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
16
+ width: 250px;
17
+ }
18
+ [data-testid="stSidebar"]{
19
+ flex-basis: unset;
20
+ }
21
+ .css-1outpf7 {
22
+ background-color:rgb(254 244 219);
23
+ width:10rem;
24
+ padding:10px 10px 10px 10px;
25
+ }
26
+
27
+ /* Main Panel*/
28
+ .css-18e3th9 {
29
+ padding:10px 10px 10px -200px;
30
+ }
31
+ .css-1ubw6au:last-child{
32
+ background-color:lightblue;
33
+ }
34
+
35
+ /* Model Panels : element-container */
36
+ .element-container{
37
+ border-style:none
38
+ }
39
+
40
+ /* Radio Button Direction*/
41
+ div.row-widget.stRadio > div{flex-direction:row;}
42
+
43
+ /* Expander Boz*/
44
+ .streamlit-expander {
45
+ border-width: 0px;
46
+ border-bottom: 1px solid #A29C9B;
47
+ border-radius: 10px;
48
+ }
49
+
50
+ .streamlit-expanderHeader {
51
+ font-style: italic;
52
+ font-weight :600;
53
+ font-size:16px;
54
+ padding-top:0px;
55
+ padding-left: 0px;
56
+ color:#A29C9B
57
+
58
+ /* Section Headers */
59
+ .sectionHeader {
60
+ font-size:10px;
61
+ }
62
+ [data-testid="stMarkdownContainer]{
63
+ font-family: sans-serif;
64
+ font-weight: 500;
65
+ font-size: 1.5 rem !important;
66
+ color: rgb(250, 250, 250);
67
+ padding: 1.25rem 0px 1rem;
68
+ margin: 0px;
69
+ line-height: 1.4;
70
+ }
71
+
72
+ /* text input*/
73
+ .st-e5 {
74
+ background-color:lightblue;
75
+ }
76
+ /*line special*/
77
+ .line-one{
78
+ border-width: 0px;
79
+ border-bottom: 1px solid #A29C9B;
80
+ border-radius: 50px;
81
+ }
82
+
83
+ </style>
84
+ """,
85
+ unsafe_allow_html=True,
86
+ )