DimaKoshman commited on
Commit
028951c
1 Parent(s): 88f8b47

trained model for a bit

Browse files
MakingGraphsAccessible.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
app.py CHANGED
@@ -1,44 +1,56 @@
1
  import gradio
2
- import transformers
3
- import types
 
 
 
 
 
 
 
 
4
 
5
 
6
- checkpoint_path = "checkpoint"
7
- examples_path = "examples"
8
 
9
- MODEL = types.SimpleNamespace()
10
- MODEL.donut_processor = transformers.DonutProcessor.from_pretrained(checkpoint_path)
11
- MODEL.encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained(checkpoint_path)
12
- MODEL.tokenizer = MODEL.donut_processor.tokenizer
13
 
 
 
 
14
 
15
- def generate_token_strings(images, skip_special_tokens=True):
16
- decoder_output = MODEL.encoder_decoder.generate(
17
- images,
18
- max_length=MODEL.encoder_decoder.config.decoder.max_length,
19
- eos_token_id=MODEL.tokenizer.eos_token_id,
20
- return_dict_in_generate=True,
21
- )
22
- return MODEL.tokenizer.batch_decode(
23
- decoder_output.sequences, skip_special_tokens=skip_special_tokens
24
- )
25
 
26
- def predict_string(image):
27
- image = MODEL.donut_processor(
28
- image, random_padding=False, return_tensors="pt"
29
- ).pixel_values
30
- string = generate_token_strings(image)[0]
31
- return string
32
-
33
-
34
- interface = gradio.Interface(
35
- title = "Making graphs accessible",
36
- description = "Generate textual representation of a graph\n"
37
- "https://www.kaggle.com/competitions/benetech-making-graphs-accessible",
38
- fn=predict_string,
39
- inputs="image",
40
- outputs="text",
41
- examples=examples_path,
42
- )
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- interface.launch()
 
1
  import gradio
2
+ import pandas as pd
3
+ from matplotlib import pyplot as plt
4
+
5
+ from config import CONFIG
6
+ from machine_learning.transformers.MakingGraphsAccessible.data import (
7
+ get_extra_tokens,
8
+ BenetechOutput,
9
+ ChartType,
10
+ )
11
+ from model import predict_string, build_model
12
 
13
 
14
+ def gradio_visualize_prediction(string):
15
+ string = string.removeprefix(get_extra_tokens().benetech_prompt)
16
 
17
+ if not BenetechOutput.does_string_match_expected_pattern(string):
18
+ return
 
 
19
 
20
+ benetech_output = BenetechOutput.from_string(string)
21
+ x = benetech_output.x_data[: len(benetech_output.y_data)]
22
+ y = benetech_output.y_data[: len(benetech_output.x_data)]
23
 
24
+ df = pd.DataFrame(dict(x=x, y=y))
 
 
 
 
 
 
 
 
 
25
 
26
+ plt_plot = {
27
+ ChartType.line: plt.plot,
28
+ ChartType.scatter: plt.scatter,
29
+ ChartType.horizontal_bar: plt.barh,
30
+ ChartType.vertical_bar: plt.bar,
31
+ ChartType.dot: plt.scatter,
32
+ }
33
+
34
+ plt_plot[benetech_output.chart_type](x, y)
35
+ plt.xticks(rotation=30)
36
+ plt.savefig("plot.png")
37
+
38
+ ...
39
+
40
+
41
+ def main():
42
+ config = CONFIG
43
+ config.pretrained_model_name = "checkpoint"
44
+ model = build_model(config)
45
+
46
+ interface = gradio.Interface(
47
+ title="Making graphs accessible",
48
+ description="Generate textual representation of a graph\n"
49
+ "https://www.kaggle.com/competitions/benetech-making-graphs-accessible",
50
+ fn=lambda image: predict_string(image, model),
51
+ inputs="image",
52
+ outputs="text",
53
+ examples="examples",
54
+ )
55
 
56
+ interface.launch()
checkpoint/added_tokens.json CHANGED
@@ -1,18 +1,19 @@
1
  {
2
  "1": 57537,
3
  "</benetech_prompt>": 57526,
4
- "<;>": 57536,
5
  "<benetech_prompt>": 57525,
6
- "<categorical>": 57532,
7
- "<dot>": 57527,
8
- "<horizontal_bar>": 57528,
9
- "<line>": 57530,
10
- "<numerical>": 57533,
11
  "<s_iitcdip>": 57523,
12
  "<s_synthdog>": 57524,
13
- "<scatter>": 57531,
14
  "<sep/>": 57522,
15
- "<vertical_bar>": 57529,
16
- "<x_start>": 57534,
17
- "<y_start>": 57535
 
18
  }
 
1
  {
2
  "1": 57537,
3
  "</benetech_prompt>": 57526,
4
+ "<;>": 57529,
5
  "<benetech_prompt>": 57525,
6
+ "<categorical>": 57535,
7
+ "<dot>": 57530,
8
+ "<horizontal_bar>": 57531,
9
+ "<line>": 57533,
10
+ "<numerical>": 57536,
11
  "<s_iitcdip>": 57523,
12
  "<s_synthdog>": 57524,
13
+ "<scatter>": 57534,
14
  "<sep/>": 57522,
15
+ "<vertical_bar>": 57532,
16
+ "<x_start>": 57527,
17
+ "<y_start>": 57528,
18
+ "ދ": 57538
19
  }
checkpoint/config.json CHANGED
@@ -4,7 +4,7 @@
4
  "architectures": [
5
  "VisionEncoderDecoderModel"
6
  ],
7
- "bos_token_id": 57525,
8
  "decoder": {
9
  "_name_or_path": "",
10
  "activation_dropout": 0.0,
@@ -51,7 +51,7 @@
51
  "LABEL_1": 1
52
  },
