Joshua Sundance Bailey commited on
Commit
171c1a6
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report to help us improve
4
+ title: ''
5
+ labels: bug
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug**
11
+ A clear and concise description of what the bug is.
12
+
13
+ **To Reproduce**
14
+ Steps to reproduce the behavior:
15
+ 1. Go to '...'
16
+ 2. Click on '....'
17
+ 3. Scroll down to '....'
18
+ 4. See error
19
+
20
+ **Expected behavior**
21
+ A clear and concise description of what you expected to happen.
22
+
23
+ **Screenshots**
24
+ If applicable, add screenshots to help explain your problem.
25
+
26
+ **Desktop (please complete the following information):**
27
+ - OS: [e.g. iOS]
28
+ - Browser [e.g. chrome, safari]
29
+ - Version [e.g. 22]
30
+
31
+ **Smartphone (please complete the following information):**
32
+ - Device: [e.g. iPhone6]
33
+ - OS: [e.g. iOS8.1]
34
+ - Browser [e.g. stock browser, safari]
35
+ - Version [e.g. 22]
36
+
37
+ **Additional context**
38
+ Add any other context about the problem here.
.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Feature request
3
+ about: Suggest an idea for this project
4
+ title: ''
5
+ labels: enhancement
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the solution you'd like**
11
+ A clear and concise description of what you want to happen.
12
+
13
+ **Describe alternatives you've considered**
14
+ A clear and concise description of any alternative solutions or features you've considered.
15
+
16
+ **Additional context**
17
+ Add any other context or screenshots about the feature request here.
.github/pull_request_template.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Thank you for contributing!
2
+ Before submitting this PR, please make sure:
3
+
4
+ - [ ] Your code builds clean without any errors or warnings
5
+ - [ ] Your code doesn't break anything we can't fix
6
+ - [ ] You have added appropriate tests
7
+
8
+ Please check one or more of the following to describe the nature of this PR:
9
+ - [ ] New feature
10
+ - [ ] Bug fix
11
+ - [ ] Documentation
12
+ - [ ] Other
.github/workflows/check-file-size-limit.yml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: 10 MB file size limit
2
+ on:
3
+ pull_request:
4
+ branches: [main]
5
+
6
+ jobs:
7
+ check-file-sizes:
8
+ runs-on: ubuntu-latest
9
+ steps:
10
+ - name: Check large files
11
+ uses: ActionsDesk/lfs-warning@v2.0
12
+ with:
13
+ filesizelimit: 10485760 # this is 10MB so we can sync to HF Spaces
14
+ token: ${{ secrets.WORKFLOW_GIT_ACCESS_TOKEN }}
.github/workflows/hf-space.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Push to HuggingFace Space
2
+
3
+ on:
4
+ push:
5
+ branches: [main]
6
+ workflow_dispatch:
7
+
8
+ jobs:
9
+ push-to-huggingface:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - uses: actions/checkout@v2
13
+ with:
14
+ fetch-depth: 0
15
+ token: ${{ secrets.WORKFLOW_GIT_ACCESS_TOKEN }}
16
+
17
+ - name: Push to HuggingFace Space
18
+ env:
19
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: |
21
+ git push https://joshuasundance:$HF_TOKEN@huggingface.co/spaces/joshuasundance/mtg-coloridentity main
.gitignore ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hf_cache/
2
+ govgis-nov2023/
3
+ *$py.class
4
+ *.chainlit
5
+ *.chroma
6
+ *.cover
7
+ *.egg
8
+ *.egg-info/
9
+ *.env
10
+ *.langchain.db
11
+ *.log
12
+ *.manifest
13
+ *.mo
14
+ *.pot
15
+ *.py,cover
16
+ *.py[cod]
17
+ *.sage.py
18
+ *.so
19
+ *.spec
20
+ .DS_STORE
21
+ .Python
22
+ .cache
23
+ .coverage
24
+ .coverage.*
25
+ .dmypy.json
26
+ .eggs/
27
+ .env
28
+ .hypothesis/
29
+ .idea
30
+ .installed.cfg
31
+ .ipynb_checkpoints
32
+ .mypy_cache/
33
+ .nox/
34
+ .pyre/
35
+ .pytest_cache/
36
+ .python-version
37
+ .ropeproject
38
+ .ruff_cache/
39
+ .scrapy
40
+ .spyderproject
41
+ .spyproject
42
+ .tox/
43
+ .venv
44
+ .vscode
45
+ .webassets-cache
46
+ /site
47
+ ENV/
48
+ MANIFEST
49
+ __pycache__
50
+ __pycache__/
51
+ __pypackages__/
52
+ build/
53
+ celerybeat-schedule
54
+ celerybeat.pid
55
+ coverage.xml
56
+ credentials.json
57
+ data/
58
+ db.sqlite3
59
+ db.sqlite3-journal
60
+ develop-eggs/
61
+ dist/
62
+ dmypy.json
63
+ docs/_build/
64
+ downloads/
65
+ eggs/
66
+ env.bak/
67
+ env/
68
+ fly.toml
69
+ htmlcov/
70
+ instance/
71
+ ipython_config.py
72
+ junk/
73
+ lib/
74
+ lib64/
75
+ local_settings.py
76
+ models/*.bin
77
+ nosetests.xml
78
+ lab/scratch/
79
+ lab/
80
+ parts/
81
+ pip-delete-this-directory.txt
82
+ pip-log.txt
83
+ pip-wheel-metadata/
84
+ profile_default/
85
+ sdist/
86
+ share/python-wheels/
87
+ storage
88
+ target/
89
+ token.json
90
+ var/
91
+ venv
92
+ venv.bak/
93
+ venv/
94
+ wheels/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Don't know what this file is? See https://pre-commit.com/
2
+ # pip install pre-commit
3
+ # pre-commit install
4
+ # pre-commit autoupdate
5
+ # Apply to all files without commiting:
6
+ # pre-commit run --all-files
7
+ # I recommend running this until you pass all checks, and then commit.
8
+ # Fix what you need to and then let the pre-commit hooks resolve their conflicts.
9
+ # You may need to git add -u between runs.
10
+ exclude: "AI_CHANGELOG.md"
11
+ repos:
12
+ - repo: https://github.com/charliermarsh/ruff-pre-commit
13
+ rev: "v0.1.15"
14
+ hooks:
15
+ - id: ruff
16
+ args: [--fix, --exit-non-zero-on-fix, --ignore, E501]
17
+ - repo: https://github.com/koalaman/shellcheck-precommit
18
+ rev: v0.9.0
19
+ hooks:
20
+ - id: shellcheck
21
+ - repo: https://github.com/pre-commit/pre-commit-hooks
22
+ rev: v4.5.0
23
+ hooks:
24
+ - id: check-ast
25
+ - id: check-builtin-literals
26
+ - id: check-merge-conflict
27
+ - id: check-symlinks
28
+ - id: check-toml
29
+ - id: check-xml
30
+ - id: debug-statements
31
+ - id: check-case-conflict
32
+ - id: check-docstring-first
33
+ - id: check-executables-have-shebangs
34
+ - id: check-json
35
+ # - id: check-yaml
36
+ - id: debug-statements
37
+ - id: fix-byte-order-marker
38
+ - id: detect-private-key
39
+ - id: end-of-file-fixer
40
+ - id: trailing-whitespace
41
+ - id: mixed-line-ending
42
+ - id: requirements-txt-fixer
43
+ - repo: https://github.com/pre-commit/mirrors-mypy
44
+ rev: v1.8.0
45
+ hooks:
46
+ - id: mypy
47
+ additional_dependencies:
48
+ - types-PyYAML
49
+ - repo: https://github.com/asottile/add-trailing-comma
50
+ rev: v3.1.0
51
+ hooks:
52
+ - id: add-trailing-comma
53
+ #- repo: https://github.com/dannysepler/rm_unneeded_f_str
54
+ # rev: v0.2.0
55
+ # hooks:
56
+ # - id: rm-unneeded-f-str
57
+ - repo: https://github.com/psf/black
58
+ rev: 24.1.1
59
+ hooks:
60
+ - id: black
61
+ - repo: https://github.com/PyCQA/bandit
62
+ rev: 1.7.7
63
+ hooks:
64
+ - id: bandit
65
+ args: ["-x", "tests/*.py"]
LICENSE ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Joshua Sundance Bailey
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
6
+
7
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
8
+
9
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: mtg-coloridentity
3
+ emoji: 🧙
4
+ colorFrom: white
5
+ colorTo: red
6
+ sdk: streamlit
7
+ sdk_version: 1.30.0
8
+ app_file: app.py
9
+ pinned: true
10
+ license: mit
11
+ ---
12
+
13
+ # mtg-coloridentity
14
+
15
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
16
+ [![python](https://img.shields.io/badge/Python-3-3776AB.svg?style=flat&logo=python&logoColor=white)](https://www.python.org)
17
+
18
+ [![Push to HuggingFace Space](https://github.com/joshuasundance-swca/mtg-coloridentity/actions/workflows/hf-space.yml/badge.svg)](https://github.com/joshuasundance-swca/mtg-coloridentity/actions/workflows/hf-space.yml)
19
+ [![Open HuggingFace Space](https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/joshuasundance/mtg-coloridentity)
20
+
21
+ [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit)
22
+ [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v1.json)](https://github.com/charliermarsh/ruff)
23
+ [![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/)
24
+ [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
25
+
26
+ [![security: bandit](https://img.shields.io/badge/security-bandit-yellow.svg)](https://github.com/PyCQA/bandit)
27
+
28
+
29
+ # mtg-coloridentity
30
+
31
+ 🤖 This README was written by GPT-4. 🤖
32
+
33
+ ## Overview
34
+ This Streamlit app is designed for the multi-label classification of Magic: The Gathering (MTG) cards,
35
+ specifically focusing on their color identity.
36
+ It utilizes a pre-trained model hosted on Hugging Face, `joshuasundance/mtg-coloridentity-multilabel-classification`,
37
+ to predict the color identity of MTG cards based on their names and descriptions.
38
+
39
+ ## Features
40
+ - Interactive UI: Users can input the name and text of any MTG card to get predictions on its color identity.
41
+ - Color Probabilities: The app displays the probability of each color identity (Black, Green, Red, Blue, White) for the given card.
42
+ - Random Card Selection: With a "Roll the Dice" feature, users can load the text of a random MTG card from the dataset.
43
+
44
+ ## How It Works
45
+ The app fetches a pre-trained `SetFit` model from Hugging Face and uses it to
46
+ predict the color identities of MTG cards.
47
+ The model's predictions are displayed as a bar chart,
48
+ showing the probability of each color identity.
49
+
50
+ ## Getting Started
51
+ To run this app locally, clone the repository and ensure you have the following prerequisites installed:
52
+
53
+ - Python 3.x
54
+ - `streamlit`
55
+ - `pandas`
56
+ - `seaborn`
57
+ - `matplotlib`
58
+ - `datasets` and `setfit` from Hugging Face
59
+
60
+ ## Contributions, Support, and Contact
61
+
62
+ Contributions to this project are welcome! Please feel free to submit issues and pull requests.
63
+
64
+ For support, please raise an issue on GitHub or in the HuggingFace space.
65
+
66
+ ## License
67
+
68
+ This project is under the [MIT License](LICENSE.md).
69
+
70
+ ## Acknowledgments
71
+
72
+ Thanks to HuggingFace and `setfit`!
73
+
74
+ ## TODO
75
+ - [ ] make a todo list ;)
76
+ - [ ] improve READMEs
77
+ - [ ] make better model(s)
78
+ - [x] learn in public
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import Sequence
4
+
5
+ import datasets
6
+ import matplotlib.pyplot as plt
7
+ import pandas as pd
8
+ import seaborn as sns
9
+ import streamlit as st
10
+ from setfit import SetFitModel
11
+
12
+ st.set_page_config(
13
+ page_title="mtg-coloridentity-multilabel-classification",
14
+ page_icon="🧙",
15
+ layout="wide",
16
+ initial_sidebar_state="collapsed",
17
+ menu_items=None,
18
+ )
19
+
20
+
21
+ default_hf_home = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
22
+ HF_HOME = os.environ.get("HF_HOME", default_hf_home)
23
+
24
+ coloridentity_model = "joshuasundance/mtg-coloridentity-multilabel-classification"
25
+
26
+ colors = ["B", "G", "R", "U", "W"]
27
+ labels = ["black", "green", "red", "blue", "white"]
28
+
29
+ sns.set()
30
+
31
+ col1, col2 = st.columns(2)
32
+
33
+
34
+ @st.cache_resource
35
+ def get_model(
36
+ model_id: str = coloridentity_model,
37
+ cache_dir: str = HF_HOME,
38
+ **kwargs,
39
+ ) -> SetFitModel:
40
+ return SetFitModel.from_pretrained(model_id, cache_dir=cache_dir, **kwargs)
41
+
42
+
43
+ @st.cache_data
44
+ def get_data(
45
+ repo_id: str = coloridentity_model,
46
+ cache_dir: str = HF_HOME,
47
+ **kwargs,
48
+ ) -> datasets.Dataset:
49
+ dataset_dict = datasets.load_dataset(repo_id, cache_dir=cache_dir, **kwargs)
50
+ return datasets.concatenate_datasets(
51
+ list(dataset_dict.values()),
52
+ )
53
+
54
+
55
+ def get_random_text() -> str:
56
+ return dataset.select([random.randint(0, len(dataset))])[0]["text"] # nosec
57
+
58
+
59
+ @st.cache_data
60
+ def get_preds(input_text: str) -> Sequence[float]:
61
+ return model.predict_proba(input_text)
62
+
63
+
64
+ def prob_bars(preds: Sequence[float]) -> None:
65
+ _preds = (float(p) for p in preds)
66
+ df = pd.DataFrame(
67
+ zip(labels, _preds),
68
+ columns=["Color", "Probability"],
69
+ )
70
+ plt.figure(figsize=(8, 6))
71
+ ax = sns.barplot(x="Color", y="Probability", data=df, palette=labels)
72
+
73
+ # Add data labels on each bar
74
+ for p in ax.patches:
75
+ ax.annotate(
76
+ format(p.get_height(), ".4f"),
77
+ (p.get_x() + p.get_width() / 2.0, p.get_height()),
78
+ ha="center",
79
+ va="center",
80
+ xytext=(0, 9),
81
+ textcoords="offset points",
82
+ )
83
+
84
+ plt.title("Prediction Probabilities")
85
+ plt.xlabel("Color")
86
+ plt.ylabel("Probability")
87
+ st.pyplot(plt.gcf())
88
+
89
+
90
+ model = get_model()
91
+ dataset = get_data()
92
+ default_text = get_random_text()
93
+
94
+ if "input_text" not in st.session_state:
95
+ st.session_state.input_text = default_text
96
+
97
+ with col1:
98
+ if st.button("🎲 Roll the Dice"):
99
+ st.session_state.input_text = get_random_text()
100
+ input_text = st.text_area(
101
+ "Card name and text",
102
+ st.session_state.input_text,
103
+ height=400,
104
+ )
105
+
106
+ preds = get_preds(input_text)
107
+
108
+ with col2:
109
+ prob_bars(preds)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets
2
+ matplotlib
3
+ pandas
4
+ seaborn
5
+ setfit
6
+ streamlit