Commit
Β·
5bcc73a
1
Parent(s):
f93b005
Create streamlit app
Browse files- .gitignore +4 -0
- .streamlit/config.toml +5 -0
- README.md +12 -5
- app.py +36 -0
- report.py +115 -0
- requirements.txt +84 -0
- save_image.py +23 -0
- scheduler.py +191 -0
- search.py +230 -0
- vision.py +29 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.streamlit/secrets.toml
|
2 |
+
env/
|
3 |
+
__pycache__/
|
4 |
+
flagged_rows/
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[browser]
|
2 |
+
gatherUsageStats = false
|
3 |
+
|
4 |
+
[server]
|
5 |
+
maxUploadSize = 5
|
README.md
CHANGED
@@ -1,12 +1,19 @@
|
|
1 |
---
|
2 |
-
title: Search
|
3 |
-
emoji:
|
4 |
colorFrom: gray
|
5 |
colorTo: green
|
6 |
-
sdk:
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Grascii Search
|
3 |
+
emoji: π
|
4 |
colorFrom: gray
|
5 |
colorTo: green
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.40.2
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
+
models:
|
11 |
+
- grascii/gregg-vision-v0.2.1
|
12 |
+
datasets:
|
13 |
+
- grascii/gregg-preanniversary-words
|
14 |
+
preload_from_hub:
|
15 |
+
- grascii/gregg-vision-v0.2.1
|
16 |
+
- grascii/gregg-preanniversary-words
|
17 |
---
|
18 |
|
19 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
st.set_page_config(
|
4 |
+
page_title="Grascii Search",
|
5 |
+
menu_items={
|
6 |
+
"About": """
|
7 |
+
Web interface for [grascii](https://github.com/grascii/grascii)'s
|
8 |
+
search utility
|
9 |
+
|
10 |
+
Image search powered by [gregg-vision-v0.2.1](https://huggingface.co/grascii/gregg-vision-v0.2.1)
|
11 |
+
"""
|
12 |
+
},
|
13 |
+
)
|
14 |
+
|
15 |
+
import pandas as pd # noqa E402
|
16 |
+
from search import write_grascii_search, write_reverse_search # noqa E402
|
17 |
+
|
18 |
+
pd.options.mode.copy_on_write = True
|
19 |
+
|
20 |
+
if "report_submitted" not in st.session_state:
|
21 |
+
st.session_state["report_submitted"] = False
|
22 |
+
|
23 |
+
if "grascii" not in st.session_state:
|
24 |
+
st.session_state["grascii"] = ""
|
25 |
+
|
26 |
+
if st.session_state["report_submitted"]:
|
27 |
+
st.toast("Thanks for the report!")
|
28 |
+
st.session_state["report_submitted"] = False
|
29 |
+
|
30 |
+
tab1, tab2 = st.tabs(["Grascii", "Reverse"])
|
31 |
+
|
32 |
+
with tab1:
|
33 |
+
write_grascii_search()
|
34 |
+
|
35 |
+
with tab2:
|
36 |
+
write_reverse_search()
|
report.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from pathlib import Path
|
3 |
+
from uuid import uuid4
|
4 |
+
import csv
|
5 |
+
from datetime import datetime, timezone
|
6 |
+
|
7 |
+
from huggingface_hub import CommitScheduler
|
8 |
+
|
9 |
+
|
10 |
+
CSV_DATASET_DIR = Path("flagged_rows")
|
11 |
+
CSV_DATASET_DIR.mkdir(parents=True, exist_ok=True)
|
12 |
+
|
13 |
+
CSV_DATASET_PATH = CSV_DATASET_DIR / f"train-{uuid4()}.csv"
|
14 |
+
|
15 |
+
wrote_header = False
|
16 |
+
|
17 |
+
|
18 |
+
def write_header(writer):
|
19 |
+
writer.writerow(
|
20 |
+
[
|
21 |
+
"date",
|
22 |
+
"grascii",
|
23 |
+
"longhand",
|
24 |
+
"incorrect_grascii",
|
25 |
+
"incorrect_longhand",
|
26 |
+
"incorrect_shorthand",
|
27 |
+
"improperly_cropped",
|
28 |
+
"extraneous_marks",
|
29 |
+
]
|
30 |
+
)
|
31 |
+
global wrote_header
|
32 |
+
wrote_header = True
|
33 |
+
|
34 |
+
|
35 |
+
scheduler = CommitScheduler(
|
36 |
+
repo_id=st.secrets.FEEDBACK_REPO,
|
37 |
+
repo_type="dataset",
|
38 |
+
folder_path=CSV_DATASET_DIR,
|
39 |
+
path_in_repo="data",
|
40 |
+
every=15,
|
41 |
+
token=st.secrets.HF_TOKEN,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
@st.dialog("Flag Results for Review", width="large")
|
46 |
+
def report_dialog(data):
|
47 |
+
st.write("Please select one or more reasons for flagging each row:")
|
48 |
+
|
49 |
+
report_df = data
|
50 |
+
report_df["3"] = True
|
51 |
+
report_df["4"] = False
|
52 |
+
report_df["5"] = False
|
53 |
+
report_df["6"] = False
|
54 |
+
report_df["7"] = False
|
55 |
+
report_df["8"] = False
|
56 |
+
final_report = st.data_editor(
|
57 |
+
report_df,
|
58 |
+
hide_index=True,
|
59 |
+
column_config={
|
60 |
+
"0": "Grascii",
|
61 |
+
"1": "Longhand",
|
62 |
+
"2": st.column_config.ImageColumn("Shorthand", width="medium"),
|
63 |
+
"3": st.column_config.CheckboxColumn("Flag"),
|
64 |
+
"4": st.column_config.CheckboxColumn("Grascii is incorrect"),
|
65 |
+
"5": st.column_config.CheckboxColumn("Longhand is incorrect"),
|
66 |
+
"6": st.column_config.CheckboxColumn("Shorthand image is incorrect"),
|
67 |
+
"7": st.column_config.CheckboxColumn(
|
68 |
+
"Shorthand image is improperly cropped"
|
69 |
+
),
|
70 |
+
"8": st.column_config.CheckboxColumn(
|
71 |
+
"Shorthand image contains extraneous marks"
|
72 |
+
),
|
73 |
+
},
|
74 |
+
disabled=["0", "1", "2"],
|
75 |
+
use_container_width=True,
|
76 |
+
)
|
77 |
+
|
78 |
+
st.write(
|
79 |
+
"If you decide that a listed row does not need to be flagged, uncheck its 'Flag' box to prevent it from being included in the submission."
|
80 |
+
)
|
81 |
+
|
82 |
+
if st.button("Submit"):
|
83 |
+
with scheduler.lock:
|
84 |
+
with open(CSV_DATASET_PATH, "a", newline="") as f:
|
85 |
+
writer = csv.writer(f, dialect="unix")
|
86 |
+
|
87 |
+
def write_row(row):
|
88 |
+
if not wrote_header:
|
89 |
+
write_header(writer)
|
90 |
+
if row.iloc[3] and any(
|
91 |
+
[
|
92 |
+
row.iloc[4],
|
93 |
+
row.iloc[5],
|
94 |
+
row.iloc[6],
|
95 |
+
row.iloc[7],
|
96 |
+
row.iloc[8],
|
97 |
+
]
|
98 |
+
):
|
99 |
+
writer.writerow(
|
100 |
+
[
|
101 |
+
datetime.now(timezone.utc).date(),
|
102 |
+
row.iloc[0],
|
103 |
+
row.iloc[1],
|
104 |
+
1 if row.iloc[4] else 0,
|
105 |
+
1 if row.iloc[5] else 0,
|
106 |
+
1 if row.iloc[6] else 0,
|
107 |
+
1 if row.iloc[7] else 0,
|
108 |
+
1 if row.iloc[8] else 0,
|
109 |
+
]
|
110 |
+
)
|
111 |
+
|
112 |
+
final_report.apply(write_row, axis=1)
|
113 |
+
|
114 |
+
st.session_state["report_submitted"] = True
|
115 |
+
st.rerun()
|
requirements.txt
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==1.2.0
|
2 |
+
aiohappyeyeballs==2.4.4
|
3 |
+
aiohttp==3.11.10
|
4 |
+
aiosignal==1.3.1
|
5 |
+
altair==5.5.0
|
6 |
+
async-timeout==5.0.1
|
7 |
+
attrs==24.2.0
|
8 |
+
blinker==1.9.0
|
9 |
+
cachetools==5.5.0
|
10 |
+
certifi==2024.8.30
|
11 |
+
charset-normalizer==3.4.0
|
12 |
+
click==8.1.7
|
13 |
+
datasets==3.1.0
|
14 |
+
dill==0.3.8
|
15 |
+
filelock==3.16.1
|
16 |
+
frozenlist==1.5.0
|
17 |
+
fsspec==2024.9.0
|
18 |
+
gitdb==4.0.11
|
19 |
+
GitPython==3.1.43
|
20 |
+
grascii==0.6.0
|
21 |
+
huggingface-hub==0.26.5
|
22 |
+
idna==3.10
|
23 |
+
Jinja2==3.1.4
|
24 |
+
jsonschema==4.23.0
|
25 |
+
jsonschema-specifications==2024.10.1
|
26 |
+
lark==1.2.2
|
27 |
+
markdown-it-py==3.0.0
|
28 |
+
MarkupSafe==3.0.2
|
29 |
+
mdurl==0.1.2
|
30 |
+
mpmath==1.3.0
|
31 |
+
multidict==6.1.0
|
32 |
+
multiprocess==0.70.16
|
33 |
+
narwhals==1.16.0
|
34 |
+
networkx==3.4.2
|
35 |
+
numpy==2.2.0
|
36 |
+
nvidia-cublas-cu12==12.4.5.8
|
37 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
38 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
39 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
40 |
+
nvidia-cudnn-cu12==9.1.0.70
|
41 |
+
nvidia-cufft-cu12==11.2.1.3
|
42 |
+
nvidia-curand-cu12==10.3.5.147
|
43 |
+
nvidia-cusolver-cu12==11.6.1.9
|
44 |
+
nvidia-cusparse-cu12==12.3.1.170
|
45 |
+
nvidia-nccl-cu12==2.21.5
|
46 |
+
nvidia-nvjitlink-cu12==12.4.127
|
47 |
+
nvidia-nvtx-cu12==12.4.127
|
48 |
+
packaging==24.2
|
49 |
+
pandas==2.2.3
|
50 |
+
pillow==11.0.0
|
51 |
+
platformdirs==4.3.6
|
52 |
+
propcache==0.2.1
|
53 |
+
protobuf==5.29.1
|
54 |
+
psutil==6.1.0
|
55 |
+
pyarrow==18.1.0
|
56 |
+
pydeck==0.9.1
|
57 |
+
Pygments==2.18.0
|
58 |
+
python-dateutil==2.9.0.post0
|
59 |
+
pytz==2024.2
|
60 |
+
PyYAML==6.0.2
|
61 |
+
referencing==0.35.1
|
62 |
+
regex==2024.11.6
|
63 |
+
requests==2.32.3
|
64 |
+
rich==13.9.4
|
65 |
+
rpds-py==0.22.3
|
66 |
+
safetensors==0.4.5
|
67 |
+
six==1.17.0
|
68 |
+
smmap==5.0.1
|
69 |
+
streamlit==1.40.2
|
70 |
+
sympy==1.13.1
|
71 |
+
tenacity==9.0.0
|
72 |
+
tokenizers==0.21.0
|
73 |
+
toml==0.10.2
|
74 |
+
torch==2.5.1
|
75 |
+
tornado==6.4.2
|
76 |
+
tqdm==4.67.1
|
77 |
+
transformers==4.47.0
|
78 |
+
triton==3.1.0
|
79 |
+
typing_extensions==4.12.2
|
80 |
+
tzdata==2024.2
|
81 |
+
urllib3==2.2.3
|
82 |
+
watchdog==6.0.0
|
83 |
+
xxhash==3.5.0
|
84 |
+
yarl==1.18.3
|
save_image.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from hashlib import sha256
|
2 |
+
import streamlit as st
|
3 |
+
from datetime import datetime, timezone
|
4 |
+
from scheduler import ParquetScheduler
|
5 |
+
|
6 |
+
|
7 |
+
scheduler = ParquetScheduler(
|
8 |
+
repo_id=st.secrets.IMAGES_REPO,
|
9 |
+
token=st.secrets.HF_TOKEN,
|
10 |
+
every=15,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
@st.cache_data(ttl=3600)
|
15 |
+
def save_image(data, prediction):
|
16 |
+
scheduler.append(
|
17 |
+
{
|
18 |
+
"date": datetime.now(timezone.utc).date(),
|
19 |
+
"image": data,
|
20 |
+
"prediction": prediction,
|
21 |
+
"sha256": sha256(data).hexdigest(),
|
22 |
+
}
|
23 |
+
)
|
scheduler.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
MIT License
|
3 |
+
|
4 |
+
Copyright (c) 2023 hysts
|
5 |
+
|
6 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
of this software and associated documentation files (the "Software"), to deal
|
8 |
+
in the Software without restriction, including without limitation the rights
|
9 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
copies of the Software, and to permit persons to whom the Software is
|
11 |
+
furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
The above copyright notice and this permission notice shall be included in all
|
14 |
+
copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
SOFTWARE.
|
23 |
+
"""
|
24 |
+
|
25 |
+
import json
|
26 |
+
import tempfile
|
27 |
+
import uuid
|
28 |
+
from pathlib import Path
|
29 |
+
from typing import Any, Dict, List, Optional, Union
|
30 |
+
|
31 |
+
import pyarrow as pa
|
32 |
+
import pyarrow.parquet as pq
|
33 |
+
from huggingface_hub import CommitScheduler, HfApi
|
34 |
+
|
35 |
+
|
36 |
+
class ParquetScheduler(CommitScheduler):
|
37 |
+
"""
|
38 |
+
Usage: configure the scheduler with a repo id. Once started, you can add data to be uploaded to the Hub. 1 `.append`
|
39 |
+
call will result in 1 row in your final dataset.
|
40 |
+
|
41 |
+
```py
|
42 |
+
# Start scheduler
|
43 |
+
>>> scheduler = ParquetScheduler(repo_id="my-parquet-dataset")
|
44 |
+
|
45 |
+
# Append some data to be uploaded
|
46 |
+
>>> scheduler.append({...})
|
47 |
+
>>> scheduler.append({...})
|
48 |
+
>>> scheduler.append({...})
|
49 |
+
```
|
50 |
+
|
51 |
+
The scheduler will automatically infer the schema from the data it pushes.
|
52 |
+
Optionally, you can manually set the schema yourself:
|
53 |
+
|
54 |
+
```py
|
55 |
+
>>> scheduler = ParquetScheduler(
|
56 |
+
... repo_id="my-parquet-dataset",
|
57 |
+
... schema={
|
58 |
+
... "prompt": {"_type": "Value", "dtype": "string"},
|
59 |
+
... "negative_prompt": {"_type": "Value", "dtype": "string"},
|
60 |
+
... "guidance_scale": {"_type": "Value", "dtype": "int64"},
|
61 |
+
... "image": {"_type": "Image"},
|
62 |
+
... },
|
63 |
+
... )
|
64 |
+
|
65 |
+
See https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value for the list of
|
66 |
+
possible values.
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
*,
|
72 |
+
repo_id: str,
|
73 |
+
schema: Optional[Dict[str, Dict[str, str]]] = None,
|
74 |
+
every: Union[int, float] = 5,
|
75 |
+
revision: Optional[str] = None,
|
76 |
+
private: bool = False,
|
77 |
+
token: Optional[str] = None,
|
78 |
+
allow_patterns: Union[List[str], str, None] = None,
|
79 |
+
ignore_patterns: Union[List[str], str, None] = None,
|
80 |
+
hf_api: Optional[HfApi] = None,
|
81 |
+
) -> None:
|
82 |
+
super().__init__(
|
83 |
+
repo_id=repo_id,
|
84 |
+
folder_path=tempfile.tempdir, # not used by the scheduler
|
85 |
+
every=every,
|
86 |
+
repo_type="dataset",
|
87 |
+
revision=revision,
|
88 |
+
private=private,
|
89 |
+
token=token,
|
90 |
+
allow_patterns=allow_patterns,
|
91 |
+
ignore_patterns=ignore_patterns,
|
92 |
+
hf_api=hf_api,
|
93 |
+
)
|
94 |
+
|
95 |
+
self._rows: List[Dict[str, Any]] = []
|
96 |
+
self._schema = schema
|
97 |
+
|
98 |
+
def append(self, row: Dict[str, Any]) -> None:
|
99 |
+
"""Add a new item to be uploaded."""
|
100 |
+
with self.lock:
|
101 |
+
self._rows.append(row)
|
102 |
+
|
103 |
+
def push_to_hub(self):
|
104 |
+
# Check for new rows to push
|
105 |
+
with self.lock:
|
106 |
+
rows = self._rows
|
107 |
+
self._rows = []
|
108 |
+
if not rows:
|
109 |
+
return
|
110 |
+
print(f"Got {len(rows)} item(s) to commit.")
|
111 |
+
|
112 |
+
# Load images + create 'features' config for datasets library
|
113 |
+
schema: Dict[str, Dict] = self._schema or {}
|
114 |
+
path_to_cleanup: List[Path] = []
|
115 |
+
for row in rows:
|
116 |
+
for key, value in row.items():
|
117 |
+
# Infer schema (for `datasets` library)
|
118 |
+
if key not in schema:
|
119 |
+
schema[key] = _infer_schema(key, value)
|
120 |
+
|
121 |
+
# Load binary files if necessary
|
122 |
+
if schema[key]["_type"] in ("Image", "Audio"):
|
123 |
+
if isinstance(value, bytes):
|
124 |
+
row[key] = {
|
125 |
+
"path": "",
|
126 |
+
"bytes": value,
|
127 |
+
}
|
128 |
+
else:
|
129 |
+
# It's an image or audio: we load the bytes and remember to cleanup the file
|
130 |
+
file_path = Path(value)
|
131 |
+
if file_path.is_file():
|
132 |
+
row[key] = {
|
133 |
+
"path": file_path.name,
|
134 |
+
"bytes": file_path.read_bytes(),
|
135 |
+
}
|
136 |
+
path_to_cleanup.append(file_path)
|
137 |
+
|
138 |
+
# Complete rows if needed
|
139 |
+
for row in rows:
|
140 |
+
for feature in schema:
|
141 |
+
if feature not in row:
|
142 |
+
row[feature] = None
|
143 |
+
|
144 |
+
# Export items to Arrow format
|
145 |
+
table = pa.Table.from_pylist(rows)
|
146 |
+
|
147 |
+
# Add metadata (used by datasets library)
|
148 |
+
table = table.replace_schema_metadata(
|
149 |
+
{"huggingface": json.dumps({"info": {"features": schema}})}
|
150 |
+
)
|
151 |
+
|
152 |
+
# Write to parquet file
|
153 |
+
archive_file = tempfile.NamedTemporaryFile()
|
154 |
+
pq.write_table(table, archive_file.name)
|
155 |
+
|
156 |
+
# Upload
|
157 |
+
self.api.upload_file(
|
158 |
+
repo_id=self.repo_id,
|
159 |
+
repo_type=self.repo_type,
|
160 |
+
revision=self.revision,
|
161 |
+
path_in_repo=f"{uuid.uuid4()}.parquet",
|
162 |
+
path_or_fileobj=archive_file.name,
|
163 |
+
)
|
164 |
+
print("Commit completed.")
|
165 |
+
|
166 |
+
# Cleanup
|
167 |
+
archive_file.close()
|
168 |
+
for path in path_to_cleanup:
|
169 |
+
path.unlink(missing_ok=True)
|
170 |
+
|
171 |
+
|
172 |
+
def _infer_schema(key: str, value: Any) -> Dict[str, str]:
|
173 |
+
"""Infer schema for the `datasets` library.
|
174 |
+
|
175 |
+
See
|
176 |
+
https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value.
|
177 |
+
"""
|
178 |
+
if "image" in key:
|
179 |
+
return {"_type": "Image"}
|
180 |
+
if "audio" in key:
|
181 |
+
return {"_type": "Audio"}
|
182 |
+
if isinstance(value, int):
|
183 |
+
return {"_type": "Value", "dtype": "int64"}
|
184 |
+
if isinstance(value, float):
|
185 |
+
return {"_type": "Value", "dtype": "float64"}
|
186 |
+
if isinstance(value, bool):
|
187 |
+
return {"_type": "Value", "dtype": "bool"}
|
188 |
+
if isinstance(value, bytes):
|
189 |
+
return {"_type": "Value", "dtype": "binary"}
|
190 |
+
# Otherwise in last resort => convert it to a string
|
191 |
+
return {"_type": "Value", "dtype": "string"}
|
search.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
from io import BytesIO
|
3 |
+
import numpy as np
|
4 |
+
import streamlit as st
|
5 |
+
from PIL import Image
|
6 |
+
import pandas as pd
|
7 |
+
from datasets import load_dataset
|
8 |
+
from grascii import GrasciiSearcher, InvalidGrascii, ReverseSearcher
|
9 |
+
from report import report_dialog
|
10 |
+
from vision import run_vision
|
11 |
+
from save_image import save_image
|
12 |
+
|
13 |
+
|
14 |
+
@st.cache_data(show_spinner="Loading shorthand images")
|
15 |
+
def load_images():
|
16 |
+
ds = load_dataset(
|
17 |
+
"grascii/gregg-preanniversary-words", split="train", token=st.secrets.HF_TOKEN
|
18 |
+
)
|
19 |
+
image_map = {}
|
20 |
+
for row in ds:
|
21 |
+
buffered = BytesIO()
|
22 |
+
row["image"].save(buffered, format="PNG")
|
23 |
+
b64 = base64.b64encode(buffered.getvalue())
|
24 |
+
image_map[row["longhand"]] = "data:image/png;base64," + b64.decode("utf-8")
|
25 |
+
return image_map
|
26 |
+
|
27 |
+
|
28 |
+
image_map = load_images()
|
29 |
+
|
30 |
+
|
31 |
+
def set_grascii():
|
32 |
+
if "grascii_text_box" in st.session_state:
|
33 |
+
st.session_state["grascii"] = st.session_state["grascii_text_box"]
|
34 |
+
|
35 |
+
|
36 |
+
def write_grascii_search():
|
37 |
+
searcher = GrasciiSearcher()
|
38 |
+
grascii_results = []
|
39 |
+
|
40 |
+
search_by = st.radio("Search by", ["text", "image (beta)"], horizontal=True)
|
41 |
+
|
42 |
+
with st.form("Grascii Search"):
|
43 |
+
placeholder = st.empty()
|
44 |
+
if search_by == "text":
|
45 |
+
placeholder.text_input(
|
46 |
+
"Grascii", value=st.session_state["grascii"], key="grascii_text_box"
|
47 |
+
)
|
48 |
+
else:
|
49 |
+
image_data = placeholder.file_uploader(
|
50 |
+
"Image",
|
51 |
+
type=["png", "jpg"],
|
52 |
+
help="""
|
53 |
+
Upload an image of a shorthand form.
|
54 |
+
|
55 |
+
At this time, minimal preprocessing is performed on images
|
56 |
+
before running them through the model. For best results,
|
57 |
+
upload an image:
|
58 |
+
|
59 |
+
- of a closely cropped, single shorthand form
|
60 |
+
- with the shorthand written in black on a white background
|
61 |
+
- that does not contain marks beside the shorthand form
|
62 |
+
""",
|
63 |
+
)
|
64 |
+
save = st.checkbox(
|
65 |
+
"Save images I upload for potential inclusion in open-source datasets used to train and improve models"
|
66 |
+
)
|
67 |
+
|
68 |
+
if image_data:
|
69 |
+
image = Image.open(image_data).convert("RGBA")
|
70 |
+
background = Image.new("RGBA", image.size, (255, 255, 255))
|
71 |
+
alpha_composite = Image.alpha_composite(background, image)
|
72 |
+
|
73 |
+
arr = np.array([alpha_composite.convert("L")])
|
74 |
+
tokens = run_vision(arr)
|
75 |
+
st.session_state["grascii"] = "".join(tokens)
|
76 |
+
if save:
|
77 |
+
save_image(image_data.getvalue(), "-".join(tokens))
|
78 |
+
|
79 |
+
with st.expander("Options"):
|
80 |
+
interpretation = st.radio(
|
81 |
+
"Interpretation",
|
82 |
+
["best", "all"],
|
83 |
+
horizontal=True,
|
84 |
+
help="""
|
85 |
+
How to intepret ambiguous Grascii strings.
|
86 |
+
|
87 |
+
- best: Only search using the best interpretation
|
88 |
+
- all: Search using all possible interpretations.
|
89 |
+
""",
|
90 |
+
)
|
91 |
+
uncertainty = st.slider(
|
92 |
+
"Uncertainty",
|
93 |
+
min_value=0,
|
94 |
+
max_value=2,
|
95 |
+
help="The uncertainty of the strokes in the Grascii string",
|
96 |
+
)
|
97 |
+
fix_first = st.checkbox(
|
98 |
+
"Fix First", help="Apply an uncertainty of 0 to the first token"
|
99 |
+
)
|
100 |
+
search_mode = st.selectbox(
|
101 |
+
"Search Mode",
|
102 |
+
["match", "start", "contain"],
|
103 |
+
help="""
|
104 |
+
The type of search to perform.
|
105 |
+
|
106 |
+
- match: Search for entries that closely match the Grascii string
|
107 |
+
- start: Search for entries that start with the Grascii string
|
108 |
+
- contain: Search for entries that contain the Grascii string
|
109 |
+
""",
|
110 |
+
)
|
111 |
+
annotation_mode = st.selectbox(
|
112 |
+
"Annotation Mode",
|
113 |
+
["strict", "retain", "discard"],
|
114 |
+
index=2,
|
115 |
+
help="""
|
116 |
+
How to handle Grascii annotations.
|
117 |
+
|
118 |
+
- discard: Annotations are discarded.
|
119 |
+
Search results may contain annotations in any location.
|
120 |
+
- retain: Annotations in the input must appear in search results.
|
121 |
+
Other annotations may appear in the results.
|
122 |
+
- strict: Annotations in the input must appear in search results.
|
123 |
+
Other annotations may not appear in the results.
|
124 |
+
""",
|
125 |
+
)
|
126 |
+
aspirate_mode = st.selectbox(
|
127 |
+
"Aspirate Mode",
|
128 |
+
["strict", "retain", "discard"],
|
129 |
+
index=2,
|
130 |
+
help="""
|
131 |
+
How to handle Grascii asirates (').
|
132 |
+
|
133 |
+
- discard: Aspirates are discarded.
|
134 |
+
Search results may contain aspirates in any location.
|
135 |
+
- retain: Aspirates in the input must appear in search results.
|
136 |
+
Other aspirates may appear in the results.
|
137 |
+
- strict: Aspirates in the input must appear in search results.
|
138 |
+
Other aspirates may not appear in the results.
|
139 |
+
""",
|
140 |
+
)
|
141 |
+
disjoiner_mode = st.selectbox(
|
142 |
+
"Disjoiner Mode",
|
143 |
+
["strict", "retain", "discard"],
|
144 |
+
index=0,
|
145 |
+
help="""
|
146 |
+
How to handle Grascii disjoiners (^).
|
147 |
+
|
148 |
+
- discard: Disjoiners are discarded.
|
149 |
+
Search results may contain disjoiners in any location.
|
150 |
+
- retain: Disjoiners in the input must appear in search results.
|
151 |
+
Other disjoiners may appear in the results.
|
152 |
+
- strict: Disjoiners in the input must appear in search results.
|
153 |
+
Other disjoiners may not appear in the results.
|
154 |
+
""",
|
155 |
+
)
|
156 |
+
|
157 |
+
st.form_submit_button("Search", on_click=set_grascii)
|
158 |
+
|
159 |
+
grascii = st.session_state["grascii"]
|
160 |
+
|
161 |
+
try:
|
162 |
+
grascii_results = searcher.sorted_search(
|
163 |
+
grascii=grascii,
|
164 |
+
interpretation=interpretation,
|
165 |
+
uncertainty=uncertainty,
|
166 |
+
fix_first=fix_first,
|
167 |
+
search_mode=search_mode,
|
168 |
+
annotation_mode=annotation_mode,
|
169 |
+
aspirate_mode=aspirate_mode,
|
170 |
+
disjoiner_mode=disjoiner_mode,
|
171 |
+
)
|
172 |
+
except InvalidGrascii as e:
|
173 |
+
if grascii:
|
174 |
+
st.error(f"Invalid Grascii\n```\n{e.context}\n```")
|
175 |
+
else:
|
176 |
+
write_results(grascii_results, grascii.upper(), "grascii")
|
177 |
+
|
178 |
+
|
179 |
+
@st.fragment
|
180 |
+
def write_results(results, term, key_prefix):
|
181 |
+
rows = map(
|
182 |
+
lambda r: [
|
183 |
+
r.entry.grascii,
|
184 |
+
r.entry.translation,
|
185 |
+
image_map.get(r.entry.translation),
|
186 |
+
],
|
187 |
+
results,
|
188 |
+
)
|
189 |
+
data = pd.DataFrame(rows)
|
190 |
+
|
191 |
+
r = "Results" if len(data) != 1 else "Result"
|
192 |
+
st.write(f'{len(data)} {r} for "{term}"')
|
193 |
+
|
194 |
+
event = st.dataframe(
|
195 |
+
data,
|
196 |
+
use_container_width=True,
|
197 |
+
column_config={
|
198 |
+
"0": "Grascii",
|
199 |
+
"1": "Longhand",
|
200 |
+
"2": st.column_config.ImageColumn("Shorthand", width="medium"),
|
201 |
+
},
|
202 |
+
selection_mode="multi-row",
|
203 |
+
on_select="rerun",
|
204 |
+
key=key_prefix + "_data_frame",
|
205 |
+
)
|
206 |
+
selected_rows = event.selection.rows
|
207 |
+
|
208 |
+
if st.button(
|
209 |
+
"Flag Selected Rows",
|
210 |
+
key=key_prefix + "_report_button",
|
211 |
+
disabled=len(selected_rows) == 0,
|
212 |
+
):
|
213 |
+
report_dialog(data.iloc[selected_rows])
|
214 |
+
|
215 |
+
|
216 |
+
def write_reverse_search():
|
217 |
+
searcher = ReverseSearcher()
|
218 |
+
reverse_results = []
|
219 |
+
|
220 |
+
with st.form("Reverse Search"):
|
221 |
+
word = st.text_input("Word(s)")
|
222 |
+
|
223 |
+
st.form_submit_button("Search")
|
224 |
+
|
225 |
+
if word:
|
226 |
+
reverse_results = searcher.sorted_search(
|
227 |
+
reverse=word,
|
228 |
+
)
|
229 |
+
if word:
|
230 |
+
write_results(reverse_results, word, "reverse")
|
vision.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from transformers import (
|
3 |
+
PreTrainedTokenizerFast,
|
4 |
+
VisionEncoderDecoderModel,
|
5 |
+
ViTImageProcessor,
|
6 |
+
)
|
7 |
+
|
8 |
+
model_name = "grascii/gregg-vision-v0.2.1"
|
9 |
+
|
10 |
+
|
11 |
+
@st.cache_resource(show_spinner=f"Loading {model_name}")
|
12 |
+
def load_model():
|
13 |
+
model = VisionEncoderDecoderModel.from_pretrained(
|
14 |
+
model_name, token=st.secrets.HF_TOKEN
|
15 |
+
)
|
16 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
17 |
+
model_name,
|
18 |
+
token=st.secrets.HF_TOKEN,
|
19 |
+
)
|
20 |
+
processor = ViTImageProcessor.from_pretrained(model_name, token=st.secrets.HF_TOKEN)
|
21 |
+
return model, tokenizer, processor
|
22 |
+
|
23 |
+
|
24 |
+
@st.cache_data(ttl=3600, show_spinner=f"Running {model_name}")
|
25 |
+
def run_vision(image):
|
26 |
+
model, tokenizer, processor = load_model()
|
27 |
+
pixel_values = processor(image, return_tensors="pt").pixel_values
|
28 |
+
generated = model.generate(pixel_values, max_new_tokens=12)[0]
|
29 |
+
return tokenizer.convert_ids_to_tokens(generated, skip_special_tokens=True)
|