53
  "length_penalty": 1.0,
54
- "max_length": 512,
55
  "max_position_embeddings": 1536,
56
  "min_length": 0,
57
  "model_type": "mbart",
@@ -88,9 +88,9 @@
88
  "typical_p": 1.0,
89
  "use_bfloat16": false,
90
  "use_cache": true,
91
- "vocab_size": 57538
92
  },
93
- "decoder_start_token_id": 57525,
94
  "encoder": {
95
  "_name_or_path": "",
96
  "add_cross_attention": false,
@@ -187,7 +187,7 @@
187
  "use_bfloat16": false,
188
  "window_size": 10
189
  },
190
- "eos_token_id": 57526,
191
  "is_encoder_decoder": true,
192
  "model_type": "vision-encoder-decoder",
193
  "pad_token_id": 1,
 
4
  "architectures": [
5
  "VisionEncoderDecoderModel"
6
  ],
7
+ "bos_token_id": 3,
8
  "decoder": {
9
  "_name_or_path": "",
10
  "activation_dropout": 0.0,
 
51
  "LABEL_1": 1
52
  },
53
  "length_penalty": 1.0,
54
+ "max_length": 20,
55
  "max_position_embeddings": 1536,
56
  "min_length": 0,
57
  "model_type": "mbart",
 
88
  "typical_p": 1.0,
89
  "use_bfloat16": false,
90
  "use_cache": true,
91
+ "vocab_size": 57539
92
  },
93
+ "decoder_start_token_id": 3,
94
  "encoder": {
95
  "_name_or_path": "",
96
  "add_cross_attention": false,
 
187
  "use_bfloat16": false,
188
  "window_size": 10
189
  },
190
+ "eos_token_id": 3,
191
  "is_encoder_decoder": true,
192
  "model_type": "vision-encoder-decoder",
193
  "pad_token_id": 1,
checkpoint/generation_config.json CHANGED
@@ -1,10 +1,9 @@
1
  {
2
  "_from_model_config": true,
3
- "bos_token_id": 57525,
4
- "decoder_start_token_id": 57525,
5
- "eos_token_id": 57526,
6
  "forced_eos_token_id": 2,
7
- "max_length": 512,
8
  "pad_token_id": 1,
9
  "transformers_version": "4.26.1"
10
  }
 
1
  {
2
  "_from_model_config": true,
3
+ "bos_token_id": 3,
4
+ "decoder_start_token_id": 3,
5
+ "eos_token_id": 3,
6
  "forced_eos_token_id": 2,
 
7
  "pad_token_id": 1,
8
  "transformers_version": "4.26.1"
9
  }
checkpoint/pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:61dbb0fa6d53b3b8ee0bd4f168eacddc6001485b2e6ceb00781f05188cb57645
3
- size 809225433
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c9a42d9810580ea7d19acdfe533e97b4be48693c18c82c2b5f337eb879921ff
3
+ size 809236249
checkpoint/special_tokens_map.json CHANGED
@@ -5,7 +5,7 @@
5
  ],
6
  "bos_token": "<s>",
7
  "cls_token": "<s>",
8
- "eos_token": "</benetech_prompt>",
9
  "mask_token": {
10
  "content": "<mask>",
11
  "lstrip": true,
 
5
  ],
6
  "bos_token": "<s>",
7
  "cls_token": "<s>",
8
+ "eos_token": "<unk>",
9
  "mask_token": {
10
  "content": "<mask>",
11
  "lstrip": true,
checkpoint/tokenizer.json CHANGED
The diff for this file is too large to render. See raw diff
 
checkpoint/tokenizer_config.json CHANGED
@@ -11,7 +11,7 @@
11
  "single_word": false
12
  },
13
  "model_max_length": 1000000000000000019884624838656,
14
- "name_or_path": "tmp.cp",
15
  "pad_token": "<pad>",
16
  "processor_class": "DonutProcessor",
17
  "sep_token": "</s>",
 
11
  "single_word": false
12
  },
13
  "model_max_length": 1000000000000000019884624838656,
14
+ "name_or_path": "naver-clova-ix/donut-base",
15
  "pad_token": "<pad>",
16
  "processor_class": "DonutProcessor",
17
  "sep_token": "</s>",
