yschneider commited on
Commit
99f3ba3
1 Parent(s): 4804e20

Upload app

Browse files
Files changed (4) hide show
  1. README.md +2 -0
  2. app.py +120 -0
  3. examples/default.jpg +0 -0
  4. requirements.txt +2 -0
README.md CHANGED
@@ -8,6 +8,8 @@ sdk_version: 4.13.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ models:
12
+ - Teklia/pylaia-rimes
13
  ---
14
 
15
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from uuid import uuid4
2
+ import gradio as gr
3
+ from laia.scripts.htr.decode_ctc import run as decode
4
+ from laia.common.arguments import CommonArgs, DataArgs, TrainerArgs, DecodeArgs
5
+ import sys
6
+ from tempfile import NamedTemporaryFile, mkdtemp
7
+ from pathlib import Path
8
+ from contextlib import redirect_stdout
9
+ import re
10
+ from huggingface_hub import snapshot_download
11
+
12
+ images = Path(mkdtemp())
13
+
14
+ IMAGE_ID_PATTERN = r"(?P<image_id>[-a-z0-9]{36})"
15
+ CONFIDENCE_PATTERN = r"(?P<confidence>[0-9.]+)" # For line
16
+ TEXT_PATTERN = r"\s*(?P<text>.*)\s*"
17
+ LINE_PREDICTION = re.compile(rf"{IMAGE_ID_PATTERN} {CONFIDENCE_PATTERN} {TEXT_PATTERN}")
18
+ models_name = ["Teklia/pylaia-rimes"]
19
+
20
+ DEFAULT_HEIGHT = 128
21
+
22
+
23
+ def get_width(image, height=DEFAULT_HEIGHT):
24
+ aspect_ratio = image.width / image.height
25
+ return height * aspect_ratio
26
+
27
+
28
+ def predict(model_name, input_img):
29
+ model_dir = Path(snapshot_download(model_name))
30
+
31
+ temperature = 2.0
32
+ batch_size = 1
33
+
34
+ weights_path = model_dir / "weights.ckpt"
35
+ syms_path = model_dir / "syms.txt"
36
+ language_model_params = {"language_model_weight": 1.0}
37
+ use_language_model = (model_dir / "tokens.txt").exists()
38
+ if use_language_model:
39
+ language_model_params.update(
40
+ {
41
+ "language_model_path": str(model_dir / "language_model.arpa.gz"),
42
+ "lexicon_path": str(model_dir / "lexicon.txt"),
43
+ "tokens_path": str(model_dir / "tokens.txt"),
44
+ }
45
+ )
46
+
47
+ common_args = CommonArgs(
48
+ checkpoint=str(weights_path.relative_to(model_dir)),
49
+ train_path=str(model_dir),
50
+ experiment_dirname="",
51
+ )
52
+ data_args = DataArgs(batch_size=batch_size, color_mode="L")
53
+ trainer_args = TrainerArgs(
54
+ # Disable progress bar else it messes with frontend display
55
+ progress_bar_refresh_rate=0
56
+ )
57
+ decode_args = DecodeArgs(
58
+ include_img_ids=True,
59
+ join_string="",
60
+ convert_spaces=True,
61
+ print_line_confidence_scores=True,
62
+ print_word_confidence_scores=False,
63
+ temperature=temperature,
64
+ use_language_model=use_language_model,
65
+ **language_model_params,
66
+ )
67
+
68
+ with NamedTemporaryFile() as pred_stdout, NamedTemporaryFile() as img_list:
69
+ image_id = uuid4()
70
+ # Resize image to 128 if bigger/smaller
71
+ input_img = input_img.resize((int(get_width(input_img)), DEFAULT_HEIGHT))
72
+ input_img.save(str(images / f"{image_id}.jpg"))
73
+ # Export image list
74
+ Path(img_list.name).write_text("\n".join([str(image_id)]))
75
+
76
+ # Capture stdout as that's where PyLaia outputs predictions
77
+ with redirect_stdout(open(pred_stdout.name, mode="w")):
78
+ decode(
79
+ syms=str(syms_path),
80
+ img_list=img_list.name,
81
+ img_dirs=[str(images)],
82
+ common=common_args,
83
+ data=data_args,
84
+ trainer=trainer_args,
85
+ decode=decode_args,
86
+ num_workers=1,
87
+ )
88
+ # Flush stdout to avoid output buffering
89
+ sys.stdout.flush()
90
+ predictions = Path(pred_stdout.name).read_text().strip().splitlines()
91
+ assert len(predictions) == 1
92
+ _, score, text = LINE_PREDICTION.match(predictions[0]).groups()
93
+ return input_img, {"text": text, "score": score}
94
+
95
+
96
+ gradio_app = gr.Interface(
97
+ predict,
98
+ inputs=[
99
+ gr.Dropdown(models_name, value=models_name[0], label="Models"),
100
+ gr.Image(
101
+ label="Upload an image of a line",
102
+ sources=["upload", "clipboard"],
103
+ type="pil",
104
+ height=DEFAULT_HEIGHT,
105
+ width=2000,
106
+ ),
107
+ ],
108
+ outputs=[
109
+ gr.Image(label="Processed Image"),
110
+ gr.JSON(label="Decoded text"),
111
+ ],
112
+ examples=[
113
+ ["Teklia/pylaia-rimes", str(filename)]
114
+ for filename in Path("examples").iterdir()
115
+ ],
116
+ title="Decode the transcription of an image using a PyLaia model",
117
+ )
118
+
119
+ if __name__ == "__main__":
120
+ gradio_app.launch()
examples/default.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pylaia==1.1.0
2
+ teklia_toolbox==0.1.3