sunwaee commited on
Commit
b438028
·
1 Parent(s): 3b45242

added scripts

Browse files
Files changed (2) hide show
  1. app.py +120 -0
  2. source/pipeline.py +138 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding:utf-8
2
+ """
3
+ Filename: app.py
4
+ Author: @DvdNss
5
+
6
+ Created on 12/18/2021
7
+ """
8
+ import os
9
+
10
+ import gdown as gdown
11
+ import nltk
12
+ import streamlit as st
13
+ from nltk.tokenize import sent_tokenize
14
+
15
+ from source.pipeline import MultiLabelPipeline, inputs_to_dataset
16
+
17
+
18
+ def download_models(ids):
19
+ """
20
+ Download all models.
21
+
22
+ :param ids: name and links of models
23
+ :return:
24
+ """
25
+
26
+ # Download sentence tokenizer
27
+ nltk.download('punkt')
28
+
29
+ # Download model from drive if not stored locally
30
+ for key in ids:
31
+ if not os.path.isfile(f"model/{key}.pt"):
32
+ url = f"https://drive.google.com/uc?id={ids[key]}"
33
+ gdown.download(url=url, output=f"model/{key}.pt")
34
+
35
+
36
+ @st.cache
37
+ def load_labels():
38
+ """
39
+ Load model labels.
40
+
41
+ :return:
42
+ """
43
+
44
+ return [
45
+ "admiration",
46
+ "amusement",
47
+ "anger",
48
+ "annoyance",
49
+ "approval",
50
+ "caring",
51
+ "confusion",
52
+ "curiosity",
53
+ "desire",
54
+ "disappointment",
55
+ "disapproval",
56
+ "disgust",
57
+ "embarrassment",
58
+ "excitement",
59
+ "fear",
60
+ "gratitude",
61
+ "grief",
62
+ "joy",
63
+ "love",
64
+ "nervousness",
65
+ "optimism",
66
+ "pride",
67
+ "realization",
68
+ "relief",
69
+ "remorse",
70
+ "sadness",
71
+ "surprise",
72
+ "neutral"
73
+ ]
74
+
75
+
76
+ @st.cache(allow_output_mutation=True)
77
+ def load_model(model_path):
78
+ """
79
+ Load model and cache it.
80
+
81
+ :param model_path: path to model
82
+ :return:
83
+ """
84
+
85
+ model = MultiLabelPipeline(model_path=model_path)
86
+
87
+ return model
88
+
89
+
90
+ # Page config
91
+ st.set_page_config(layout="centered")
92
+ st.title("Multiclass Emotion Classification")
93
+ st.write("DeepMind Language Perceiver for Multiclass Emotion Classification (Eng). ")
94
+
95
+ # Variables
96
+ ids = {'perceiver-go-emotions': '15m-p0Pwwnh3STi7zXHkKr9HFxliGJikU'}
97
+ labels = load_labels()
98
+
99
+ # Download all models from drive
100
+ download_models(ids)
101
+
102
+ # Display labels
103
+ st.markdown(f"__Labels:__ {', '.join(labels)}")
104
+
105
+ # Model selection
106
+ left, right = st.columns([4, 2])
107
+ inputs = left.text_area('', max_chars=2048, placeholder='Write something here to see what happens! ')
108
+ model_path = right.selectbox('', options=[k for k in ids], index=0, help='Model to use. ')
109
+ split = right.checkbox('Split into sentences')
110
+ model = load_model(model_path=f"model/{model_path}.pt")
111
+ right.write(model.device)
112
+
113
+ if split:
114
+ if not inputs.isspace() and inputs != "":
115
+ with st.spinner('Processing text... This may take a while.'):
116
+ left.write(model(inputs_to_dataset(sent_tokenize(inputs)), batch_size=1))
117
+ else:
118
+ if not inputs.isspace() and inputs != "":
119
+ with st.spinner('Processing text... This may take a while.'):
120
+ left.write(model(inputs_to_dataset([inputs]), batch_size=1))
source/pipeline.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding:utf-8
2
+ """
3
+ Filename: inference.py
4
+ Author: @DvdNss
5
+
6
+ Created on 12/17/2021
7
+ """
8
+ from typing import List
9
+
10
+ import torch
11
+ from datasets import Dataset
12
+ from torch.utils.data import DataLoader
13
+ from tqdm import tqdm
14
+ from transformers import PerceiverTokenizer
15
+
16
+
17
+ def _map_outputs(predictions):
18
+ """
19
+ Map model outputs to classes.
20
+
21
+ :param predictions: model ouptut batch
22
+ :return:
23
+ """
24
+
25
+ labels = [
26
+ "admiration",
27
+ "amusement",
28
+ "anger",
29
+ "annoyance",
30
+ "approval",
31
+ "caring",
32
+ "confusion",
33
+ "curiosity",
34
+ "desire",
35
+ "disappointment",
36
+ "disapproval",
37
+ "disgust",
38
+ "embarrassment",
39
+ "excitement",
40
+ "fear",
41
+ "gratitude",
42
+ "grief",
43
+ "joy",
44
+ "love",
45
+ "nervousness",
46
+ "optimism",
47
+ "pride",
48
+ "realization",
49
+ "relief",
50
+ "remorse",
51
+ "sadness",
52
+ "surprise",
53
+ "neutral"
54
+ ]
55
+ classes = []
56
+ for i, example in enumerate(predictions):
57
+ out_batch = []
58
+ for j, category in enumerate(example):
59
+ out_batch.append(labels[j]) if category > 0.5 else None
60
+ classes.append(out_batch)
61
+ return classes
62
+
63
+
64
+ class MultiLabelPipeline:
65
+ """
66
+ Multi label classification pipeline.
67
+ """
68
+
69
+ def __init__(self, model_path):
70
+ """
71
+ Init MLC pipeline.
72
+
73
+ :param model_path: model to use
74
+ """
75
+
76
+ # Init attributes
77
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
78
+ if self.device == 'cuda':
79
+ self.model = torch.load(model_path).eval().to(self.device)
80
+ else:
81
+ self.model = torch.load(model_path, map_location=torch.device('cpu')).eval().to(self.device)
82
+ self.tokenizer = PerceiverTokenizer.from_pretrained('deepmind/language-perceiver')
83
+
84
+ def __call__(self, dataset, batch_size: int = 4):
85
+ """
86
+ Processing pipeline.
87
+
88
+ :param dataset: dataset
89
+ :return:
90
+ """
91
+
92
+ # Tokenize inputs
93
+ dataset = dataset.map(lambda row: self.tokenizer(row['text'], padding="max_length", truncation=True),
94
+ batched=True, remove_columns=['text'], desc='Tokenizing')
95
+ dataset.set_format('torch', columns=['input_ids', 'attention_mask'])
96
+ dataloader = DataLoader(dataset, batch_size=batch_size)
97
+
98
+ # Define output classes
99
+ classes = []
100
+ mem_logs = []
101
+
102
+ with tqdm(dataloader, unit='batches') as progression:
103
+ for batch in progression:
104
+ progression.set_description('Inference')
105
+ # Forward
106
+ outputs = self.model(inputs=batch['input_ids'].to(self.device),
107
+ attention_mask=batch['attention_mask'].to(self.device), )
108
+
109
+ # Outputs
110
+ predictions = outputs.logits.cpu().detach().numpy()
111
+
112
+ # Map predictions to classes
113
+ batch_classes = _map_outputs(predictions)
114
+
115
+ for row in batch_classes:
116
+ classes.append(row)
117
+
118
+ # Retrieve memory usage
119
+ memory = round(torch.cuda.memory_reserved(self.device) / 1e9, 2)
120
+ mem_logs.append(memory)
121
+
122
+ # Update pbar
123
+ progression.set_postfix(memory=f"{round(sum(mem_logs) / len(mem_logs), 2)}Go")
124
+
125
+ return classes
126
+
127
+
128
+ def inputs_to_dataset(inputs: List[str]):
129
+ """
130
+ Convert a list of strings to a dataset object.
131
+
132
+ :param inputs: list of strings
133
+ :return:
134
+ """
135
+
136
+ inputs = {'text': [input for input in inputs]}
137
+
138
+ return Dataset.from_dict(inputs)