smhavens commited on
Commit
d5bca77
1 Parent(s): 9be73ed

Minimized training version

Browse files
Files changed (1) hide show
  1. train.py +265 -0
train.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import math
3
+ import spacy
4
+ from datasets import load_dataset
5
+ from sentence_transformers import SentenceTransformer
6
+ from sentence_transformers import InputExample
7
+ from sentence_transformers import losses
8
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
9
+ from transformers import TrainingArguments, Trainer
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import DataLoader
13
+ import numpy as np
14
+ import evaluate
15
+ import nltk
16
+ from nltk.corpus import stopwords
17
+ import subprocess
18
+ import sys
19
+
20
+ # !pip install https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl
21
+ # subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl'])
22
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
23
+ # nltk.download('stopwords')
24
+ # nlp = spacy.load("en_core_web_sm")
25
+ # stops = stopwords.words("english")
26
+
27
+ # answer = "Pizza"
28
+ guesses = []
29
+ answer = "Pizza"
30
+
31
+
32
+ #Mean Pooling - Take attention mask into account for correct averaging
33
+ def mean_pooling(model_output, attention_mask):
34
+ token_embeddings = model_output[0] #First element of model_output contains all token embeddings
35
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
36
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
37
+
38
+
39
+ # def normalize(comment, lowercase, remove_stopwords):
40
+ # if lowercase:
41
+ # comment = comment.lower()
42
+ # comment = nlp(comment)
43
+ # lemmatized = list()
44
+ # for word in comment:
45
+ # lemma = word.lemma_.strip()
46
+ # if lemma:
47
+ # if not remove_stopwords or (remove_stopwords and lemma not in stops):
48
+ # lemmatized.append(lemma)
49
+ # return " ".join(lemmatized)
50
+
51
+
52
+ def tokenize_function(examples):
53
+ return tokenizer(examples["text"])
54
+
55
+
56
+ def compute_metrics(eval_pred):
57
+ logits, labels = eval_pred
58
+ predictions = np.argmax(logits, axis=-1)
59
+ metric = evaluate.load("accuracy")
60
+ return metric.compute(predictions=predictions, references=labels)
61
+
62
+
63
+ def training():
64
+ dataset_id = "ag_news"
65
+ dataset = load_dataset(dataset_id)
66
+ # dataset = dataset["train"]
67
+ # tokenized_datasets = dataset.map(tokenize_function, batched=True)
68
+
69
+ print(f"- The {dataset_id} dataset has {dataset['train'].num_rows} examples.")
70
+ print(f"- Each example is a {type(dataset['train'][0])} with a {type(dataset['train'][0]['text'])} as value.")
71
+ print(f"- Examples look like this: {dataset['train'][0]}")
72
+
73
+ # small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
74
+ # small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
75
+
76
+ # dataset = dataset["train"].map(tokenize_function, batched=True)
77
+ # dataset.set_format(type="torch", columns=["input_ids", "token_type_ids", "attention_mask", "label"])
78
+ # dataset.format['type']
79
+
80
+ # print(dataset)
81
+
82
+ train_examples = []
83
+ train_data = dataset["train"]
84
+ # For agility we only 1/2 of our available data
85
+ n_examples = dataset["train"].num_rows // 2
86
+ # n_remaining = dataset["train"].num_rows - n_examples
87
+ # dataset_clean = {}
88
+ # # dataset_0 = []
89
+ # # dataset_1 = []
90
+ # # dataset_2 = []
91
+ # # dataset_3 = []
92
+ # for i in range(n_examples):
93
+ # dataset_clean[i] = {}
94
+ # dataset_clean[i]["text"] = normalize(train_data[i]["text"], lowercase=True, remove_stopwords=True)
95
+ # dataset_clean[i]["label"] = train_data[i]["label"]
96
+ # if train_data[i]["label"] == 0:
97
+ # dataset_0.append(dataset_clean[i])
98
+ # elif train_data[i]["label"] == 1:
99
+ # dataset_1.append(dataset_clean[i])
100
+ # elif train_data[i]["label"] == 2:
101
+ # dataset_2.append(dataset_clean[i])
102
+ # elif train_data[i]["label"] == 3:
103
+ # dataset_3.append(dataset_clean[i])
104
+ # n_0 = len(dataset_0) // 2
105
+ # n_1 = len(dataset_1) // 2
106
+ # n_2 = len(dataset_2) // 2
107
+ # n_3 = len(dataset_3) // 2
108
+ # print("Label lengths:", len(dataset_0), len(dataset_1), len(dataset_2), len(dataset_3))
109
+
110
+ for i in range(n_examples):
111
+ example = train_data[i]
112
+ # example_opposite = dataset_clean[-(i)]
113
+ # print(example["text"])
114
+ train_examples.append(InputExample(texts=[example['text']], label=example['label']))
115
+
116
+ # for i in range(n_0):
117
+ # example = dataset_0[i]
118
+ # # example_opposite = dataset_0[-(i)]
119
+ # # print(example["text"])
120
+ # train_examples.append(InputExample(texts=[example['text']], label=0))
121
+
122
+ # for i in range(n_1):
123
+ # example = dataset_1[i]
124
+ # # example_opposite = dataset_1[-(i)]
125
+ # # print(example["text"])
126
+ # train_examples.append(InputExample(texts=[example['text']], label=1))
127
+
128
+ # for i in range(n_2):
129
+ # example = dataset_2[i]
130
+ # # example_opposite = dataset_2[-(i)]
131
+ # # print(example["text"])
132
+ # train_examples.append(InputExample(texts=[example['text']], label=2))
133
+
134
+ # for i in range(n_3):
135
+ # example = dataset_3[i]
136
+ # # example_opposite = dataset_3[-(i)]
137
+ # # print(example["text"])
138
+ # train_examples.append(InputExample(texts=[example['text']], label=3))
139
+
140
+ train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=25)
141
+
142
+ print("END DATALOADER")
143
+
144
+ # print(train_examples)
145
+
146
+ embeddings = finetune(train_dataloader)
147
+
148
+ return (dataset['train'].num_rows, type(dataset['train'][0]), type(dataset['train'][0]['text']), dataset['train'][0], embeddings)
149
+
150
+
151
+ def finetune(train_dataloader):
152
+ # model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)
153
+ model_id = "sentence-transformers/all-MiniLM-L6-v2"
154
+ model = SentenceTransformer(model_id)
155
+
156
+ # training_args = TrainingArguments(output_dir="test_trainer")
157
+
158
+ # USE THIS LINK
159
+ # https://huggingface.co/blog/how-to-train-sentence-transformers
160
+
161
+ train_loss = losses.BatchHardSoftMarginTripletLoss(model=model)
162
+
163
+ print("BEGIN FIT")
164
+
165
+ model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=10)
166
+
167
+ model.save("ag_news_model")
168
+
169
+ model.save_to_hub("smhavens/all-MiniLM-agNews")
170
+ # accuracy = compute_metrics(eval, metric)
171
+
172
+ # training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")
173
+
174
+ # trainer = Trainer(
175
+ # model=model,
176
+ # args=training_args,
177
+ # train_dataset=train,
178
+ # eval_dataset=eval,
179
+ # compute_metrics=compute_metrics,
180
+ # )
181
+
182
+ # trainer.train()
183
+
184
+ sentences = ["This is an example sentence", "Each sentence is converted"]
185
+
186
+ # model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
187
+ embeddings = model.encode(sentences)
188
+ print(embeddings)
189
+
190
+ # Sentences we want sentence embeddings for
191
+ sentences = ['This is an example sentence', 'Each sentence is converted']
192
+
193
+ # Load model from HuggingFace Hub
194
+ # tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
195
+ # model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
196
+
197
+ # Tokenize sentences
198
+ encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
199
+
200
+ # Compute token embeddings
201
+ with torch.no_grad():
202
+ model_output = model(**encoded_input)
203
+
204
+ # Perform pooling
205
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
206
+
207
+ # Normalize embeddings
208
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
209
+
210
+ print("Sentence embeddings:")
211
+ print(sentence_embeddings)
212
+ return sentence_embeddings
213
+
214
+
215
+
216
+ def greet(name):
217
+ return "Hello " + name + "!!"
218
+
219
+ def check_answer(guess:str):
220
+ global guesses
221
+ global answer
222
+ guesses.append(guess)
223
+ output = ""
224
+ for guess in guesses:
225
+ output += ("- " + guess + "\n")
226
+ output = output[:-1]
227
+
228
+ if guess.lower() == answer.lower():
229
+ return "Correct!", output
230
+ else:
231
+ return "Try again!", output
232
+
233
+ def main():
234
+ word1 = "Black"
235
+ word2 = "White"
236
+ word3 = "Sun"
237
+ global answer
238
+ answer = "Moon"
239
+ global guesses
240
+
241
+ num_rows, data_type, value, example, embeddings = training()
242
+
243
+ # prompt = f"{word1} is to {word2} as {word3} is to ____"
244
+ # with gr.Blocks() as iface:
245
+ # gr.Markdown(prompt)
246
+ # with gr.Tab("Guess"):
247
+ # text_input = gr.Textbox()
248
+ # text_output = gr.Textbox()
249
+ # text_button = gr.Button("Submit")
250
+ # with gr.Accordion("Open for previous guesses"):
251
+ # text_guesses = gr.Textbox()
252
+ # with gr.Tab("Testing"):
253
+ # gr.Markdown(f"""Number of rows in dataset is {num_rows}, with each having type {data_type} and value {value}.
254
+ # An example is {example}.
255
+ # The Embeddings are {embeddings}.""")
256
+ # text_button.click(check_answer, inputs=[text_input], outputs=[text_output, text_guesses])
257
+ # # iface = gr.Interface(fn=greet, inputs="text", outputs="text")
258
+ # iface.launch()
259
+
260
+
261
+
262
+
263
+
264
+ if __name__ == "__main__":
265
+ main()