chanicpanic commited on
Commit
5bcc73a
Β·
1 Parent(s): f93b005

Create streamlit app

Browse files
Files changed (10) hide show
  1. .gitignore +4 -0
  2. .streamlit/config.toml +5 -0
  3. README.md +12 -5
  4. app.py +36 -0
  5. report.py +115 -0
  6. requirements.txt +84 -0
  7. save_image.py +23 -0
  8. scheduler.py +191 -0
  9. search.py +230 -0
  10. vision.py +29 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .streamlit/secrets.toml
2
+ env/
3
+ __pycache__/
4
+ flagged_rows/
.streamlit/config.toml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [browser]
2
+ gatherUsageStats = false
3
+
4
+ [server]
5
+ maxUploadSize = 5
README.md CHANGED
@@ -1,12 +1,19 @@
1
  ---
2
- title: Search
3
- emoji: πŸ‘€
4
  colorFrom: gray
5
  colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.31.0
8
  app_file: app.py
9
- pinned: false
 
 
 
 
 
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Grascii Search
3
+ emoji: πŸ”Ž
4
  colorFrom: gray
5
  colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.40.2
8
  app_file: app.py
9
+ pinned: true
10
+ models:
11
+ - grascii/gregg-vision-v0.2.1
12
+ datasets:
13
+ - grascii/gregg-preanniversary-words
14
+ preload_from_hub:
15
+ - grascii/gregg-vision-v0.2.1
16
+ - grascii/gregg-preanniversary-words
17
  ---
18
 
