dkoshman commited on
Commit
1b4da0d
1 Parent(s): 96feb73

improved interface

Browse files
Files changed (2) hide show
  1. app.py +18 -9
  2. data_preprocessing.py +7 -4
app.py CHANGED
@@ -9,7 +9,6 @@ import torchvision.transforms as T
9
 
10
  MODEL_PATH = RESOURCES + "/model_2tcuvfsj.pt"
11
 
12
- # TODO: make faster
13
  transformer = torch.load(MODEL_PATH)
14
  image_transform = T.Compose((
15
  T.ToTensor(),
@@ -18,12 +17,22 @@ image_transform = T.Compose((
18
  random_magnitude=0)
19
  ))
20
 
21
- st.markdown("### Image to TeX")
22
- st.image("resources/frontend/latex_example_1.png")
23
- file_png = st.file_uploader("Upload a PNG image", type=([".png"]))
24
- if file_png is not None:
25
- image = PIL.Image.open(file_png)
 
 
 
 
 
 
 
 
26
  image = image.convert("RGB")
27
- tex = beam_search_decode(transformer, image, image_transform=image_transform)
28
- st.latex(tex[0])
29
- st.text(tex[0])
 
 
 
9
 
10
  MODEL_PATH = RESOURCES + "/model_2tcuvfsj.pt"
11
 
 
12
  transformer = torch.load(MODEL_PATH)
13
  image_transform = T.Compose((
14
  T.ToTensor(),
 
17
  random_magnitude=0)
18
  ))
19
 
20
+ st.title("Image to TeX")
21
+
22
+ st.image("resources/frontend/fraction_derivative.png", width=500)
23
+ st.image("resources/frontend/positional_encoding.png")
24
+ st.image("resources/frontend/taylor_sequence_expanded.png")
25
+ # st.image("resources/frontend/taylor_sequence.png")
26
+ # st.image("resources/frontend/maclaurin_series.png")
27
+ # st.image("resources/frontend/gauss_distribution.png")
28
+
29
+ image_file = st.file_uploader("Upload an image with equation", type=([".png", ".jpg", ".jpeg"]))
30
+
31
+ if image_file is not None:
32
+ image = PIL.Image.open(image_file)
33
  image = image.convert("RGB")
34
+ texs = beam_search_decode(transformer, image, image_transform=image_transform)
35
+ # streamlit latex doesn't support boldmath
36
+ tex = texs[0].replace("\\boldmath", "")
37
+ st.latex(tex)
38
+ st.markdown(tex)
data_preprocessing.py CHANGED
@@ -74,14 +74,16 @@ class RandomizeImageTransform(object):
74
 
75
  def __init__(self, width, height, random_magnitude):
76
  self.transform = T.Compose((
77
- T.ColorJitter(brightness=random_magnitude / 10, contrast=random_magnitude / 10,
78
- saturation=random_magnitude / 10, hue=min(0.5, random_magnitude / 10)),
 
 
79
  T.Resize(height, max_size=width),
80
  T.Grayscale(),
81
  T.functional.invert,
82
  T.CenterCrop((height, width)),
83
  torch.Tensor.contiguous,
84
- T.RandAugment(magnitude=random_magnitude),
85
  T.ConvertImageDtype(torch.float32)
86
  ))
87
 
@@ -133,7 +135,8 @@ class LatexImageDataModule(pl.LightningDataModule):
133
  super().__init__()
134
 
135
  dataset = TexImageDataset(root_dir=DATA_DIR,
136
- image_transform=RandomizeImageTransform(image_width, image_height, random_magnitude),
 
137
  tex_transform=ExtractEquationFromTexTransform())
138
  self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
139
  dataset, [len(dataset) * 18 // 20, len(dataset) // 20, len(dataset) // 20])
 
74
 
75
  def __init__(self, width, height, random_magnitude):
76
  self.transform = T.Compose((
77
+ lambda x: x if random_magnitude == 0 else T.ColorJitter(brightness=random_magnitude / 10,
78
+ contrast=random_magnitude / 10,
79
+ saturation=random_magnitude / 10,
80
+ hue=min(0.5, random_magnitude / 10)),
81
  T.Resize(height, max_size=width),
82
  T.Grayscale(),
83
  T.functional.invert,
84
  T.CenterCrop((height, width)),
85
  torch.Tensor.contiguous,
86
+ lambda x: x if random_magnitude == 0 else T.RandAugment(magnitude=random_magnitude),
87
  T.ConvertImageDtype(torch.float32)
88
  ))
89
 
 
135
  super().__init__()
136
 
137
  dataset = TexImageDataset(root_dir=DATA_DIR,
138
+ image_transform=RandomizeImageTransform(image_width, image_height,
139
+ random_magnitude),
140
  tex_transform=ExtractEquationFromTexTransform())
141
  self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
142
  dataset, [len(dataset) * 18 // 20, len(dataset) // 20, len(dataset) // 20])