DimaKoshman commited on
Commit
6a43216
1 Parent(s): ba1576e
Files changed (6) hide show
  1. app.py +1 -3
  2. data.py +6 -8
  3. metrics.py +0 -2
  4. model.py +3 -5
  5. train.py +3 -5
  6. utils.py +1 -3
app.py CHANGED
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
  import gradio
4
  import pandas as pd
5
  from matplotlib import pyplot as plt
@@ -44,7 +42,7 @@ def main():
44
  interface = gradio.Interface(
45
  title="Making graphs accessible",
46
  description="Generate textual representation of a graph\n"
47
- "https://www.kaggle.com/competitions/benetech-making-graphs-accessible",
48
  fn=lambda image: predict_string(image, model),
49
  inputs="image",
50
  outputs="text",
 
 
 
1
  import gradio
2
  import pandas as pd
3
  from matplotlib import pyplot as plt
 
42
  interface = gradio.Interface(
43
  title="Making graphs accessible",
44
  description="Generate textual representation of a graph\n"
45
+ "https://www.kaggle.com/competitions/benetech-making-graphs-accessible",
46
  fn=lambda image: predict_string(image, model),
47
  inputs="image",
48
  outputs="text",
data.py CHANGED
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
  import dataclasses
4
  import enum
5
  import functools
@@ -231,23 +229,23 @@ class AnnotatedImage:
231
 
232
  def generate_annotated_images():
233
  for image_id in tqdm.autonotebook.tqdm(
234
- load_train_image_ids(), "Iterating over annotated images"
235
  ):
236
  yield AnnotatedImage.from_image_id(image_id)
237
 
238
 
239
- @functools.lru_cache
240
  def load_train_image_ids() -> list[str]:
241
  train_image_ids = [i.replace(".jpg", "") for i in os.listdir("data/train/images")]
242
  return train_image_ids[: 1000 if CONFIG.debug else None]
243
 
244
 
245
- @functools.lru_cache
246
  def load_test_image_ids() -> list[str]:
247
  return [i.replace(".jpg", "") for i in os.listdir("data/test/images")]
248
 
249
 
250
- @functools.lru_cache
251
  def load_image_annotation(image_id: str) -> dict:
252
  return json.load(open(f"data/train/annotations/{image_id}.json"))
253
 
@@ -309,7 +307,7 @@ def to_token_str(value: str or enum.Enum):
309
  return f"<{string}>"
310
 
311
 
312
- @functools.lru_cache
313
  def get_extra_tokens() -> types.SimpleNamespace:
314
  token_ns = types.SimpleNamespace()
315
 
@@ -333,7 +331,7 @@ def convert_number_to_scientific_string(value: int or float) -> str:
333
 
334
 
335
  def convert_axis_data_to_string(
336
- axis_data: list[str or float], values_type: ValuesType
337
  ) -> str:
338
  formatted_axis_data = []
339
  for value in axis_data:
 
 
 
1
  import dataclasses
2
  import enum
3
  import functools
 
229
 
230
  def generate_annotated_images():
231
  for image_id in tqdm.autonotebook.tqdm(
232
+ load_train_image_ids(), "Iterating over annotated images"
233
  ):
234
  yield AnnotatedImage.from_image_id(image_id)
235
 
236
 
237
+ @functools.cache
238
  def load_train_image_ids() -> list[str]:
239
  train_image_ids = [i.replace(".jpg", "") for i in os.listdir("data/train/images")]
240
  return train_image_ids[: 1000 if CONFIG.debug else None]
241
 
242
 
243
+ @functools.cache
244
  def load_test_image_ids() -> list[str]:
245
  return [i.replace(".jpg", "") for i in os.listdir("data/test/images")]
246
 
247
 
248
+ @functools.cache
249
  def load_image_annotation(image_id: str) -> dict:
250
  return json.load(open(f"data/train/annotations/{image_id}.json"))
251
 
 
307
  return f"<{string}>"
308
 
309
 
310
+ @functools.cache
311
  def get_extra_tokens() -> types.SimpleNamespace:
312
  token_ns = types.SimpleNamespace()
313
 
 
331
 
332
 
333
  def convert_axis_data_to_string(
334
+ axis_data: list[str or float], values_type: ValuesType
335
  ) -> str:
336
  formatted_axis_data = []
337
  for value in axis_data:
metrics.py CHANGED
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
  import numpy as np
4
  import rapidfuzz
5
  import sklearn
 
 
 
1
  import numpy as np
2
  import rapidfuzz
3
  import sklearn
model.py CHANGED
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
  import collections
4
  import dataclasses
5
  import types
@@ -30,7 +28,7 @@ class Model:
30
 
31
 
32
  def add_unknown_tokens_to_tokenizer(
33
- tokenizer, encoder_decoder, unknown_tokens: list[str]
34
  ):
35
  tokenizer.add_tokens(unknown_tokens)
36
  encoder_decoder.decoder.resize_token_embeddings(len(tokenizer))
@@ -53,7 +51,7 @@ def find_unknown_tokens_for_tokenizer(tokenizer) -> collections.Counter:
53
 
54
 
55
  def replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(
56
- tokenizer, token_ids
57
  ):
58
  token_ids[token_ids == tokenizer.pad_token_id] = -100
59
  return token_ids
@@ -144,7 +142,7 @@ def build_model(config: types.SimpleNamespace or object) -> Model:
144
 
145
 
146
  def generate_token_strings(
147
- model: Model, images: torch.Tensor, skip_special_tokens=True
148
  ) -> list[str]:
149
  decoder_output = model.encoder_decoder.generate(
150
  images,
 
 
 
1
  import collections
2
  import dataclasses
3
  import types
 
28
 
29
 
30
  def add_unknown_tokens_to_tokenizer(
31
+ tokenizer, encoder_decoder, unknown_tokens: list[str]
32
  ):
33
  tokenizer.add_tokens(unknown_tokens)
34
  encoder_decoder.decoder.resize_token_embeddings(len(tokenizer))
 
51
 
52
 
53
  def replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(
54
+ tokenizer, token_ids
55
  ):
56
  token_ids[token_ids == tokenizer.pad_token_id] = -100
57
  return token_ids
 
142
 
143
 
144
  def generate_token_strings(
145
+ model: Model, images: torch.Tensor, skip_special_tokens=True
146
  ) -> list[str]:
147
  decoder_output = model.encoder_decoder.generate(
148
  images,
train.py CHANGED
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
  import os
4
 
5
  import pandas as pd
@@ -22,12 +20,12 @@ from utils import set_tokenizers_parallelism, set_torch_device_order_pci_bus
22
 
23
  class MetricsCallback(pl.callbacks.Callback):
24
  def on_validation_batch_start(
25
- self, trainer, pl_module, batch: Batch, batch_idx, dataloader_idx=0
26
  ):
27
  predicted_strings = generate_token_strings(pl_module.model, images=batch.images)
28
 
29
  for expected_data_index, predicted_string in zip(
30
- batch.data_indices, predicted_strings, strict=True
31
  ):
32
  benetech_score = benetech_score_string_prediction(
33
  expected_data_index=expected_data_index,
@@ -52,7 +50,7 @@ class MetricsCallback(pl.callbacks.Callback):
52
 
53
  class TransformersPreTrainedModelsCheckpointIO(pl.plugins.CheckpointIO):
54
  def __init__(
55
- self, pretrained_models: list[transformers.modeling_utils.PreTrainedModel]
56
  ):
57
  super().__init__()
58
  self.pretrained_models = pretrained_models
 
 
 
1
  import os
2
 
3
  import pandas as pd
 
20
 
21
  class MetricsCallback(pl.callbacks.Callback):
22
  def on_validation_batch_start(
23
+ self, trainer, pl_module, batch: Batch, batch_idx, dataloader_idx=0
24
  ):
25
  predicted_strings = generate_token_strings(pl_module.model, images=batch.images)
26
 
27
  for expected_data_index, predicted_string in zip(
28
+ batch.data_indices, predicted_strings, strict=True
29
  ):
30
  benetech_score = benetech_score_string_prediction(
31
  expected_data_index=expected_data_index,
 
50
 
51
  class TransformersPreTrainedModelsCheckpointIO(pl.plugins.CheckpointIO):
52
  def __init__(
53
+ self, pretrained_models: list[transformers.modeling_utils.PreTrainedModel]
54
  ):
55
  super().__init__()
56
  self.pretrained_models = pretrained_models
utils.py CHANGED
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
  import os
4
  import pickle
5
  from typing import Callable, TypeVar
@@ -16,7 +14,7 @@ def set_torch_device_order_pci_bus():
16
 
17
 
18
  def load_pickle_or_build_object_and_save(
19
- pickle_path: str, build_object: Callable[[], T], overwrite=False
20
  ) -> T:
21
  if overwrite or not os.path.exists(pickle_path):
22
  pickle.dump(build_object(), open(pickle_path, "wb"))
 
 
 
1
  import os
2
  import pickle
3
  from typing import Callable, TypeVar
 
14
 
15
 
16
  def load_pickle_or_build_object_and_save(
17
+ pickle_path: str, build_object: Callable[[], T], overwrite=False
18
  ) -> T:
19
  if overwrite or not os.path.exists(pickle_path):
20
  pickle.dump(build_object(), open(pickle_path, "wb"))