19
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ st.set_page_config(
4
+ page_title="Grascii Search",
5
+ menu_items={
6
+ "About": """
7
+ Web interface for [grascii](https://github.com/grascii/grascii)'s
8
+ search utility
9
+
10
+ Image search powered by [gregg-vision-v0.2.1](https://huggingface.co/grascii/gregg-vision-v0.2.1)
11
+ """
12
+ },
13
+ )
14
+
15
+ import pandas as pd # noqa E402
16
+ from search import write_grascii_search, write_reverse_search # noqa E402
17
+
18
+ pd.options.mode.copy_on_write = True
19
+
20
+ if "report_submitted" not in st.session_state:
21
+ st.session_state["report_submitted"] = False
22
+
23
+ if "grascii" not in st.session_state:
24
+ st.session_state["grascii"] = ""
25
+
26
+ if st.session_state["report_submitted"]:
27
+ st.toast("Thanks for the report!")
28
+ st.session_state["report_submitted"] = False
29
+
30
+ tab1, tab2 = st.tabs(["Grascii", "Reverse"])
31
+
32
+ with tab1:
33
+ write_grascii_search()
34
+
35
+ with tab2:
36
+ write_reverse_search()
report.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pathlib import Path
3
+ from uuid import uuid4
4
+ import csv
5
+ from datetime import datetime, timezone
6
+
7
+ from huggingface_hub import CommitScheduler
8
+
9
+
10
+ CSV_DATASET_DIR = Path("flagged_rows")
11
+ CSV_DATASET_DIR.mkdir(parents=True, exist_ok=True)
12
+
13
+ CSV_DATASET_PATH = CSV_DATASET_DIR / f"train-{uuid4()}.csv"
14
+
15
+ wrote_header = False
16
+
17
+
18
+ def write_header(writer):
19
+ writer.writerow(
20
+ [
21
+ "date",
22
+ "grascii",
23
+ "longhand",
24
+ "incorrect_grascii",
25
+ "incorrect_longhand",
26
+ "incorrect_shorthand",
27
+ "improperly_cropped",
28
+ "extraneous_marks",
29
+ ]
30
+ )
31
+ global wrote_header
32
+ wrote_header = True
33
+
34
+
35
+ scheduler = CommitScheduler(
36
+ repo_id=st.secrets.FEEDBACK_REPO,
37
+ repo_type="dataset",
38
+ folder_path=CSV_DATASET_DIR,
39
+ path_in_repo="data",
40
+ every=15,
41
+ token=st.secrets.HF_TOKEN,
42
+ )
43
+
44
+
45
+ @st.dialog("Flag Results for Review", width="large")
46
+ def report_dialog(data):
47
+ st.write("Please select one or more reasons for flagging each row:")
48
+
49
+ report_df = data
50
+ report_df["3"] = True
51
+ report_df["4"] = False
52
+ report_df["5"] = False
53
+ report_df["6"] = False
54
+ report_df["7"] = False
55
+ report_df["8"] = False
56
+ final_report = st.data_editor(
57
+ report_df,
58
+ hide_index=True,
59
+ column_config={
60
+ "0": "Grascii",
61
+ "1": "Longhand",
62
+ "2": st.column_config.ImageColumn("Shorthand", width="medium"),
63
+ "3": st.column_config.CheckboxColumn("Flag"),
64
+ "4": st.column_config.CheckboxColumn("Grascii is incorrect"),
65
+ "5": st.column_config.CheckboxColumn("Longhand is incorrect"),
66
+ "6": st.column_config.CheckboxColumn("Shorthand image is incorrect"),
67
+ "7": st.column_config.CheckboxColumn(
68
+ "Shorthand image is improperly cropped"
69
+ ),
70
+ "8": st.column_config.CheckboxColumn(
71
+ "Shorthand image contains extraneous marks"
72
+ ),
73
+ },
74
+ disabled=["0", "1", "2"],
75
+ use_container_width=True,
76
+ )
77
+
78
+ st.write(
79
+ "If you decide that a listed row does not need to be flagged, uncheck its 'Flag' box to prevent it from being included in the submission."
80
+ )
81
+
82
+ if st.button("Submit"):
83
+ with scheduler.lock:
84
+ with open(CSV_DATASET_PATH, "a", newline="") as f:
85
+ writer = csv.writer(f, dialect="unix")
86
+
87
+ def write_row(row):
88
+ if not wrote_header:
89
+ write_header(writer)
90
+ if row.iloc[3] and any(
91
+ [
92
+ row.iloc[4],
93
+ row.iloc[5],
94
+ row.iloc[6],
95
+ row.iloc[7],
96
+ row.iloc[8],
97
+ ]
98
+ ):
99
+ writer.writerow(
100
+ [
101
+ datetime.now(timezone.utc).date(),
102
+ row.iloc[0],
103
+ row.iloc[1],
104
+ 1 if row.iloc[4] else 0,
105
+ 1 if row.iloc[5] else 0,
106
+ 1 if row.iloc[6] else 0,
107
+ 1 if row.iloc[7] else 0,
108
+ 1 if row.iloc[8] else 0,
109
+ ]
110
+ )
111
+
112
+ final_report.apply(write_row, axis=1)
113
+
114
+ st.session_state["report_submitted"] = True
115
+ st.rerun()
requirements.txt ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.2.0
2
+ aiohappyeyeballs==2.4.4
3
+ aiohttp==3.11.10
4
+ aiosignal==1.3.1
5
+ altair==5.5.0
6
+ async-timeout==5.0.1
7
+ attrs==24.2.0
8
+ blinker==1.9.0
9
+ cachetools==5.5.0
10
+ certifi==2024.8.30
11
+ charset-normalizer==3.4.0
12
+ click==8.1.7
13
+ datasets==3.1.0
14
+ dill==0.3.8
15
+ filelock==3.16.1
16
+ frozenlist==1.5.0
17
+ fsspec==2024.9.0
18
+ gitdb==4.0.11
19
+ GitPython==3.1.43
20
+ grascii==0.6.0
21
+ huggingface-hub==0.26.5
22
+ idna==3.10
23
+ Jinja2==3.1.4
24
+ jsonschema==4.23.0
25
+ jsonschema-specifications==2024.10.1
26
+ lark==1.2.2
27
+ markdown-it-py==3.0.0
28
+ MarkupSafe==3.0.2
29
+ mdurl==0.1.2
30
+ mpmath==1.3.0
31
+ multidict==6.1.0
32
+ multiprocess==0.70.16
33
+ narwhals==1.16.0
34
+ networkx==3.4.2
35
+ numpy==2.2.0
36
+ nvidia-cublas-cu12==12.4.5.8
37
+ nvidia-cuda-cupti-cu12==12.4.127
38
+ nvidia-cuda-nvrtc-cu12==12.4.127
39
+ nvidia-cuda-runtime-cu12==12.4.127
40
+ nvidia-cudnn-cu12==9.1.0.70
41
+ nvidia-cufft-cu12==11.2.1.3
42
+ nvidia-curand-cu12==10.3.5.147
43
+ nvidia-cusolver-cu12==11.6.1.9
44
+ nvidia-cusparse-cu12==12.3.1.170
45
+ nvidia-nccl-cu12==2.21.5
46
+ nvidia-nvjitlink-cu12==12.4.127
47
+ nvidia-nvtx-cu12==12.4.127
48
+ packaging==24.2
49
+ pandas==2.2.3
50
+ pillow==11.0.0
51
+ platformdirs==4.3.6
52
+ propcache==0.2.1
53
+ protobuf==5.29.1
54
+ psutil==6.1.0
55
+ pyarrow==18.1.0
56
+ pydeck==0.9.1
57
+ Pygments==2.18.0
58
+ python-dateutil==2.9.0.post0
59
+ pytz==2024.2
60
+ PyYAML==6.0.2
61
+ referencing==0.35.1
62
+ regex==2024.11.6
63
+ requests==2.32.3
64
+ rich==13.9.4
65
+ rpds-py==0.22.3
66
+ safetensors==0.4.5
67
+ six==1.17.0
68
+ smmap==5.0.1
69
+ streamlit==1.40.2
70
+ sympy==1.13.1
71
+ tenacity==9.0.0
72
+ tokenizers==0.21.0
73
+ toml==0.10.2
74
+ torch==2.5.1
75
+ tornado==6.4.2
76
+ tqdm==4.67.1
77
+ transformers==4.47.0
78
+ triton==3.1.0
79
+ typing_extensions==4.12.2
80
+ tzdata==2024.2
81
+ urllib3==2.2.3
82
+ watchdog==6.0.0
83
+ xxhash==3.5.0
84
+ yarl==1.18.3
save_image.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from hashlib import sha256
2
+ import streamlit as st
3
+ from datetime import datetime, timezone
4
+ from scheduler import ParquetScheduler
5
+
6
+
7
+ scheduler = ParquetScheduler(
8
+ repo_id=st.secrets.IMAGES_REPO,
9
+ token=st.secrets.HF_TOKEN,
10
+ every=15,
11
+ )
12
+
13
+
14
+ @st.cache_data(ttl=3600)
15
+ def save_image(data, prediction):
16
+ scheduler.append(
17
+ {
18
+ "date": datetime.now(timezone.utc).date(),
19
+ "image": data,
20
+ "prediction": prediction,
21
+ "sha256": sha256(data).hexdigest(),
22
+ }
23
+ )
scheduler.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MIT License
3
+
4
+ Copyright (c) 2023 hysts
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
23
+ """
24
+
25
+ import json
26
+ import tempfile
27
+ import uuid
28
+ from pathlib import Path
29
+ from typing import Any, Dict, List, Optional, Union
30
+
31
+ import pyarrow as pa
32
+ import pyarrow.parquet as pq
33
+ from huggingface_hub import CommitScheduler, HfApi
34
+
35
+
36
+ class ParquetScheduler(CommitScheduler):
37
+ """
38
+ Usage: configure the scheduler with a repo id. Once started, you can add data to be uploaded to the Hub. 1 `.append`
39
+ call will result in 1 row in your final dataset.
40
+
41
+ ```py
42
+ # Start scheduler
43
+ >>> scheduler = ParquetScheduler(repo_id="my-parquet-dataset")
44
+
45
+ # Append some data to be uploaded
46
+ >>> scheduler.append({...})
47
+ >>> scheduler.append({...})
48
+ >>> scheduler.append({...})
49
+ ```
50
+
51
+ The scheduler will automatically infer the schema from the data it pushes.
52
+ Optionally, you can manually set the schema yourself:
53
+
54
+ ```py
55
+ >>> scheduler = ParquetScheduler(
56
+ ... repo_id="my-parquet-dataset",
57
+ ... schema={
58
+ ... "prompt": {"_type": "Value", "dtype": "string"},
59
+ ... "negative_prompt": {"_type": "Value", "dtype": "string"},
60
+ ... "guidance_scale": {"_type": "Value", "dtype": "int64"},
61
+ ... "image": {"_type": "Image"},
62
+ ... },
63
+ ... )
64
+
65
+ See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value for the list of
66
+ possible values.
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ *,
72
+ repo_id: str,
73
+ schema: Optional[Dict[str, Dict[str, str]]] = None,
74
+ every: Union[int, float] = 5,
75
+ revision: Optional[str] = None,
76
+ private: bool = False,
77
+ token: Optional[str] = None,
78
+ allow_patterns: Union[List[str], str, None] = None,
79
+ ignore_patterns: Union[List[str], str, None] = None,
80
+ hf_api: Optional[HfApi] = None,
81
+ ) -> None:
82
+ super().__init__(
83
+ repo_id=repo_id,
84
+ folder_path=tempfile.tempdir, # not used by the scheduler
85
+ every=every,
86
+ repo_type="dataset",
87
+ revision=revision,
88
+ private=private,
89
+ token=token,
90
+ allow_patterns=allow_patterns,
91
+ ignore_patterns=ignore_patterns,
92
+ hf_api=hf_api,
93
+ )
94
+
95
+ self._rows: List[Dict[str, Any]] = []
96
+ self._schema = schema
97
+
98
+ def append(self, row: Dict[str, Any]) -> None:
99
+ """Add a new item to be uploaded."""
100
+ with self.lock:
101
+ self._rows.append(row)
102
+
103
+ def push_to_hub(self):
104
+ # Check for new rows to push
105
+ with self.lock:
106
+ rows = self._rows
107
+ self._rows = []
108
+ if not rows:
109
+ return
110
+ print(f"Got {len(rows)} item(s) to commit.")
111
+
112
+ # Load images + create 'features' config for datasets library
113
+ schema: Dict[str, Dict] = self._schema or {}
114
+ path_to_cleanup: List[Path] = []
115
+ for row in rows:
116
+ for key, value in row.items():
117
+ # Infer schema (for `datasets` library)
118
+ if key not in schema:
119
+ schema[key] = _infer_schema(key, value)
120
+
121
+ # Load binary files if necessary
122
+ if schema[key]["_type"] in ("Image", "Audio"):
123
+ if isinstance(value, bytes):
124
+ row[key] = {
125
+ "path": "",
126
+ "bytes": value,
127
+ }
128
+ else:
129
+ # It's an image or audio: we load the bytes and remember to cleanup the file
130
+ file_path = Path(value)
131
+ if file_path.is_file():
132
+ row[key] = {
133
+ "path": file_path.name,
134
+ "bytes": file_path.read_bytes(),
135
+ }
136
+ path_to_cleanup.append(file_path)
137
+
138
+ # Complete rows if needed
139
+ for row in rows:
140
+ for feature in schema:
141
+ if feature not in row:
142
+ row[feature] = None
143
+
144
+ # Export items to Arrow format
145
+ table = pa.Table.from_pylist(rows)
146
+
147
+ # Add metadata (used by datasets library)
148
+ table = table.replace_schema_metadata(
149
+ {"huggingface": json.dumps({"info": {"features": schema}})}
150
+ )
151
+
152
+ # Write to parquet file
153
+ archive_file = tempfile.NamedTemporaryFile()
154
+ pq.write_table(table, archive_file.name)
155
+
156
+ # Upload
157
+ self.api.upload_file(
158
+ repo_id=self.repo_id,
159
+ repo_type=self.repo_type,
160
+ revision=self.revision,
161
+ path_in_repo=f"{uuid.uuid4()}.parquet",
162
+ path_or_fileobj=archive_file.name,
163
+ )
164
+ print("Commit completed.")
165
+
166
+ # Cleanup
167
+ archive_file.close()
168
+ for path in path_to_cleanup:
169
+ path.unlink(missing_ok=True)
170
+
171
+
172
+ def _infer_schema(key: str, value: Any) -> Dict[str, str]:
173
+ """Infer schema for the `datasets` library.
174
+
175
+ See
176
+ https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value.
177
+ """
178
+ if "image" in key:
179
+ return {"_type": "Image"}
180
+ if "audio" in key:
181
+ return {"_type": "Audio"}
182
+ if isinstance(value, int):
183
+ return {"_type": "Value", "dtype": "int64"}
184
+ if isinstance(value, float):
185
+ return {"_type": "Value", "dtype": "float64"}
186
+ if isinstance(value, bool):
187
+ return {"_type": "Value", "dtype": "bool"}
188
+ if isinstance(value, bytes):
189
+ return {"_type": "Value", "dtype": "binary"}
190
+ # Otherwise in last resort => convert it to a string
191
+ return {"_type": "Value", "dtype": "string"}
search.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+ import numpy as np
4
+ import streamlit as st
5
+ from PIL import Image
6
+ import pandas as pd
7
+ from datasets import load_dataset
8
+ from grascii import GrasciiSearcher, InvalidGrascii, ReverseSearcher
9
+ from report import report_dialog
10
+ from vision import run_vision
11
+ from save_image import save_image
12
+
13
+
14
+ @st.cache_data(show_spinner="Loading shorthand images")
15
+ def load_images():
16
+ ds = load_dataset(
17
+ "grascii/gregg-preanniversary-words", split="train", token=st.secrets.HF_TOKEN
18
+ )
19
+ image_map = {}
20
+ for row in ds:
21
+ buffered = BytesIO()
22
+ row["image"].save(buffered, format="PNG")
23
+ b64 = base64.b64encode(buffered.getvalue())
24
+ image_map[row["longhand"]] = "data:image/png;base64," + b64.decode("utf-8")
25
+ return image_map
26
+
27
+
28
+ image_map = load_images()
29
+
30
+
31
+ def set_grascii():
32
+ if "grascii_text_box" in st.session_state:
33
+ st.session_state["grascii"] = st.session_state["grascii_text_box"]
34
+
35
+
36
+ def write_grascii_search():
37
+ searcher = GrasciiSearcher()
38
+ grascii_results = []
39
+
40
+ search_by = st.radio("Search by", ["text", "image (beta)"], horizontal=True)
41
+
42
+ with st.form("Grascii Search"):
43
+ placeholder = st.empty()
44
+ if search_by == "text":
45
+ placeholder.text_input(
46
+ "Grascii", value=st.session_state["grascii"], key="grascii_text_box"
47
+ )
48
+ else:
49
+ image_data = placeholder.file_uploader(
50
+ "Image",
51
+ type=["png", "jpg"],
52
+ help="""
53
+ Upload an image of a shorthand form.
54
+
55
+ At this time, minimal preprocessing is performed on images
56
+ before running them through the model. For best results,
57
+ upload an image:
58
+
59
+ - of a closely cropped, single shorthand form
60
+ - with the shorthand written in black on a white background
61
+ - that does not contain marks beside the shorthand form
62
+ """,
63
+ )
64
+ save = st.checkbox(
65
+ "Save images I upload for potential inclusion in open-source datasets used to train and improve models"
66
+ )
67
+
68
+ if image_data:
69
+ image = Image.open(image_data).convert("RGBA")
70
+ background = Image.new("RGBA", image.size, (255, 255, 255))
71
+ alpha_composite = Image.alpha_composite(background, image)
72
+
73
+ arr = np.array([alpha_composite.convert("L")])
74
+ tokens = run_vision(arr)
75
+ st.session_state["grascii"] = "".join(tokens)
76
+ if save:
77
+ save_image(image_data.getvalue(), "-".join(tokens))
78
+
79
+ with st.expander("Options"):
80
+ interpretation = st.radio(
81
+ "Interpretation",
82
+ ["best", "all"],
83
+ horizontal=True,
84
+ help="""
85
+ How to intepret ambiguous Grascii strings.
86
+
87
+ - best: Only search using the best interpretation
88
+ - all: Search using all possible interpretations.
89
+ """,
90
+ )
91
+ uncertainty = st.slider(
92
+ "Uncertainty",
93
+ min_value=0,
94
+ max_value=2,
95
+ help="The uncertainty of the strokes in the Grascii string",
96
+ )
97
+ fix_first = st.checkbox(
98
+ "Fix First", help="Apply an uncertainty of 0 to the first token"
99
+ )
100
+ search_mode = st.selectbox(
101
+ "Search Mode",
102
+ ["match", "start", "contain"],
103
+ help="""
104
+ The type of search to perform.
105
+
106
+ - match: Search for entries that closely match the Grascii string
107
+ - start: Search for entries that start with the Grascii string
108
+ - contain: Search for entries that contain the Grascii string
109
+ """,
110
+ )
111
+ annotation_mode = st.selectbox(
112
+ "Annotation Mode",
113
+ ["strict", "retain", "discard"],
114
+ index=2,
115
+ help="""
116
+ How to handle Grascii annotations.
117
+
118
+ - discard: Annotations are discarded.
119
+ Search results may contain annotations in any location.
120
+ - retain: Annotations in the input must appear in search results.
121
+ Other annotations may appear in the results.
122
+ - strict: Annotations in the input must appear in search results.
123
+ Other annotations may not appear in the results.
124
+ """,
125
+ )
126
+ aspirate_mode = st.selectbox(
127
+ "Aspirate Mode",
128
+ ["strict", "retain", "discard"],
129
+ index=2,
130
+ help="""
131
+ How to handle Grascii asirates (').
132
+
133
+ - discard: Aspirates are discarded.
134
+ Search results may contain aspirates in any location.
135
+ - retain: Aspirates in the input must appear in search results.
136
+ Other aspirates may appear in the results.
137
+ - strict: Aspirates in the input must appear in search results.
138
+ Other aspirates may not appear in the results.
139
+ """,
140
+ )
141
+ disjoiner_mode = st.selectbox(
142
+ "Disjoiner Mode",
143
+ ["strict", "retain", "discard"],
144
+ index=0,
145
+ help="""
146
+ How to handle Grascii disjoiners (^).
147
+
148
+ - discard: Disjoiners are discarded.
149
+ Search results may contain disjoiners in any location.
150
+ - retain: Disjoiners in the input must appear in search results.
151
+ Other disjoiners may appear in the results.
152
+ - strict: Disjoiners in the input must appear in search results.
153
+ Other disjoiners may not appear in the results.
154
+ """,
155
+ )
156
+
157
+ st.form_submit_button("Search", on_click=set_grascii)
158
+
159
+ grascii = st.session_state["grascii"]
160
+
161
+ try:
162
+ grascii_results = searcher.sorted_search(
163
+ grascii=grascii,
164
+ interpretation=interpretation,
165
+ uncertainty=uncertainty,
166
+ fix_first=fix_first,
167
+ search_mode=search_mode,
168
+ annotation_mode=annotation_mode,
169
+ aspirate_mode=aspirate_mode,
170
+ disjoiner_mode=disjoiner_mode,
171
+ )
172
+ except InvalidGrascii as e:
173
+ if grascii:
174
+ st.error(f"Invalid Grascii\n```\n{e.context}\n```")
175
+ else:
176
+ write_results(grascii_results, grascii.upper(), "grascii")
177
+
178
+
179
+ @st.fragment
180
+ def write_results(results, term, key_prefix):
181
+ rows = map(
182
+ lambda r: [
183
+ r.entry.grascii,
184
+ r.entry.translation,
185
+ image_map.get(r.entry.translation),
186
+ ],
187
+ results,
188
+ )
189
+ data = pd.DataFrame(rows)
190
+
191
+ r = "Results" if len(data) != 1 else "Result"
192
+ st.write(f'{len(data)} {r} for "{term}"')
193
+
194
+ event = st.dataframe(
195
+ data,
196
+ use_container_width=True,
197
+ column_config={
198
+ "0": "Grascii",
199
+ "1": "Longhand",
200
+ "2": st.column_config.ImageColumn("Shorthand", width="medium"),
201
+ },
202
+ selection_mode="multi-row",
203
+ on_select="rerun",
204
+ key=key_prefix + "_data_frame",
205
+ )
206
+ selected_rows = event.selection.rows
207
+
208
+ if st.button(
209
+ "Flag Selected Rows",
210
+ key=key_prefix + "_report_button",
211
+ disabled=len(selected_rows) == 0,
212
+ ):
213
+ report_dialog(data.iloc[selected_rows])
214
+
215
+
216
+ def write_reverse_search():
217
+ searcher = ReverseSearcher()
218
+ reverse_results = []
219
+
220
+ with st.form("Reverse Search"):
221
+ word = st.text_input("Word(s)")
222
+
223
+ st.form_submit_button("Search")
224
+
225
+ if word:
226
+ reverse_results = searcher.sorted_search(
227
+ reverse=word,
228
+ )
229
+ if word:
230
+ write_results(reverse_results, word, "reverse")
vision.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import (
3
+ PreTrainedTokenizerFast,
4
+ VisionEncoderDecoderModel,
5
+ ViTImageProcessor,
6
+ )
7
+
8
+ model_name = "grascii/gregg-vision-v0.2.1"
9
+
10
+
11
+ @st.cache_resource(show_spinner=f"Loading {model_name}")
12
+ def load_model():
13
+ model = VisionEncoderDecoderModel.from_pretrained(
14
+ model_name, token=st.secrets.HF_TOKEN
15
+ )
16
+ tokenizer = PreTrainedTokenizerFast.from_pretrained(
17
+ model_name,
18
+ token=st.secrets.HF_TOKEN,
19
+ )
20
+ processor = ViTImageProcessor.from_pretrained(model_name, token=st.secrets.HF_TOKEN)
21
+ return model, tokenizer, processor
22
+
23
+
24
+ @st.cache_data(ttl=3600, show_spinner=f"Running {model_name}")
25
+ def run_vision(image):
26
+ model, tokenizer, processor = load_model()
27
+ pixel_values = processor(image, return_tensors="pt").pixel_values
28
+ generated = model.generate(pixel_values, max_new_tokens=12)[0]
29
+ return tokenizer.convert_ids_to_tokens(generated, skip_special_tokens=True)