DimaKoshman commited on
Commit
b78886d
1 Parent(s): 1fa5cf1

fixed annotations

Browse files
Files changed (5) hide show
  1. data.py +4 -2
  2. metrics.py +2 -0
  3. model.py +5 -3
  4. train.py +5 -3
  5. utils.py +3 -1
data.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import dataclasses
2
  import enum
3
  import functools
@@ -229,7 +231,7 @@ class AnnotatedImage:
229
 
230
  def generate_annotated_images():
231
  for image_id in tqdm.autonotebook.tqdm(
232
- load_train_image_ids(), "Iterating over annotated images"
233
  ):
234
  yield AnnotatedImage.from_image_id(image_id)
235
 
@@ -331,7 +333,7 @@ def convert_number_to_scientific_string(value: int or float) -> str:
331
 
332
 
333
  def convert_axis_data_to_string(
334
- axis_data: list[str or float], values_type: ValuesType
335
  ) -> str:
336
  formatted_axis_data = []
337
  for value in axis_data:
 
1
+ from __future__ import annotations
2
+
3
  import dataclasses
4
  import enum
5
  import functools
 
231
 
232
  def generate_annotated_images():
233
  for image_id in tqdm.autonotebook.tqdm(
234
+ load_train_image_ids(), "Iterating over annotated images"
235
  ):
236
  yield AnnotatedImage.from_image_id(image_id)
237
 
 
333
 
334
 
335
  def convert_axis_data_to_string(
336
+ axis_data: list[str or float], values_type: ValuesType
337
  ) -> str:
338
  formatted_axis_data = []
339
  for value in axis_data:
metrics.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import numpy as np
2
  import rapidfuzz
3
  import sklearn
 
1
+ from __future__ import annotations
2
+
3
  import numpy as np
4
  import rapidfuzz
5
  import sklearn
model.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import collections
2
  import dataclasses
3
  import types
@@ -28,7 +30,7 @@ class Model:
28
 
29
 
30
  def add_unknown_tokens_to_tokenizer(
31
- tokenizer, encoder_decoder, unknown_tokens: list[str]
32
  ):
33
  tokenizer.add_tokens(unknown_tokens)
34
  encoder_decoder.decoder.resize_token_embeddings(len(tokenizer))
@@ -51,7 +53,7 @@ def find_unknown_tokens_for_tokenizer(tokenizer) -> collections.Counter:
51
 
52
 
53
  def replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(
54
- tokenizer, token_ids
55
  ):
56
  token_ids[token_ids == tokenizer.pad_token_id] = -100
57
  return token_ids
@@ -142,7 +144,7 @@ def build_model(config: types.SimpleNamespace or object) -> Model:
142
 
143
 
144
  def generate_token_strings(
145
- model: Model, images: torch.Tensor, skip_special_tokens=True
146
  ) -> list[str]:
147
  decoder_output = model.encoder_decoder.generate(
148
  images,
 
1
+ from __future__ import annotations
2
+
3
  import collections
4
  import dataclasses
5
  import types
 
30
 
31
 
32
  def add_unknown_tokens_to_tokenizer(
33
+ tokenizer, encoder_decoder, unknown_tokens: list[str]
34
  ):
35
  tokenizer.add_tokens(unknown_tokens)
36
  encoder_decoder.decoder.resize_token_embeddings(len(tokenizer))
 
53
 
54
 
55
  def replace_pad_token_id_with_negative_hundred_for_hf_transformers_automatic_batch_transformation(
56
+ tokenizer, token_ids
57
  ):
58
  token_ids[token_ids == tokenizer.pad_token_id] = -100
59
  return token_ids
 
144
 
145
 
146
  def generate_token_strings(
147
+ model: Model, images: torch.Tensor, skip_special_tokens=True
148
  ) -> list[str]:
149
  decoder_output = model.encoder_decoder.generate(
150
  images,
train.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
 
3
  import pandas as pd
@@ -20,12 +22,12 @@ from utils import set_tokenizers_parallelism, set_torch_device_order_pci_bus
20
 
21
  class MetricsCallback(pl.callbacks.Callback):
22
  def on_validation_batch_start(
23
- self, trainer, pl_module, batch: Batch, batch_idx, dataloader_idx=0
24
  ):
25
  predicted_strings = generate_token_strings(pl_module.model, images=batch.images)
26
 
27
  for expected_data_index, predicted_string in zip(
28
- batch.data_indices, predicted_strings, strict=True
29
  ):
30
  benetech_score = benetech_score_string_prediction(
31
  expected_data_index=expected_data_index,
@@ -50,7 +52,7 @@ class MetricsCallback(pl.callbacks.Callback):
50
 
51
  class TransformersPreTrainedModelsCheckpointIO(pl.plugins.CheckpointIO):
52
  def __init__(
53
- self, pretrained_models: list[transformers.modeling_utils.PreTrainedModel]
54
  ):
55
  super().__init__()
56
  self.pretrained_models = pretrained_models
 
1
+ from __future__ import annotations
2
+
3
  import os
4
 
5
  import pandas as pd
 
22
 
23
  class MetricsCallback(pl.callbacks.Callback):
24
  def on_validation_batch_start(
25
+ self, trainer, pl_module, batch: Batch, batch_idx, dataloader_idx=0
26
  ):
27
  predicted_strings = generate_token_strings(pl_module.model, images=batch.images)
28
 
29
  for expected_data_index, predicted_string in zip(
30
+ batch.data_indices, predicted_strings, strict=True
31
  ):
32
  benetech_score = benetech_score_string_prediction(
33
  expected_data_index=expected_data_index,
 
52
 
53
  class TransformersPreTrainedModelsCheckpointIO(pl.plugins.CheckpointIO):
54
  def __init__(
55
+ self, pretrained_models: list[transformers.modeling_utils.PreTrainedModel]
56
  ):
57
  super().__init__()
58
  self.pretrained_models = pretrained_models
utils.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import pickle
3
  from typing import Callable, TypeVar
@@ -14,7 +16,7 @@ def set_torch_device_order_pci_bus():
14
 
15
 
16
  def load_pickle_or_build_object_and_save(
17
- pickle_path: str, build_object: Callable[[], T], overwrite=False
18
  ) -> T:
19
  if overwrite or not os.path.exists(pickle_path):
20
  pickle.dump(build_object(), open(pickle_path, "wb"))
 
1
+ from __future__ import annotations
2
+
3
  import os
4
  import pickle
5
  from typing import Callable, TypeVar
 
16
 
17
 
18
  def load_pickle_or_build_object_and_save(
19
+ pickle_path: str, build_object: Callable[[], T], overwrite=False
20
  ) -> T:
21
  if overwrite or not os.path.exists(pickle_path):
22
  pickle.dump(build_object(), open(pickle_path, "wb"))