config.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class CONFIG:
2
+ debug = False
3
+ accelerator = "cpu" if debug else "gpu"
4
+ devices = "auto" if accelerator == "cpu" else [1]
5
+ batch_size = 2 if debug else 1
6
+ limit_train_batches = 2 if debug else None
7
+ limit_val_batches = 2 if debug else 100
8
+ learning_rate = 3e-5
9
+ val_fraction = 0.1
10
+ seed = 42
11
+ train_val_indices_path = "data/train_val_indices.pickle"
12
+ float_scientific_notation_string_precision = 5
13
+ pretrained_model_name = "naver-clova-ix/donut-base"
14
+ image_width = 720
15
+ image_height = 512
16
+ unknown_tokens_for_tokenizer_path = "data/unknown_tokens_for_tokenizer.pickle"
17
+ decoder_sequence_max_length = 512
18
+ num_workers = 4
19
+ training_directory = "training"
20
+ save_top_k_checkpoints = 3
21
+ wandb_project_name = "MakingGraphsAccessible"
data.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import enum
3
+ import functools
4
+ import json
5
+ import os
6
+ import re
7
+ import types
8
+ from typing import Callable
9
+
10
+ import einops
11
+ import imageio
12
+ import numpy as np
13
+ import torch.utils.data
14
+ import torchvision
15
+ import tqdm
16
+
17
+ from config import CONFIG
18
+ from utils import load_pickle_or_build_object_and_save
19
+
20
+
21
+ class Source(enum.Enum):
22
+ generated = "generated"
23
+ extracted = "extracted"
24
+
25
+
26
+ class ChartType(enum.Enum):
27
+ dot = "dot"
28
+ horizontal_bar = "horizontal_bar"
29
+ vertical_bar = "vertical_bar"
30
+ line = "line"
31
+ scatter = "scatter"
32
+
33
+
34
+ @dataclasses.dataclass
35
+ class PlotBoundingBox:
36
+ height: int
37
+ width: int
38
+ x0: int
39
+ y0: int
40
+
41
+ def get_bounds(self):
42
+ xs = [self.x0, self.x0 + self.width, self.x0 + self.width, self.x0, self.x0]
43
+ ys = [self.y0, self.y0, self.y0 + self.height, self.y0 + self.height, self.y0]
44
+ return xs, ys
45
+
46
+
47
+ @dataclasses.dataclass
48
+ class DataPoint:
49
+ x: float or str
50
+ y: float or str
51
+
52
+
53
+ class TextRole(enum.Enum):
54
+ axis_title = "axis_title"
55
+ chart_title = "chart_title"
56
+ legend_label = "legend_label"
57
+ tick_grouping = "tick_grouping"
58
+ tick_label = "tick_label"
59
+ other = "other"
60
+
61
+
62
+ @dataclasses.dataclass
63
+ class Polygon:
64
+ x0: int
65
+ x1: int
66
+ x2: int
67
+ x3: int
68
+ y0: int
69
+ y1: int
70
+ y2: int
71
+ y3: int
72
+
73
+ def get_bounds(self):
74
+ xs = [
75
+ self.x0,
76
+ self.x1,
77
+ self.x2,
78
+ self.x3,
79
+ self.x0,
80
+ ]
81
+ ys = [
82
+ self.y0,
83
+ self.y1,
84
+ self.y2,
85
+ self.y3,
86
+ self.y0,
87
+ ]
88
+ return xs, ys
89
+
90
+
91
+ @dataclasses.dataclass
92
+ class Text:
93
+ id: int
94
+ polygon: Polygon
95
+ role: TextRole
96
+ text: str
97
+
98
+ def __post_init__(self):
99
+ self.polygon = Polygon(**self.polygon)
100
+ self.role = TextRole(self.role)
101
+
102
+
103
+ class ValuesType(enum.Enum):
104
+ categorical = "categorical"
105
+ numerical = "numerical"
106
+
107
+
108
+ @dataclasses.dataclass
109
+ class Tick:
110
+ id: int
111
+ x: int
112
+ y: int
113
+
114
+
115
+ class TickType(enum.Enum):
116
+ markers = "markers"
117
+ separators = "separators"
118
+
119
+
120
+ @dataclasses.dataclass
121
+ class Axis:
122
+ values_type: ValuesType
123
+ tick_type: TickType
124
+ ticks: list[Tick]
125
+
126
+ def __post_init__(self):
127
+ self.values_type = ValuesType(self.values_type)
128
+ self.tick_type = TickType(self.tick_type)
129
+ self.ticks = [
130
+ Tick(id=kw["id"], x=kw["tick_pt"]["x"], y=kw["tick_pt"]["y"])
131
+ for kw in self.ticks
132
+ ]
133
+
134
+ def get_bounds(self):
135
+ min_x = min(tick.x for tick in self.ticks)
136
+ max_x = max(tick.x for tick in self.ticks)
137
+ min_y = min(tick.y for tick in self.ticks)
138
+ max_y = max(tick.y for tick in self.ticks)
139
+ xs = [min_x, max_x, max_x, min_x, min_x]
140
+ ys = [min_y, min_y, max_y, max_y, min_y]
141
+ return xs, ys
142
+
143
+
144
+ def convert_dashes_to_underscores_in_key_names(dictionary):
145
+ return {k.replace("-", "_"): v for k, v in dictionary.items()}
146
+
147
+
148
+ @dataclasses.dataclass
149
+ class Axes:
150
+ x_axis: Axis
151
+ y_axis: Axis
152
+
153
+ def __post_init__(self):
154
+ self.x_axis = Axis(**convert_dashes_to_underscores_in_key_names(self.x_axis))
155
+ self.y_axis = Axis(**convert_dashes_to_underscores_in_key_names(self.y_axis))
156
+
157
+
158
+ def preprocess_numerical_value(value):
159
+ value = float(value)
160
+ value = 0 if np.isnan(value) else value
161
+ return value
162
+
163
+
164
+ def preprocess_value(value, value_type: ValuesType):
165
+ if value_type == ValuesType.numerical:
166
+ return preprocess_numerical_value(value)
167
+ else:
168
+ return str(value)
169
+
170
+
171
+ @dataclasses.dataclass
172
+ class Annotation:
173
+ source: Source
174
+ chart_type: ChartType
175
+ plot_bb: PlotBoundingBox
176
+ text: list[Text]
177
+ axes: Axes
178
+ data_series: list[DataPoint]
179
+
180
+ def __post_init__(self):
181
+ self.source = Source(self.source)
182
+ self.chart_type = ChartType(self.chart_type)
183
+ self.plot_bb = PlotBoundingBox(**self.plot_bb)
184
+ self.text = [Text(**kw) for kw in self.text]
185
+ self.axes = Axes(**convert_dashes_to_underscores_in_key_names(self.axes))
186
+ self.data_series = [DataPoint(**kw) for kw in self.data_series]
187
+
188
+ for i in range(len(self.data_series)):
189
+ self.data_series[i].x = preprocess_value(
190
+ self.data_series[i].x, self.axes.x_axis.values_type
191
+ )
192
+ self.data_series[i].y = preprocess_value(
193
+ self.data_series[i].y, self.axes.y_axis.values_type
194
+ )
195
+
196
+ @staticmethod
197
+ def from_dict_with_dashes(kwargs):
198
+ return Annotation(**convert_dashes_to_underscores_in_key_names(kwargs))
199
+
200
+ @staticmethod
201
+ def from_image_index(image_index: int):
202
+ image_id = load_train_image_ids()[image_index]
203
+ return Annotation.from_dict_with_dashes(load_image_annotation(image_id))
204
+
205
+ def get_text_by_role(self, text_role: TextRole) -> list[Text]:
206
+ return [t for t in self.text if t.role == text_role]
207
+
208
+
209
+ @dataclasses.dataclass
210
+ class AnnotatedImage:
211
+ id: str
212
+ image: np.ndarray
213
+ annotation: Annotation
214
+
215
+ @staticmethod
216
+ def from_image_id(image_id: str):
217
+ return AnnotatedImage(
218
+ id=image_id,
219
+ image=load_image(image_id),
220
+ annotation=Annotation.from_dict_with_dashes(
221
+ load_image_annotation(image_id)
222
+ ),
223
+ )
224
+
225
+ @staticmethod
226
+ def from_image_index(image_index: int):
227
+ return AnnotatedImage.from_image_id(load_train_image_ids()[image_index])
228
+
229
+
230
+ def generate_annotated_images():
231
+ for image_id in tqdm.autonotebook.tqdm(
232
+ load_train_image_ids(), "Iterating over annotated images"
233
+ ):
234
+ yield AnnotatedImage.from_image_id(image_id)
235
+
236
+
237
+ @functools.cache
238
+ def load_train_image_ids() -> list[str]:
239
+ train_image_ids = [i.replace(".jpg", "") for i in os.listdir("data/train/images")]
240
+ return train_image_ids[: 1000 if CONFIG.debug else None]
241
+
242
+
243
+ @functools.cache
244
+ def load_test_image_ids() -> list[str]:
245
+ return [i.replace(".jpg", "") for i in os.listdir("data/test/images")]
246
+
247
+
248
+ @functools.cache
249
+ def load_image_annotation(image_id: str) -> dict:
250
+ return json.load(open(f"data/train/annotations/{image_id}.json"))
251
+
252
+
253
+ def load_image(image_id: str) -> np.ndarray:
254
+ return imageio.v3.imread(open(f"data/train/images/{image_id}.jpg", "rb"))
255
+
256
+
257
+ @dataclasses.dataclass
258
+ class DataItem:
259
+ image: torch.FloatTensor
260
+ target_string: str
261
+ data_index: int
262
+
263
+ def __post_init__(self):
264
+ shape = einops.parse_shape(self.image, "channel height width")
265
+ assert shape["channel"] == 3, "Image is expected to have 3 channels."
266
+
267
+
268
+ def split_train_indices_by_source():
269
+ extracted_image_indices = []
270
+ generated_image_indices = []
271
+ for i, annotated_image in enumerate(generate_annotated_images()):
272
+ if annotated_image.annotation.source == Source.extracted:
273
+ extracted_image_indices.append(i)
274
+ else:
275
+ generated_image_indices.append(i)
276
+ return extracted_image_indices, generated_image_indices
277
+
278
+
279
+ def get_train_val_split_indices(val_fraction=0.1, seed=42):
280
+ np.random.seed(seed)
281
+ val_size = int(len(load_train_image_ids()) * val_fraction)
282
+
283
+ extracted_image_indices, generated_image_indices = split_train_indices_by_source()
284
+ extracted_image_indices = np.random.permutation(extracted_image_indices)
285
+ generated_image_indices = np.random.permutation(generated_image_indices)
286
+
287
+ val_indices = extracted_image_indices[:val_size]
288
+ n_generated_images_in_val = val_size - len(val_indices)
289
+ val_indices = np.concatenate(
290
+ [val_indices, generated_image_indices[:n_generated_images_in_val]]
291
+ )
292
+
293
+ train_indices = generated_image_indices[n_generated_images_in_val:]
294
+
295
+ assert len(set(train_indices) | set(val_indices)) == len(load_train_image_ids())
296
+ assert len(val_indices) == val_size
297
+ assert len(set(train_indices) & set(val_indices)) == 0
298
+
299
+ return train_indices, val_indices
300
+
301
+
302
+ def to_token_str(value: str or enum.Enum):
303
+ string = value.name if isinstance(value, enum.Enum) else value
304
+ if re.fullmatch("<.*>", string):
305
+ return string
306
+ else:
307
+ return f"<{string}>"
308
+
309
+
310
+ @functools.cache
311
+ def get_extra_tokens() -> types.SimpleNamespace:
312
+ token_ns = types.SimpleNamespace()
313
+
314
+ token_ns.benetech_prompt = to_token_str("benetech_prompt")
315
+ token_ns.benetech_prompt_end = to_token_str("/benetech_prompt")
316
+ token_ns.x_start = to_token_str("x_start")
317
+ token_ns.y_start = to_token_str("y_start")
318
+ token_ns.value_separator = to_token_str(";")
319
+
320
+ for chart_type in ChartType:
321
+ setattr(token_ns, chart_type.name, to_token_str(chart_type))
322
+
323
+ for values_type in ValuesType:
324
+ setattr(token_ns, values_type.name, to_token_str(values_type))
325
+
326
+ return token_ns
327
+
328
+
329
+ def convert_number_to_scientific_string(value: int or float) -> str:
330
+ return f"{value:.{CONFIG.float_scientific_notation_string_precision}e}"
331
+
332
+
333
+ def convert_axis_data_to_string(
334
+ axis_data: list[str or float], values_type: ValuesType
335
+ ) -> str:
336
+ formatted_axis_data = []
337
+ for value in axis_data:
338
+ if values_type == ValuesType.numerical:
339
+ value = convert_number_to_scientific_string(value)
340
+ formatted_axis_data.append(value)
341
+ return get_extra_tokens().value_separator.join(formatted_axis_data)
342
+
343
+
344
+ def convert_string_to_axis_data(string, values_type: ValuesType):
345
+ data = string.split(get_extra_tokens().value_separator)
346
+ if values_type == ValuesType.numerical:
347
+ data = [float(i.replace(" ", "")) for i in data]
348
+ return data
349
+
350
+
351
+ @dataclasses.dataclass
352
+ class BenetechOutput:
353
+ chart_type: ChartType
354
+ x_values_type: ValuesType
355
+ y_values_type: ValuesType
356
+ x_data: list[str or float]
357
+ y_data: list[str or float]
358
+
359
+ def __post_init__(self):
360
+ self.chart_type = ChartType(self.chart_type)
361
+ self.x_values_type = ValuesType(self.x_values_type)
362
+ self.y_values_type = ValuesType(self.y_values_type)
363
+ assert isinstance(self.x_data, list)
364
+ assert isinstance(self.y_data, list)
365
+
366
+ def get_main_characteristics(self):
367
+ return (
368
+ self.chart_type,
369
+ self.x_values_type,
370
+ self.y_values_type,
371
+ len(self.x_data),
372
+ len(self.y_data),
373
+ )
374
+
375
+ @staticmethod
376
+ def from_annotation(annotation: Annotation):
377
+ return BenetechOutput(
378
+ chart_type=annotation.chart_type,
379
+ x_values_type=annotation.axes.x_axis.values_type,
380
+ y_values_type=annotation.axes.y_axis.values_type,
381
+ x_data=[dp.x for dp in annotation.data_series],
382
+ y_data=[dp.y for dp in annotation.data_series],
383
+ )
384
+
385
+ def to_string(self):
386
+ return self.format_strings(
387
+ chart_type=self.chart_type,
388
+ x_values_type=self.x_values_type,
389
+ y_values_type=self.y_values_type,
390
+ x_data=convert_axis_data_to_string(self.x_data, self.x_values_type),
391
+ y_data=convert_axis_data_to_string(self.y_data, self.y_values_type),
392
+ )
393
+
394
+ @staticmethod
395
+ def format_strings(*, chart_type, x_values_type, y_values_type, x_data, y_data):
396
+ chart_type = to_token_str(chart_type)
397
+ x_values_type = to_token_str(x_values_type)
398
+ y_values_type = to_token_str(y_values_type)
399
+ token_ns = get_extra_tokens()
400
+ return (
401
+ f"{token_ns.benetech_prompt}{chart_type}"
402
+ f"{token_ns.x_start}{x_values_type}{x_data}"
403
+ f"{token_ns.y_start}{y_values_type}{y_data}"
404
+ f"{token_ns.benetech_prompt_end}"
405
+ )
406
+
407
+ @staticmethod
408
+ def get_string_pattern():
409
+ field_names = [field.name for field in dataclasses.fields(BenetechOutput)]
410
+ pattern = BenetechOutput.format_strings(
411
+ **{field_name: f"(?P<{field_name}>.*?)" for field_name in field_names}
412
+ )
413
+ return pattern
414
+
415
+ @staticmethod
416
+ def does_string_match_expected_pattern(string):
417
+ try:
418
+ BenetechOutput.from_string(string)
419
+ return True
420
+ except:
421
+ return False
422
+
423
+ @staticmethod
424
+ def from_string(string):
425
+ fullmatch = re.fullmatch(BenetechOutput.get_string_pattern(), string)
426
+ benetech_kwargs = fullmatch.groupdict()
427
+ benetech_kwargs["chart_type"] = ChartType(benetech_kwargs["chart_type"])
428
+ benetech_kwargs["x_values_type"] = ValuesType(benetech_kwargs["x_values_type"])
429
+ benetech_kwargs["y_values_type"] = ValuesType(benetech_kwargs["y_values_type"])
430
+ benetech_kwargs["x_data"] = convert_string_to_axis_data(
431
+ benetech_kwargs["x_data"], benetech_kwargs["x_values_type"]
432
+ )
433
+ benetech_kwargs["y_data"] = convert_string_to_axis_data(
434
+ benetech_kwargs["y_data"], benetech_kwargs["y_values_type"]
435
+ )
436
+ return BenetechOutput(**benetech_kwargs)
437
+
438
+
439
+ def get_annotation_ground_truth_str(annotation: Annotation):
440
+ benetech_output = BenetechOutput(
441
+ chart_type=annotation.chart_type,
442
+ x_values_type=annotation.axes.x_axis.values_type,
443
+ x_data=[dp.x for dp in annotation.data_series],
444
+ y_values_type=annotation.axes.y_axis.values_type,
445
+ y_data=[dp.y for dp in annotation.data_series],
446
+ )
447
+ return benetech_output.to_string()
448
+
449
+
450
+ def get_annotation_ground_truth_str_from_image_index(image_index: int) -> str:
451
+ return get_annotation_ground_truth_str(Annotation.from_image_index(image_index))
452
+
453
+
454
+ class Dataset(torch.utils.data.Dataset):
455
+ def __init__(self, indices: list[int]):
456
+ super().__init__()
457
+ self.indices = indices
458
+ self.to_tensor = torchvision.transforms.ToTensor()
459
+
460
+ def __len__(self):
461
+ return len(self.indices)
462
+
463
+ def __getitem__(self, idx: int) -> DataItem:
464
+ data_index = self.indices[idx]
465
+
466
+ annotated_image = AnnotatedImage.from_image_index(data_index)
467
+
468
+ image = annotated_image.image
469
+ image = self.to_tensor(image)
470
+
471
+ target_string = get_annotation_ground_truth_str(annotated_image.annotation)
472
+
473
+ return DataItem(image=image, target_string=target_string, data_index=data_index)
474
+
475
+
476
+ def get_train_val_datasets():
477
+ train_indices, val_indices = load_pickle_or_build_object_and_save(
478
+ CONFIG.train_val_indices_path,
479
+ lambda: get_train_val_split_indices(CONFIG.val_fraction, CONFIG.seed),
480
+ )
481
+ return Dataset(train_indices), Dataset(val_indices)
482
+
483
+
484
+ def get_train_dataset():
485
+ return get_train_val_datasets()[0]
486
+
487
+
488
+ def get_val_dataset():
489
+ return get_train_val_datasets()[1]
490
+
491
+
492
+ @dataclasses.dataclass
493
+ class Batch:
494
+ images: torch.FloatTensor
495
+ labels: torch.IntTensor
496
+ data_indices: list[int]
497
+
498
+ def __post_init__(self):
499
+ if CONFIG.debug:
500
+ images_shape = einops.parse_shape(self.images, "batch channel height width")
501
+ labels_shape = einops.parse_shape(self.labels, "batch label")
502
+ assert images_shape["batch"] == labels_shape["batch"]
503
+ assert len(self.data_indices) == images_shape["batch"]
504
+
505
+
506
+ class Split(enum.Enum):
507
+ train = "train"
508
+ val = "val"
509
+
510
+
511
+ BatchCollateFunction = Callable[[list[DataItem], Split], Batch]
512
+
513
+
514
+ def build_dataloader(split: Split, batch_collate_function: BatchCollateFunction):
515
+ return torch.utils.data.DataLoader(
516
+ get_train_dataset() if split == Split.train else get_val_dataset(),
517
+ batch_size=CONFIG.batch_size,
518
+ shuffle=split == Split.train,
519
+ num_workers=CONFIG.num_workers,
520
+ collate_fn=functools.partial(batch_collate_function, split=split),
521
+ )
metrics.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import rapidfuzz
3
+ import sklearn
4
+
5
+ from data import ValuesType, BenetechOutput, Annotation
6
+
7
+
8
+ def normalized_rmse(expected: list[float], predicted: list[float]) -> float:
9
+ return (1 - sklearn.metrics.r2_score(expected, predicted)) ** 0.5
10
+
11
+
12
+ def normalized_levenshtein_distance(expected: list[str], predicted: list[str]) -> float:
13
+ total_distance = 0
14
+ for e, p in zip(expected, predicted):
15
+ total_distance += rapidfuzz.distance.Levenshtein.distance(e, p)
16
+ total_length = np.sum([len(e) for e in expected])
17
+ return total_distance / total_length
18
+
19
+
20
+ def sigmoid(x):
21
+ return 1 / (1 + np.exp(-x))
22
+
23
+
24
+ def positive_loss_to_score(x):
25
+ return 2 * sigmoid(-x)
26
+
27
+
28
+ def score_axis_values(values_type, expected, predicted):
29
+ if values_type == ValuesType.numerical:
30
+ loss = normalized_rmse(expected, predicted)
31
+ else:
32
+ loss = normalized_levenshtein_distance(expected, predicted)
33
+ return positive_loss_to_score(loss)
34
+
35
+
36
+ def benetech_score(expected: BenetechOutput, predicted: BenetechOutput) -> float:
37
+ if expected.get_main_characteristics() != predicted.get_main_characteristics():
38
+ return 0
39
+ x_score = score_axis_values(
40
+ expected.x_values_type, expected.x_data, predicted.x_data
41
+ )
42
+ y_score = score_axis_values(
43
+ expected.y_values_type, expected.y_data, predicted.y_data
44
+ )
45
+ return (x_score + y_score) / 2
46
+
47
+
48
+ def benetech_score_string_prediction(expected_data_index: int, predicted_string: str):
49
+ if not BenetechOutput.does_string_match_expected_pattern(predicted_string):
50
+ return 0
51
+ expected_annotation = Annotation.from_image_index(expected_data_index)
52
+ expected_output = BenetechOutput.from_annotation(expected_annotation)
53
+ predicted_output = BenetechOutput.from_string(predicted_string)
54
+ return benetech_score(expected_output, predicted_output)
model.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import dataclasses
3
+ import types
4
+
5
+ import pytorch_lightning as pl
6
+ import torch.utils.data
7
+ import transformers
8
+
9
+ from data import (
10
+ generate_annotated_images,
11
+ get_annotation_ground_truth_str,
12
+ DataItem,
13
+ get_extra_tokens,
14
+ Batch,
15
+ Split,
16
+ BatchCollateFunction,
17
+ )
18
+ from utils import load_pickle_or_build_object_and_save
19
+
20
+
21
+ @dataclasses.dataclass
22
+ class Model:
23
+ processor: transformers.models.donut.processing_donut.DonutProcessor
24
+ tokenizer: transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast
25
+ encoder_decoder: transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder.VisionEncoderDecoderModel
26
+ batch_collate_function: BatchCollateFunction
27
+ config: types.SimpleNamespace
28
+
29
+
30
+ def add_unknown_tokens_to_tokenizer(
31
+ tokenizer, encoder_decoder, unknown_tokens: list[str]
32
+ ):
33
+ tokenizer.add_tokens(unknown_tokens)
34
+ encoder_decoder.decoder.resize_token_embeddings(len(tokenizer))
35
+
36
+
37
+ def find_unknown_tokens_for_tokenizer(tokenizer) -> collections.Counter:
38
+ unknown_tokens_counter = collections.Counter()
39
+
40
+ for annotated_image in generate_annotated_images():
41
+ ground_truth = get_annotation_ground_truth_str(annotated_image.annotation)
42
+
43
+ input_ids = tokenizer(ground_truth).input_ids
44
+ tokens = tokenizer.tokenize(ground_truth, add_special_tokens=True)
45
+
46
+ for token_id, token in zip(input_ids, tokens, strict=True):
47
+ if token_id == tokenizer.unk_token_id:
48
+ unknown_tokens_counter.update([token])
49
+
50
+ return unknown_tokens_counter
51
+
52
+
53
+ def replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(
54
+ tokenizer, token_ids
55
+ ):
56
+ token_ids[token_ids == tokenizer.pad_token_id] = -100
57
+ return token_ids
58
+
59
+
60
+ @dataclasses.dataclass
61
+ class BatchCollateFunction:
62
+ processor: transformers.models.donut.processing_donut.DonutProcessor
63
+ tokenizer: transformers.models.xlm_roberta.tokenization_xlm_roberta_fast.XLMRobertaTokenizerFast
64
+ decoder_sequence_max_length: int
65
+
66
+ def __call__(self, batch: list[DataItem], split: Split) -> Batch:
67
+ images = [di.image for di in batch]
68
+ images = self.processor(
69
+ images, random_padding=split == Split.train, return_tensors="pt"
70
+ ).pixel_values
71
+
72
+ target_token_ids = self.tokenizer(
73
+ [di.target_string for di in batch],
74
+ add_special_tokens=False,
75
+ max_length=self.decoder_sequence_max_length,
76
+ padding="max_length",
77
+ truncation=True,
78
+ return_tensors="pt",
79
+ ).input_ids
80
+ labels = replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(
81
+ self.tokenizer, target_token_ids
82
+ )
83
+
84
+ data_indices = [di.data_index for di in batch]
85
+
86
+ return Batch(images=images, labels=labels, data_indices=data_indices)
87
+
88
+
89
+ def build_model(config: types.SimpleNamespace or object) -> Model:
90
+ donut_processor = transformers.DonutProcessor.from_pretrained(
91
+ config.pretrained_model_name
92
+ )
93
+ donut_processor.image_processor.size = dict(
94
+ width=config.image_width, height=config.image_height
95
+ )
96
+ donut_processor.image_processor.do_align_long_axis = False
97
+
98
+ tokenizer = donut_processor.tokenizer
99
+
100
+ encoder_decoder_config = transformers.VisionEncoderDecoderConfig.from_pretrained(
101
+ config.pretrained_model_name
102
+ )
103
+ encoder_decoder_config.encoder.image_size = (
104
+ config.image_width,
105
+ config.image_height,
106
+ )
107
+
108
+ encoder_decoder = transformers.VisionEncoderDecoderModel.from_pretrained(
109
+ config.pretrained_model_name, config=encoder_decoder_config
110
+ )
111
+ encoder_decoder_config.pad_token_id = tokenizer.pad_token_id
112
+ encoder_decoder_config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(
113
+ get_extra_tokens().benetech_prompt
114
+ )
115
+ encoder_decoder_config.bos_token_id = encoder_decoder_config.decoder_start_token_id
116
+ encoder_decoder_config.eos_token_id = tokenizer.convert_tokens_to_ids(
117
+ get_extra_tokens().benetech_prompt_end
118
+ )
119
+
120
+ extra_tokens = list(get_extra_tokens().__dict__.values())
121
+ add_unknown_tokens_to_tokenizer(tokenizer, encoder_decoder, extra_tokens)
122
+ unknown_dataset_tokens = load_pickle_or_build_object_and_save(
123
+ config.unknown_tokens_for_tokenizer_path,
124
+ lambda: list(find_unknown_tokens_for_tokenizer(tokenizer).keys()),
125
+ )
126
+ add_unknown_tokens_to_tokenizer(tokenizer, encoder_decoder, unknown_dataset_tokens)
127
+ tokenizer.eos_token_id = encoder_decoder_config.eos_token_id
128
+
129
+ batch_collate_function = BatchCollateFunction(
130
+ processor=donut_processor,
131
+ tokenizer=tokenizer,
132
+ decoder_sequence_max_length=config.decoder_sequence_max_length,
133
+ )
134
+
135
+ return Model(
136
+ processor=donut_processor,
137
+ tokenizer=tokenizer,
138
+ encoder_decoder=encoder_decoder,
139
+ batch_collate_function=batch_collate_function,
140
+ config=config,
141
+ )
142
+
143
+
144
+ def generate_token_strings(
145
+ model: Model, images: torch.Tensor, skip_special_tokens=True
146
+ ) -> list[str]:
147
+ decoder_output = model.encoder_decoder.generate(
148
+ images,
149
+ max_length=10
150
+ if model.config.debug
151
+ else model.config.decoder_sequence_max_length,
152
+ eos_token_id=model.tokenizer.eos_token_id,
153
+ return_dict_in_generate=True,
154
+ )
155
+ return model.tokenizer.batch_decode(
156
+ decoder_output.sequences, skip_special_tokens=skip_special_tokens
157
+ )
158
+
159
+
160
+ def predict_string(image, model: Model):
161
+ image = model.processor(
162
+ image, random_padding=False, return_tensors="pt"
163
+ ).pixel_values
164
+ string = generate_token_strings(model, image)[0]
165
+ return string
166
+
167
+
168
+ class LightningModule(pl.LightningModule):
169
+ def __init__(self, config):
170
+ super().__init__()
171
+ self.save_hyperparameters()
172
+ self.model = build_model(config)
173
+ self.encoder_decoder = self.model.encoder_decoder
174
+
175
+ def training_step(self, batch: Batch, batch_idx: int) -> torch.Tensor:
176
+ loss = self.compute_loss(batch)
177
+ self.log("train_loss", loss)
178
+ return loss
179
+
180
+ def validation_step(self, batch: Batch, batch_idx: int, dataset_idx: int = 0):
181
+ loss = self.compute_loss(batch)
182
+ self.log("val_loss", loss)
183
+
184
+ def compute_loss(self, batch: Batch) -> torch.Tensor:
185
+ outputs = self.encoder_decoder(pixel_values=batch.images, labels=batch.labels)
186
+ loss = outputs.loss
187
+ return loss
188
+
189
+ def configure_optimizers(self) -> torch.optim.Optimizer:
190
+ optimizer = torch.optim.Adam(
191
+ self.parameters(), lr=self.hparams["config"].learning_rate
192
+ )
193
+ return optimizer
requirements.txt DELETED
@@ -1,3 +0,0 @@
1
- gradio==3.27.0
2
- torch==2.0.0
3
- transformers==4.26.1
 
 
 
 
train.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pandas as pd
4
+ import pytorch_lightning as pl
5
+ import transformers
6
+ import wandb
7
+
8
+ from config import CONFIG
9
+ from data import (
10
+ get_annotation_ground_truth_str_from_image_index,
11
+ load_train_image_ids,
12
+ build_dataloader,
13
+ Split,
14
+ Batch,
15
+ )
16
+ from metrics import benetech_score_string_prediction
17
+ from model import generate_token_strings, LightningModule
18
+ from utils import set_tokenizers_parallelism, set_torch_device_order_pci_bus
19
+
20
+
21
+ class MetricsCallback(pl.callbacks.Callback):
22
+ def on_validation_batch_start(
23
+ self, trainer, pl_module, batch: Batch, batch_idx, dataloader_idx=0
24
+ ):
25
+ predicted_strings = generate_token_strings(pl_module.model, images=batch.images)
26
+
27
+ for expected_data_index, predicted_string in zip(
28
+ batch.data_indices, predicted_strings, strict=True
29
+ ):
30
+ benetech_score = benetech_score_string_prediction(
31
+ expected_data_index=expected_data_index,
32
+ predicted_string=predicted_string,
33
+ )
34
+ wandb.log(dict(benetech_score=benetech_score))
35
+
36
+ ground_truth_strings = [
37
+ get_annotation_ground_truth_str_from_image_index(i)
38
+ for i in batch.data_indices
39
+ ]
40
+ string_ids = [load_train_image_ids()[i] for i in batch.data_indices]
41
+ strings_dataframe = pd.DataFrame(
42
+ dict(
43
+ string_ids=string_ids,
44
+ ground_truth=ground_truth_strings,
45
+ predicted=predicted_strings,
46
+ )
47
+ )
48
+ wandb.log(dict(strings=wandb.Table(dataframe=strings_dataframe)))
49
+
50
+
51
+ class TransformersPreTrainedModelsCheckpointIO(pl.plugins.CheckpointIO):
52
+ def __init__(
53
+ self, pretrained_models: list[transformers.modeling_utils.PreTrainedModel]
54
+ ):
55
+ super().__init__()
56
+ self.pretrained_models = pretrained_models
57
+
58
+ def save_checkpoint(self, checkpoint, path, storage_options=None):
59
+ for pretrained_model in self.pretrained_models:
60
+ pretrained_model.save_pretrained(path)
61
+
62
+ def load_checkpoint(self, path, storage_options=None):
63
+ self.pretrained_models = [
64
+ pm.from_pretrained(path) for pm in self.pretrained_models
65
+ ]
66
+
67
+ def remove_checkpoint(self, path):
68
+ os.remove(path)
69
+
70
+
71
+ def train():
72
+ set_tokenizers_parallelism(False)
73
+ set_torch_device_order_pci_bus()
74
+
75
+ pl_module = LightningModule(CONFIG)
76
+
77
+ model_checkpoint = pl.callbacks.ModelCheckpoint(
78
+ dirpath=CONFIG.training_directory,
79
+ monitor="val_loss",
80
+ save_top_k=CONFIG.save_top_k_checkpoints,
81
+ )
82
+ metrics_callback = MetricsCallback()
83
+
84
+ logger = pl.loggers.WandbLogger(
85
+ project=CONFIG.wandb_project_name, save_dir=CONFIG.training_directory
86
+ )
87
+
88
+ plugin = TransformersPreTrainedModelsCheckpointIO(
89
+ [pl_module.model.processor, pl_module.model.encoder_decoder]
90
+ )
91
+
92
+ trainer = pl.Trainer(
93
+ accelerator=CONFIG.accelerator,
94
+ devices=CONFIG.devices,
95
+ plugins=[plugin],
96
+ callbacks=[model_checkpoint, metrics_callback],
97
+ logger=logger,
98
+ limit_train_batches=CONFIG.limit_train_batches,
99
+ limit_val_batches=CONFIG.limit_val_batches,
100
+ )
101
+
102
+ trainer.fit(
103
+ model=pl_module,
104
+ train_dataloaders=build_dataloader(
105
+ Split.train, pl_module.model.batch_collate_function
106
+ ),
107
+ val_dataloaders=build_dataloader(
108
+ Split.val, pl_module.model.batch_collate_function
109
+ ),
110
+ )
111
+
112
+
113
+ if __name__ == "__main__":
114
+ train()
utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pickle
3
+ from typing import Callable, TypeVar
4
+
5
+ T = TypeVar("T")
6
+
7
+
8
+ def set_tokenizers_parallelism(enable: bool):
9
+ os.environ["TOKENIZERS_PARALLELISM"] = "true" if enable else "false"
10
+
11
+
12
+ def set_torch_device_order_pci_bus():
13
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
14
+
15
+
16
+ def load_pickle_or_build_object_and_save(
17
+ pickle_path: str, build_object: Callable[[], T], overwrite=False
18
+ ) -> T:
19
+ if overwrite or not os.path.exists(pickle_path):
20
+ pickle.dump(build_object(), open(pickle_path, "wb"))
21
+ else:
22
+ print(f"Reusing object {pickle_path}.")
23
+ return pickle.load(open(pickle_path, "rb"))