m3hrdadfi commited on
Commit
82bf4de
·
1 Parent(s): 0219ca2

Add extra scripts

Browse files
notes/flax_to_pytorch.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import jax.numpy as jnp
4
+
5
+ from transformers import AutoTokenizer
6
+ from transformers import FlaxT5ForConditionalGeneration
7
+ from transformers import T5ForConditionalGeneration
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained("../")
10
+ model_fx = FlaxT5ForConditionalGeneration.from_pretrained("../")
11
+ model_pt = T5ForConditionalGeneration.from_pretrained("../", from_flax=True)
12
+ model_pt.save_pretrained("./")
13
+
14
+ text = "Hello To You"
15
+ e_input_ids_fx = tokenizer(text, return_tensors="np", padding=True, max_length=128, truncation=True)
16
+ d_input_ids_fx = jnp.ones((e_input_ids_fx.input_ids.shape[0], 1), dtype="i4") * model_fx.config.decoder_start_token_id
17
+
18
+ e_input_ids_pt = tokenizer(text, return_tensors="pt", padding=True, max_length=128, truncation=True)
19
+ d_input_ids_pt = np.ones((e_input_ids_pt.input_ids.shape[0], 1), dtype="i4") * model_pt.config.decoder_start_token_id
20
+
21
+
22
+ print(e_input_ids_fx)
23
+ print(d_input_ids_fx)
24
+
25
+ print()
26
+
27
+ encoder_pt = model_fx.encode(**e_input_ids_pt)
28
+ decoder_pt = model_fx.decode(d_input_ids_pt, encoder_pt)
29
+ logits_pt = decoder_pt.logits
30
+ print(logits_pt)
31
+
32
+ encoder_fx = model_fx.encode(**e_input_ids_fx)
33
+ decoder_fx = model_fx.decode(d_input_ids_fx, encoder_fx)
34
+ logits_fx = decoder_fx.logits
35
+ print(logits_fx)
notes/flax_to_tf.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import jax.numpy as jnp
4
+
5
+ from transformers import AutoTokenizer
6
+ from transformers import FlaxT5ForConditionalGeneration
7
+ from transformers import TFT5ForConditionalGeneration
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained("../")
10
+ model_fx = FlaxT5ForConditionalGeneration.from_pretrained("../")
11
+ model_tf = TFT5ForConditionalGeneration.from_pretrained("./", from_pt=True)
12
+ model_tf.save_pretrained("./")
13
+
14
+ text = "Hello To You"
15
+ e_input_ids_fx = tokenizer(text, return_tensors="np", padding=True, max_length=128, truncation=True)
16
+ d_input_ids_fx = jnp.ones((e_input_ids_fx.input_ids.shape[0], 1), dtype="i4") * model_fx.config.decoder_start_token_id
17
+
18
+ e_input_ids_tf = tokenizer(text, return_tensors="tf", padding=True, max_length=128, truncation=True)
19
+ d_input_ids_tf = np.ones((e_input_ids_tf.input_ids.shape[0], 1), dtype="i4") * model_tf.config.decoder_start_token_id
20
+
21
+
22
+ print(e_input_ids_fx)
23
+ print(d_input_ids_fx)
24
+
25
+ print()
26
+
27
+ encoder_tf = model_fx.encode(**e_input_ids_tf)
28
+ decoder_tf = model_fx.decode(d_input_ids_tf, encoder_tf)
29
+ logits_tf = decoder_tf.logits
30
+ print(logits_tf)
31
+
32
+ encoder_fx = model_fx.encode(**e_input_ids_fx)
33
+ decoder_fx = model_fx.decode(d_input_ids_fx, encoder_fx)
34
+ logits_fx = decoder_fx.logits
35
+ print(logits_fx)
notes/generation.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import pandas as pd
4
+ import random
5
+ import re
6
+ import sys
7
+ import time
8
+ from dataclasses import dataclass, field
9
+ from functools import partial
10
+ from pathlib import Path
11
+ from typing import Callable, Optional
12
+
13
+ import jax
14
+ import jax.numpy as jnp
15
+
16
+ from filelock import FileLock
17
+ from flax import jax_utils, traverse_util
18
+ from flax.jax_utils import unreplicate
19
+ from flax.training import train_state
20
+ from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
21
+
22
+ from transformers import FlaxAutoModelForSeq2SeqLM
23
+ from transformers import AutoTokenizer
24
+
25
+ from datasets import Dataset, load_dataset, load_metric
26
+ from tqdm import tqdm
27
+ import pandas as pd
28
+
29
+
30
+ print(jax.devices())
31
+
32
+ MODEL_NAME_OR_PATH = "../"
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True)
35
+ model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)
36
+
37
+ prefix = "items: "
38
+ text_column = "inputs"
39
+ target_column = "targets"
40
+ max_source_length = 256
41
+ max_target_length = 1024
42
+ seed = 42
43
+ eval_batch_size = 32
44
+ # generation_kwargs = {
45
+ # "max_length": 1024,
46
+ # "min_length": 128,
47
+ # "no_repeat_ngram_size": 3,
48
+ # "do_sample": True,
49
+ # "top_k": 60,
50
+ # "top_p": 0.95
51
+ # }
52
+ generation_kwargs = {
53
+ "max_length": 1024,
54
+ "min_length": 128,
55
+ "no_repeat_ngram_size": 3,
56
+ "early_stopping": True,
57
+ "num_beams": 5,
58
+ "length_penalty": 1.5,
59
+ }
60
+
61
+ special_tokens = tokenizer.all_special_tokens
62
+ tokens_map = {
63
+ "<sep>": "--",
64
+ "<section>": "\n"
65
+ }
66
+ def skip_special_tokens(text, special_tokens):
67
+ for token in special_tokens:
68
+ text = text.replace(token, '')
69
+
70
+ return text
71
+
72
+ def target_postprocessing(texts, special_tokens):
73
+ if not isinstance(texts, list):
74
+ texts = [texts]
75
+
76
+ new_texts = []
77
+ for text in texts:
78
+ text = skip_special_tokens(text, special_tokens)
79
+
80
+ for k, v in tokens_map.items():
81
+ text = text.replace(k, v)
82
+
83
+ new_texts.append(text)
84
+
85
+ return new_texts
86
+
87
+
88
+ predict_dataset = load_dataset("csv", data_files={"test": "/home/m3hrdadfi/code/data/test.csv"}, delimiter="\t")["test"]
89
+ print(predict_dataset)
90
+ # predict_dataset = predict_dataset.select(range(10))
91
+ # print(predict_dataset)
92
+ column_names = predict_dataset.column_names
93
+ print(column_names)
94
+
95
+
96
+ # Setting padding="max_length" as we need fixed length inputs for jitted functions
97
+ def preprocess_function(examples):
98
+ inputs = examples[text_column]
99
+ targets = examples[target_column]
100
+ inputs = [prefix + inp for inp in inputs]
101
+ model_inputs = tokenizer(
102
+ inputs,
103
+ max_length=max_source_length,
104
+ padding="max_length",
105
+ truncation=True,
106
+ return_tensors="np"
107
+ )
108
+
109
+ # Setup the tokenizer for targets
110
+ with tokenizer.as_target_tokenizer():
111
+ labels = tokenizer(
112
+ targets,
113
+ max_length=max_target_length,
114
+ padding="max_length",
115
+ truncation=True,
116
+ return_tensors="np"
117
+ )
118
+
119
+ model_inputs["labels"] = labels["input_ids"]
120
+
121
+ return model_inputs
122
+
123
+ predict_dataset = predict_dataset.map(
124
+ preprocess_function,
125
+ batched=True,
126
+ num_proc=None,
127
+ remove_columns=column_names,
128
+ desc="Running tokenizer on prediction dataset",
129
+ )
130
+
131
+ def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False):
132
+ """
133
+ Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
134
+ Shuffle batches if `shuffle` is `True`.
135
+ """
136
+ steps_per_epoch = len(dataset) // batch_size
137
+
138
+ if shuffle:
139
+ batch_idx = jax.random.permutation(rng, len(dataset))
140
+ else:
141
+ batch_idx = jnp.arange(len(dataset))
142
+
143
+ batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
144
+ batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
145
+
146
+ for idx in batch_idx:
147
+ batch = dataset[idx]
148
+ batch = {k: jnp.array(v) for k, v in batch.items()}
149
+
150
+ batch = shard(batch)
151
+
152
+ yield batch
153
+
154
+ rng = jax.random.PRNGKey(seed)
155
+ rng, dropout_rng = jax.random.split(rng)
156
+ rng, input_rng = jax.random.split(rng)
157
+
158
+ def generate_step(batch):
159
+ output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **generation_kwargs)
160
+ return output_ids.sequences
161
+
162
+ p_generate_step = jax.pmap(generate_step, "batch")
163
+
164
+ pred_generations = []
165
+ pred_labels = []
166
+ pred_inputs = []
167
+ pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size)
168
+ pred_steps = len(predict_dataset) // eval_batch_size
169
+
170
+ for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False):
171
+ # Model forward
172
+ batch = next(pred_loader)
173
+ inputs = batch["input_ids"]
174
+ labels = batch["labels"]
175
+
176
+ generated_ids = p_generate_step(batch)
177
+ pred_generations.extend(jax.device_get(generated_ids.reshape(-1, generation_kwargs["max_length"])))
178
+ pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))
179
+ pred_inputs.extend(jax.device_get(inputs.reshape(-1, inputs.shape[-1])))
180
+
181
+ inputs = tokenizer.batch_decode(pred_inputs, skip_special_tokens=True)
182
+ true_recipe = target_postprocessing(
183
+ tokenizer.batch_decode(pred_labels, skip_special_tokens=False),
184
+ special_tokens
185
+ )
186
+ generated_recipe = target_postprocessing(
187
+ tokenizer.batch_decode(pred_generations, skip_special_tokens=False),
188
+ special_tokens
189
+ )
190
+ test_output = {
191
+ "inputs": inputs,
192
+ "true_recipe": true_recipe,
193
+ "generated_recipe": generated_recipe
194
+ }
195
+ test_output = pd.DataFrame.from_dict(test_output)
196
+ test_output.to_csv("./generated_recipes_b.csv", sep="\t", index=False, encoding="utf-8")