Spaces:
Runtime error
Runtime error
dkoshman
commited on
Commit
•
1b4da0d
1
Parent(s):
96feb73
improved interface
Browse files- app.py +18 -9
- data_preprocessing.py +7 -4
app.py
CHANGED
@@ -9,7 +9,6 @@ import torchvision.transforms as T
|
|
9 |
|
10 |
MODEL_PATH = RESOURCES + "/model_2tcuvfsj.pt"
|
11 |
|
12 |
-
# TODO: make faster
|
13 |
transformer = torch.load(MODEL_PATH)
|
14 |
image_transform = T.Compose((
|
15 |
T.ToTensor(),
|
@@ -18,12 +17,22 @@ image_transform = T.Compose((
|
|
18 |
random_magnitude=0)
|
19 |
))
|
20 |
|
21 |
-
st.
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
image = image.convert("RGB")
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
9 |
|
10 |
MODEL_PATH = RESOURCES + "/model_2tcuvfsj.pt"
|
11 |
|
|
|
12 |
transformer = torch.load(MODEL_PATH)
|
13 |
image_transform = T.Compose((
|
14 |
T.ToTensor(),
|
|
|
17 |
random_magnitude=0)
|
18 |
))
|
19 |
|
20 |
+
st.title("Image to TeX")
|
21 |
+
|
22 |
+
st.image("resources/frontend/fraction_derivative.png", width=500)
|
23 |
+
st.image("resources/frontend/positional_encoding.png")
|
24 |
+
st.image("resources/frontend/taylor_sequence_expanded.png")
|
25 |
+
# st.image("resources/frontend/taylor_sequence.png")
|
26 |
+
# st.image("resources/frontend/maclaurin_series.png")
|
27 |
+
# st.image("resources/frontend/gauss_distribution.png")
|
28 |
+
|
29 |
+
image_file = st.file_uploader("Upload an image with equation", type=([".png", ".jpg", ".jpeg"]))
|
30 |
+
|
31 |
+
if image_file is not None:
|
32 |
+
image = PIL.Image.open(image_file)
|
33 |
image = image.convert("RGB")
|
34 |
+
texs = beam_search_decode(transformer, image, image_transform=image_transform)
|
35 |
+
# streamlit latex doesn't support boldmath
|
36 |
+
tex = texs[0].replace("\\boldmath", "")
|
37 |
+
st.latex(tex)
|
38 |
+
st.markdown(tex)
|
data_preprocessing.py
CHANGED
@@ -74,14 +74,16 @@ class RandomizeImageTransform(object):
|
|
74 |
|
75 |
def __init__(self, width, height, random_magnitude):
|
76 |
self.transform = T.Compose((
|
77 |
-
T.ColorJitter(brightness=random_magnitude / 10,
|
78 |
-
|
|
|
|
|
79 |
T.Resize(height, max_size=width),
|
80 |
T.Grayscale(),
|
81 |
T.functional.invert,
|
82 |
T.CenterCrop((height, width)),
|
83 |
torch.Tensor.contiguous,
|
84 |
-
T.RandAugment(magnitude=random_magnitude),
|
85 |
T.ConvertImageDtype(torch.float32)
|
86 |
))
|
87 |
|
@@ -133,7 +135,8 @@ class LatexImageDataModule(pl.LightningDataModule):
|
|
133 |
super().__init__()
|
134 |
|
135 |
dataset = TexImageDataset(root_dir=DATA_DIR,
|
136 |
-
image_transform=RandomizeImageTransform(image_width, image_height,
|
|
|
137 |
tex_transform=ExtractEquationFromTexTransform())
|
138 |
self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
|
139 |
dataset, [len(dataset) * 18 // 20, len(dataset) // 20, len(dataset) // 20])
|
|
|
74 |
|
75 |
def __init__(self, width, height, random_magnitude):
|
76 |
self.transform = T.Compose((
|
77 |
+
lambda x: x if random_magnitude == 0 else T.ColorJitter(brightness=random_magnitude / 10,
|
78 |
+
contrast=random_magnitude / 10,
|
79 |
+
saturation=random_magnitude / 10,
|
80 |
+
hue=min(0.5, random_magnitude / 10)),
|
81 |
T.Resize(height, max_size=width),
|
82 |
T.Grayscale(),
|
83 |
T.functional.invert,
|
84 |
T.CenterCrop((height, width)),
|
85 |
torch.Tensor.contiguous,
|
86 |
+
lambda x: x if random_magnitude == 0 else T.RandAugment(magnitude=random_magnitude),
|
87 |
T.ConvertImageDtype(torch.float32)
|
88 |
))
|
89 |
|
|
|
135 |
super().__init__()
|
136 |
|
137 |
dataset = TexImageDataset(root_dir=DATA_DIR,
|
138 |
+
image_transform=RandomizeImageTransform(image_width, image_height,
|
139 |
+
random_magnitude),
|
140 |
tex_transform=ExtractEquationFromTexTransform())
|
141 |
self.train_dataset, self.val_dataset, self.test_dataset = torch.utils.data.random_split(
|
142 |
dataset, [len(dataset) * 18 // 20, len(dataset) // 20, len(dataset) // 20])
|