Spaces:
Runtime error
Runtime error
Add first version
Browse files- .gitignore +151 -0
- README.md +4 -3
- app.py +78 -0
- artifacts/examples/basketball.jpg +0 -0
- artifacts/examples/cassowary.jpg +0 -0
- artifacts/examples/colosseum.jpg +0 -0
- artifacts/examples/desk.jpg +0 -0
- artifacts/examples/kitchen.jpg +0 -0
- artifacts/examples/log.csv +11 -0
- artifacts/examples/monkey.jpg +0 -0
- artifacts/examples/park.jpg +0 -0
- artifacts/examples/ramen.jpg +0 -0
- artifacts/examples/sagrada.jpg +0 -0
- artifacts/examples/venice.jpg +0 -0
- artifacts/models/databases/.gitkeep +0 -0
- artifacts/models/retrieval/indices.json +3 -0
- flagged/.gitkeep +0 -0
- requirements.txt +10 -0
- src/nn.py +330 -0
- src/retrieval.py +42 -0
- src/transforms.py +506 -0
.gitignore
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
|
30 |
+
# PyInstaller
|
31 |
+
# Usually these files are written by a python script from a template
|
32 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
33 |
+
*.manifest
|
34 |
+
*.spec
|
35 |
+
|
36 |
+
# Installer logs
|
37 |
+
pip-log.txt
|
38 |
+
pip-delete-this-directory.txt
|
39 |
+
|
40 |
+
# Unit test / coverage reports
|
41 |
+
htmlcov/
|
42 |
+
.tox/
|
43 |
+
.nox/
|
44 |
+
.coverage
|
45 |
+
.coverage.*
|
46 |
+
.cache
|
47 |
+
nosetests.xml
|
48 |
+
coverage.xml
|
49 |
+
*.cover
|
50 |
+
*.py,cover
|
51 |
+
.hypothesis/
|
52 |
+
.pytest_cache/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
target/
|
76 |
+
|
77 |
+
# Jupyter Notebook
|
78 |
+
.ipynb_checkpoints
|
79 |
+
|
80 |
+
# IPython
|
81 |
+
profile_default/
|
82 |
+
ipython_config.py
|
83 |
+
|
84 |
+
# pyenv
|
85 |
+
.python-version
|
86 |
+
|
87 |
+
# pipenv
|
88 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
89 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
90 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
91 |
+
# install all needed dependencies.
|
92 |
+
#Pipfile.lock
|
93 |
+
|
94 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
95 |
+
__pypackages__/
|
96 |
+
|
97 |
+
# Celery stuff
|
98 |
+
celerybeat-schedule
|
99 |
+
celerybeat.pid
|
100 |
+
|
101 |
+
# SageMath parsed files
|
102 |
+
*.sage.py
|
103 |
+
|
104 |
+
# Environments
|
105 |
+
.venv
|
106 |
+
env/
|
107 |
+
venv/
|
108 |
+
ENV/
|
109 |
+
env.bak/
|
110 |
+
venv.bak/
|
111 |
+
|
112 |
+
# Spyder project settings
|
113 |
+
.spyderproject
|
114 |
+
.spyproject
|
115 |
+
|
116 |
+
# Rope project settings
|
117 |
+
.ropeproject
|
118 |
+
|
119 |
+
# mkdocs documentation
|
120 |
+
/site
|
121 |
+
|
122 |
+
# mypy
|
123 |
+
.mypy_cache/
|
124 |
+
.dmypy.json
|
125 |
+
dmypy.json
|
126 |
+
|
127 |
+
# pytype type checking
|
128 |
+
.pytype/
|
129 |
+
|
130 |
+
# Pyre type checker
|
131 |
+
.pyre/
|
132 |
+
|
133 |
+
### VisualStudioCode
|
134 |
+
.vscode/*
|
135 |
+
!.vscode/settings.json
|
136 |
+
!.vscode/tasks.json
|
137 |
+
!.vscode/launch.json
|
138 |
+
!.vscode/extensions.json
|
139 |
+
*.code-workspace
|
140 |
+
**/.vscode
|
141 |
+
|
142 |
+
# JetBrains
|
143 |
+
.idea/
|
144 |
+
|
145 |
+
# Data & Models
|
146 |
+
*.h5
|
147 |
+
*.tar
|
148 |
+
*.tar.gz
|
149 |
+
|
150 |
+
# Template
|
151 |
+
/artifacts/models/databases/*/
|
README.md
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 🌍
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.33.1
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
+
title: Vocabulary-free Image Classification
|
3 |
emoji: 🌍
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 3.33.1
|
8 |
+
python_version: 3.9
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
---
|
app.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from src.nn import CaSED
|
7 |
+
|
8 |
+
PAPER_TITLE = "Vocabulary-free Image Classification"
|
9 |
+
PAPER_DESCRIPTION = """
|
10 |
+
|
11 |
+
|
12 |
+
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 1rem;">
|
13 |
+
<a href="https://github.com/altndrr/vic" style="margin-right: 0.5rem;">
|
14 |
+
<img src="https://img.shields.io/badge/code-github.altndrr%2Fvic-blue.svg"/>
|
15 |
+
</a>
|
16 |
+
<a href="https://arxiv.org/abs/2306.00917" style="margin-right: 0.5rem;">
|
17 |
+
<img src="https://img.shields.io/badge/paper-arXiv%3A2306.00917-B31B1B.svg"/>
|
18 |
+
</a>
|
19 |
+
<a href="https://altndrr.github.io/vic/" style="margin-right: 0.5rem;">
|
20 |
+
<img src="https://img.shields.io/badge/website-gh--pages.altndrr%2Fvic-success.svg"/>
|
21 |
+
</a>
|
22 |
+
</div>
|
23 |
+
|
24 |
+
|
25 |
+
Vocabulary-free Image Classification aims to assign a class to an image *without* prior knowledge
|
26 |
+
on the list of class names, thus operating on the semantic class space that contains all the
|
27 |
+
possible concepts. Our proposed method CaSED finds the best matching category within the
|
28 |
+
unconstrained semantic space by multimodal data from large vision-language databases. We first
|
29 |
+
retrieve the semantically most similar captions from a database, from which we extract a set of
|
30 |
+
candidate categories by applying text parsing and filtering techniques. We further score the
|
31 |
+
candidates using the multimodal aligned representation of the large pre-trained VLM, *i.e.* CLIP,
|
32 |
+
to obtain the best-matching category, using *alpha* as a hyperparameter to control the trade-off
|
33 |
+
between the visual and textual similarity.
|
34 |
+
"""
|
35 |
+
PAPER_URL = "https://arxiv.org/abs/2306.00917"
|
36 |
+
|
37 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
38 |
+
|
39 |
+
model = CaSED().to(DEVICE).eval()
|
40 |
+
|
41 |
+
|
42 |
+
def vic(filename: str, alpha: Optional[float] = None):
|
43 |
+
# get the outputs of the model
|
44 |
+
vocabulary, scores = model(filename, alpha=alpha)
|
45 |
+
confidences = dict(zip(vocabulary, scores))
|
46 |
+
|
47 |
+
return confidences
|
48 |
+
|
49 |
+
def resize_image(image, max_size: int = 256):
|
50 |
+
"""Resize image to max_size keeping the aspect ratio."""
|
51 |
+
width, height = image.size
|
52 |
+
if width > height:
|
53 |
+
ratio = width / height
|
54 |
+
new_width = max_size * ratio
|
55 |
+
new_height = max_size
|
56 |
+
else:
|
57 |
+
ratio = height / width
|
58 |
+
new_width = max_size
|
59 |
+
new_height = max_size * ratio
|
60 |
+
return image.resize((int(new_width), int(new_height)))
|
61 |
+
|
62 |
+
|
63 |
+
demo = gr.Interface(
|
64 |
+
fn=vic,
|
65 |
+
inputs=[
|
66 |
+
gr.Image(type="filepath", label="input"),
|
67 |
+
gr.Slider(0.0, 1.0, value=0.5, label="alpha"),
|
68 |
+
],
|
69 |
+
outputs=[gr.Label(num_top_classes=5, label="output")],
|
70 |
+
title=PAPER_TITLE,
|
71 |
+
description=PAPER_DESCRIPTION,
|
72 |
+
article=f"Check out <a href={PAPER_URL}>the original paper</a> for more information.",
|
73 |
+
examples="./artifacts/examples/",
|
74 |
+
allow_flagging='never',
|
75 |
+
theme=gr.themes.Soft()
|
76 |
+
)
|
77 |
+
|
78 |
+
demo.launch(share=False)
|
artifacts/examples/basketball.jpg
ADDED
artifacts/examples/cassowary.jpg
ADDED
artifacts/examples/colosseum.jpg
ADDED
artifacts/examples/desk.jpg
ADDED
artifacts/examples/kitchen.jpg
ADDED
artifacts/examples/log.csv
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_fp
|
2 |
+
basketball.jpg
|
3 |
+
cassowary.jpg
|
4 |
+
colosseum.jpg
|
5 |
+
desk.jpg
|
6 |
+
kitchen.jpg
|
7 |
+
monkey.jpg
|
8 |
+
park.jpg
|
9 |
+
ramen.jpg
|
10 |
+
sagrada.jpg
|
11 |
+
venice.jpg
|
artifacts/examples/monkey.jpg
ADDED
artifacts/examples/park.jpg
ADDED
artifacts/examples/ramen.jpg
ADDED
artifacts/examples/sagrada.jpg
ADDED
artifacts/examples/venice.jpg
ADDED
artifacts/models/databases/.gitkeep
ADDED
File without changes
|
artifacts/models/retrieval/indices.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"ViT-L-14_CC12M": "./artifacts/models/databases/cc12m/vit-l-14/"
|
3 |
+
}
|
flagged/.gitkeep
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.0.1
|
2 |
+
torchvision==0.15.2
|
3 |
+
faiss-cpu==1.7.4
|
4 |
+
flair==0.12.2
|
5 |
+
gradio==3.33.1
|
6 |
+
gdown==4.4.0
|
7 |
+
inflect==6.0.4
|
8 |
+
nltk==3.8.1
|
9 |
+
open_clip_torch==2.20.0
|
10 |
+
transformers==4.26.1
|
src/nn.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import tarfile
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import faiss
|
7 |
+
import gdown
|
8 |
+
import numpy as np
|
9 |
+
import open_clip
|
10 |
+
import torch
|
11 |
+
from open_clip.transformer import Transformer
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
from src.retrieval import ArrowMetadataProvider, meta_to_dict
|
15 |
+
from src.transforms import TextCompose, default_vocabulary_transforms
|
16 |
+
|
17 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
|
19 |
+
|
20 |
+
RETRIEVAL_DATABASES = {
|
21 |
+
"cc12m": "https://drive.google.com/uc?id=1HyM4mnKSxF0sqzAe-KZL8y-cQWRPiuXn&confirm=t",
|
22 |
+
}
|
23 |
+
|
24 |
+
|
25 |
+
class CaSED(torch.nn.Module):
|
26 |
+
"""Torch module for Category Search from External Databases (CaSED).
|
27 |
+
|
28 |
+
Args:
|
29 |
+
index_name (str): Name of the faiss index to use.
|
30 |
+
vocabulary_transforms (TextCompose): List of transforms to apply to the vocabulary.
|
31 |
+
model_name (str): Name of the CLIP model to use. Defaults to "ViT-L-14".
|
32 |
+
pretrained (str): Pretrained weights to use for the CLIP model. Defaults to "openai".
|
33 |
+
|
34 |
+
Extra hparams:
|
35 |
+
alpha (float): Weight for the average of the image and text predictions. Defaults to 0.5.
|
36 |
+
artifact_dir (str): Path to the directory where the databases are stored. Defaults to
|
37 |
+
"artifacts/".
|
38 |
+
retrieval_num_results (int): Number of results to return. Defaults to 10.
|
39 |
+
vocabulary_prompt (str): Prompt to use for the vocabulary. Defaults to "{}".
|
40 |
+
tau (float): Temperature to use for the classifier. Defaults to 1.0.
|
41 |
+
"""
|
42 |
+
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
index_name: str = "ViT-L-14_CC12M",
|
46 |
+
vocabulary_transforms: TextCompose = default_vocabulary_transforms(),
|
47 |
+
model_name: str = "ViT-L-14",
|
48 |
+
pretrained: str = "openai",
|
49 |
+
vocabulary_prompt: str = "{}",
|
50 |
+
**kwargs,
|
51 |
+
):
|
52 |
+
super().__init__()
|
53 |
+
self._prev_vocab_words = None
|
54 |
+
self._prev_used_prompts = None
|
55 |
+
self._prev_vocab_words_z = None
|
56 |
+
|
57 |
+
model, _, preprocess = open_clip.create_model_and_transforms(
|
58 |
+
model_name, pretrained=pretrained, device="cpu"
|
59 |
+
)
|
60 |
+
tokenizer = open_clip.get_tokenizer(model_name)
|
61 |
+
self.tokenizer = tokenizer
|
62 |
+
self.preprocess = preprocess
|
63 |
+
|
64 |
+
kwargs["alpha"] = kwargs.get("alpha", 0.5)
|
65 |
+
kwargs["artifact_dir"] = kwargs.get("artifact_dir", "artifacts/")
|
66 |
+
kwargs["retrieval_num_results"] = kwargs.get("retrieval_num_results", 10)
|
67 |
+
vocabulary_prompt = kwargs.get("vocabulary_prompt", "{}")
|
68 |
+
kwargs["vocabulary_prompts"] = [vocabulary_prompt]
|
69 |
+
kwargs["tau"] = kwargs.get("tau", 1.0)
|
70 |
+
self.hparams = kwargs
|
71 |
+
|
72 |
+
language_encoder = LanguageTransformer(
|
73 |
+
model.transformer,
|
74 |
+
model.token_embedding,
|
75 |
+
model.positional_embedding,
|
76 |
+
model.ln_final,
|
77 |
+
model.text_projection,
|
78 |
+
model.attn_mask,
|
79 |
+
)
|
80 |
+
scale = model.logit_scale.exp().item()
|
81 |
+
classifier = NearestNeighboursClassifier(scale=scale, tau=self.hparams["tau"])
|
82 |
+
|
83 |
+
self.index_name = index_name
|
84 |
+
self.vocabulary_transforms = vocabulary_transforms
|
85 |
+
self.vision_encoder = model.visual
|
86 |
+
self.language_encoder = language_encoder
|
87 |
+
self.classifier = classifier
|
88 |
+
|
89 |
+
# download databases
|
90 |
+
self.prepare_data()
|
91 |
+
|
92 |
+
# load faiss indices
|
93 |
+
indices_list_dir = Path(self.hparams["artifact_dir"]) / "models" / "retrieval"
|
94 |
+
indices_fp = indices_list_dir / "indices.json"
|
95 |
+
self.indices = json.load(open(indices_fp, "r"))
|
96 |
+
|
97 |
+
# load faiss indices and metadata providers
|
98 |
+
self.resources = {}
|
99 |
+
for name, index_fp in self.indices.items():
|
100 |
+
text_index_fp = Path(index_fp) / "text.index"
|
101 |
+
metadata_fp = Path(index_fp) / "metadata/"
|
102 |
+
|
103 |
+
text_index = faiss.read_index(
|
104 |
+
str(text_index_fp), faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY
|
105 |
+
)
|
106 |
+
metadata_provider = ArrowMetadataProvider(metadata_fp)
|
107 |
+
|
108 |
+
self.resources[name] = {
|
109 |
+
"device": DEVICE,
|
110 |
+
"model": model_name,
|
111 |
+
"text_index": text_index,
|
112 |
+
"metadata_provider": metadata_provider,
|
113 |
+
}
|
114 |
+
|
115 |
+
def prepare_data(self):
|
116 |
+
"""Download data if needed."""
|
117 |
+
databases_path = Path(self.hparams["artifact_dir"]) / "models" / "databases"
|
118 |
+
|
119 |
+
for name, url in RETRIEVAL_DATABASES.items():
|
120 |
+
database_path = Path(databases_path, name)
|
121 |
+
if database_path.exists():
|
122 |
+
continue
|
123 |
+
|
124 |
+
# download data
|
125 |
+
target_path = Path(databases_path, name + ".tar.gz")
|
126 |
+
try:
|
127 |
+
gdown.download(url, str(target_path), quiet=False)
|
128 |
+
tar = tarfile.open(target_path, "r:gz")
|
129 |
+
tar.extractall(target_path.parent)
|
130 |
+
tar.close()
|
131 |
+
target_path.unlink()
|
132 |
+
except FileNotFoundError:
|
133 |
+
print(f"Could not download {url}.")
|
134 |
+
print(f"Please download it manually and place it in {target_path.parent}.")
|
135 |
+
|
136 |
+
@torch.no_grad()
|
137 |
+
def query_index(self, sample_z: torch.Tensor) -> torch.Tensor:
|
138 |
+
# get the index
|
139 |
+
resources = self.resources[self.index_name]
|
140 |
+
text_index = resources["text_index"]
|
141 |
+
metadata_provider = resources["metadata_provider"]
|
142 |
+
|
143 |
+
# query the index
|
144 |
+
sample_z = sample_z.squeeze(0)
|
145 |
+
sample_z = sample_z / sample_z.norm(dim=-1, keepdim=True)
|
146 |
+
query_input = sample_z.cpu().detach().numpy().tolist()
|
147 |
+
query = np.expand_dims(np.array(query_input).astype("float32"), 0)
|
148 |
+
|
149 |
+
distances, idxs, _ = text_index.search_and_reconstruct(
|
150 |
+
query, self.hparams["retrieval_num_results"]
|
151 |
+
)
|
152 |
+
results = idxs[0]
|
153 |
+
nb_results = np.where(results == -1)[0]
|
154 |
+
nb_results = nb_results[0] if len(nb_results) > 0 else len(results)
|
155 |
+
indices = results[:nb_results]
|
156 |
+
distances = distances[0][:nb_results]
|
157 |
+
|
158 |
+
if len(distances) == 0:
|
159 |
+
return []
|
160 |
+
|
161 |
+
# get the metadata
|
162 |
+
results = []
|
163 |
+
metadata = metadata_provider.get(indices[:20], ["caption"])
|
164 |
+
for key, (d, i) in enumerate(zip(distances, indices)):
|
165 |
+
output = {}
|
166 |
+
meta = None if key + 1 > len(metadata) else metadata[key]
|
167 |
+
if meta is not None:
|
168 |
+
output.update(meta_to_dict(meta))
|
169 |
+
output["id"] = i.item()
|
170 |
+
output["similarity"] = d.item()
|
171 |
+
results.append(output)
|
172 |
+
|
173 |
+
# get the captions only
|
174 |
+
vocabularies = [result["caption"] for result in results]
|
175 |
+
|
176 |
+
return vocabularies
|
177 |
+
|
178 |
+
@torch.no_grad()
|
179 |
+
def encode_vocabulary(self, vocabulary: list, use_prompts: bool = False) -> torch.Tensor:
|
180 |
+
"""Encode a vocabulary.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
vocabulary (list): List of words.
|
184 |
+
"""
|
185 |
+
# check if vocabulary has changed
|
186 |
+
if vocabulary == self._prev_vocab_words and use_prompts == self._prev_used_prompts:
|
187 |
+
return self._prev_vocab_words_z
|
188 |
+
|
189 |
+
# tokenize vocabulary
|
190 |
+
classes = [c.replace("_", " ") for c in vocabulary]
|
191 |
+
prompts = self.hparams["vocabulary_prompts"] if use_prompts else ["{}"]
|
192 |
+
texts_views = [[p.format(c) for c in classes] for p in prompts]
|
193 |
+
tokenized_texts_views = [
|
194 |
+
torch.cat([self.tokenizer(prompt) for prompt in class_prompts])
|
195 |
+
for class_prompts in texts_views
|
196 |
+
]
|
197 |
+
tokenized_texts_views = torch.stack(tokenized_texts_views).to(DEVICE)
|
198 |
+
|
199 |
+
# encode vocabulary
|
200 |
+
T, C, _ = tokenized_texts_views.shape
|
201 |
+
texts_z_views = self.language_encoder(tokenized_texts_views.view(T * C, -1))
|
202 |
+
texts_z_views = texts_z_views.view(T, C, -1)
|
203 |
+
texts_z_views = texts_z_views / texts_z_views.norm(dim=-1, keepdim=True)
|
204 |
+
|
205 |
+
# cache vocabulary
|
206 |
+
self._prev_vocab_words = vocabulary
|
207 |
+
self._prev_used_prompts = use_prompts
|
208 |
+
self._prev_vocab_words_z = texts_z_views
|
209 |
+
|
210 |
+
return texts_z_views
|
211 |
+
|
212 |
+
@torch.no_grad()
|
213 |
+
def forward(self, image_fp: str, alpha: Optional[float] = None) -> torch.Tensor():
|
214 |
+
image = self.preprocess(Image.open(image_fp)).unsqueeze(0)
|
215 |
+
image_z = self.vision_encoder(image.to(DEVICE))
|
216 |
+
|
217 |
+
# get the vocabulary
|
218 |
+
vocabulary = self.query_index(image_z)
|
219 |
+
|
220 |
+
# generate a single text embedding from the unfiltered vocabulary
|
221 |
+
unfiltered_vocabulary_z = self.encode_vocabulary(vocabulary).squeeze(0)
|
222 |
+
text_z = unfiltered_vocabulary_z.mean(dim=0)
|
223 |
+
text_z = text_z / text_z.norm(dim=-1, keepdim=True)
|
224 |
+
text_z = text_z.unsqueeze(0)
|
225 |
+
|
226 |
+
# filter the vocabulary, embed it, and get its mean embedding
|
227 |
+
vocabulary = self.vocabulary_transforms(vocabulary) or ["object"]
|
228 |
+
vocabulary_z = self.encode_vocabulary(vocabulary, use_prompts=True)
|
229 |
+
mean_vocabulary_z = vocabulary_z.mean(dim=0)
|
230 |
+
mean_vocabulary_z = mean_vocabulary_z / mean_vocabulary_z.norm(dim=-1, keepdim=True)
|
231 |
+
|
232 |
+
# get the image and text predictions
|
233 |
+
image_p = self.classifier(image_z, vocabulary_z)
|
234 |
+
text_p = self.classifier(text_z, vocabulary_z)
|
235 |
+
|
236 |
+
# average the image and text predictions
|
237 |
+
alpha = alpha or self.hparams["alpha"]
|
238 |
+
sample_p = alpha * image_p + (1 - alpha) * text_p
|
239 |
+
|
240 |
+
# get the scores
|
241 |
+
sample_p = sample_p.cpu()
|
242 |
+
scores = sample_p[0].tolist()
|
243 |
+
|
244 |
+
del image_z, unfiltered_vocabulary_z, text_z, vocabulary_z, mean_vocabulary_z
|
245 |
+
del image_p, text_p, sample_p
|
246 |
+
|
247 |
+
return vocabulary, scores
|
248 |
+
|
249 |
+
|
250 |
+
class NearestNeighboursClassifier(torch.nn.Module):
|
251 |
+
"""Nearest neighbours classifier.
|
252 |
+
|
253 |
+
It computes the similarity between the query and the supports using the
|
254 |
+
cosine similarity and then applies a softmax to obtain the logits.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
scale (float): Scale for the logits of the query. Defaults to 1.0.
|
258 |
+
tau (float): Temperature for the softmax. Defaults to 1.0.
|
259 |
+
"""
|
260 |
+
|
261 |
+
def __init__(self, scale: float = 1.0, tau: float = 1.0):
|
262 |
+
super().__init__()
|
263 |
+
self.scale = scale
|
264 |
+
self.tau = tau
|
265 |
+
|
266 |
+
def forward(self, query: torch.Tensor, supports: torch.Tensor):
|
267 |
+
query = query / query.norm(dim=-1, keepdim=True)
|
268 |
+
supports = supports / supports.norm(dim=-1, keepdim=True)
|
269 |
+
|
270 |
+
if supports.dim() == 2:
|
271 |
+
supports = supports.unsqueeze(0)
|
272 |
+
|
273 |
+
Q, _ = query.shape
|
274 |
+
N, C, _ = supports.shape
|
275 |
+
|
276 |
+
supports = supports.mean(dim=0)
|
277 |
+
supports = supports / supports.norm(dim=-1, keepdim=True)
|
278 |
+
similarity = self.scale * query @ supports.T
|
279 |
+
similarity = similarity / self.tau if self.tau != 1.0 else similarity
|
280 |
+
logits = similarity.softmax(dim=-1)
|
281 |
+
|
282 |
+
return logits
|
283 |
+
|
284 |
+
|
285 |
+
class LanguageTransformer(torch.nn.Module):
|
286 |
+
"""Language Transformer for CLIP.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
transformer (Transformer): Transformer model.
|
290 |
+
token_embedding (torch.nn.Embedding): Token embedding.
|
291 |
+
positional_embedding (torch.nn.Parameter): Positional embedding.
|
292 |
+
ln_final (torch.nn.LayerNorm): Layer norm.
|
293 |
+
text_projection (torch.nn.Parameter): Text projection.
|
294 |
+
"""
|
295 |
+
|
296 |
+
def __init__(
|
297 |
+
self,
|
298 |
+
model: Transformer,
|
299 |
+
token_embedding: torch.nn.Embedding,
|
300 |
+
positional_embedding: torch.nn.Parameter,
|
301 |
+
ln_final: torch.nn.LayerNorm,
|
302 |
+
text_projection: torch.nn.Parameter,
|
303 |
+
attn_mask: torch.Tensor,
|
304 |
+
):
|
305 |
+
super().__init__()
|
306 |
+
self.transformer = model
|
307 |
+
self.token_embedding = token_embedding
|
308 |
+
self.positional_embedding = positional_embedding
|
309 |
+
self.ln_final = ln_final
|
310 |
+
self.text_projection = text_projection
|
311 |
+
|
312 |
+
self.register_buffer("attn_mask", attn_mask, persistent=False)
|
313 |
+
|
314 |
+
def forward(self, text: torch.Tensor) -> torch.Tensor:
|
315 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
316 |
+
|
317 |
+
"""Forward pass for the text encoder."""
|
318 |
+
x = self.token_embedding(text).to(cast_dtype)
|
319 |
+
|
320 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
321 |
+
x = x.permute(1, 0, 2)
|
322 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
323 |
+
x = x.permute(1, 0, 2)
|
324 |
+
x = self.ln_final(x)
|
325 |
+
|
326 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
327 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
328 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
329 |
+
|
330 |
+
return x
|
src/retrieval.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import pyarrow as pa
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class ArrowMetadataProvider:
|
8 |
+
"""The arrow metadata provider provides metadata from contiguous ids using arrow.
|
9 |
+
|
10 |
+
Code taken from:
|
11 |
+
https://github.dev/rom1504/clip-retrieval
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, arrow_folder):
|
15 |
+
arrow_files = [str(a) for a in sorted(Path(arrow_folder).glob("**/*")) if a.is_file()]
|
16 |
+
self.table = pa.concat_tables(
|
17 |
+
[
|
18 |
+
pa.ipc.RecordBatchFileReader(pa.memory_map(arrow_file, "r")).read_all()
|
19 |
+
for arrow_file in arrow_files
|
20 |
+
]
|
21 |
+
)
|
22 |
+
|
23 |
+
def get(self, ids, cols=None):
|
24 |
+
"""implement the get method from the arrow metadata provide, get metadata from ids"""
|
25 |
+
if cols is None:
|
26 |
+
cols = self.table.schema.names
|
27 |
+
else:
|
28 |
+
cols = list(set(self.table.schema.names) & set(cols))
|
29 |
+
t = pa.concat_tables([self.table[i:(i + 1)] for i in ids])
|
30 |
+
return t.select(cols).to_pandas().to_dict("records")
|
31 |
+
|
32 |
+
|
33 |
+
def meta_to_dict(meta):
|
34 |
+
"""Convert a metadata list to a dictionary."""
|
35 |
+
output = {}
|
36 |
+
for k, v in meta.items():
|
37 |
+
if isinstance(v, bytes):
|
38 |
+
v = v.decode()
|
39 |
+
elif type(v).__module__ == np.__name__:
|
40 |
+
v = v.item()
|
41 |
+
output[k] = v
|
42 |
+
return output
|
src/transforms.py
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from typing import Any, Optional, Union, cast
|
4 |
+
|
5 |
+
import inflect
|
6 |
+
import nltk
|
7 |
+
import numpy as np
|
8 |
+
import PIL.Image
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms as T
|
11 |
+
import torchvision.transforms.functional as F
|
12 |
+
from flair.data import Sentence
|
13 |
+
from flair.models import SequenceTagger
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
"DynamicResize",
|
17 |
+
"DropFileExtensions",
|
18 |
+
"DropNonAlpha",
|
19 |
+
"DropShortWords",
|
20 |
+
"DropSpecialCharacters",
|
21 |
+
"DropTokens",
|
22 |
+
"DropURLs",
|
23 |
+
"DropWords",
|
24 |
+
"FilterPOS",
|
25 |
+
"FrequencyMinWordCount",
|
26 |
+
"FrequencyTopK",
|
27 |
+
"ReplaceSeparators",
|
28 |
+
"ToRGBTensor",
|
29 |
+
"ToLowercase",
|
30 |
+
"ToSingular",
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
class BaseTextTransform(ABC):
|
35 |
+
"""Base class for string transforms."""
|
36 |
+
|
37 |
+
@abstractmethod
|
38 |
+
def __call__(self, text: str):
|
39 |
+
raise NotImplementedError
|
40 |
+
|
41 |
+
def __repr__(self) -> str:
|
42 |
+
return f"{self.__class__.__name__}()"
|
43 |
+
|
44 |
+
|
45 |
+
class DynamicResize(T.Resize):
|
46 |
+
"""Resize the input PIL Image to the given size.
|
47 |
+
|
48 |
+
Extends the torchvision Resize transform to dynamically evaluate the second dimension of the
|
49 |
+
output size based on the aspect ratio of the first input image.
|
50 |
+
"""
|
51 |
+
|
52 |
+
def forward(self, img):
|
53 |
+
if isinstance(self.size, int):
|
54 |
+
_, h, w = F.get_dimensions(img)
|
55 |
+
aspect_ratio = w / h
|
56 |
+
side = self.size
|
57 |
+
|
58 |
+
if aspect_ratio < 1.0:
|
59 |
+
self.size = int(side / aspect_ratio), side
|
60 |
+
else:
|
61 |
+
self.size = side, int(side * aspect_ratio)
|
62 |
+
|
63 |
+
return super().forward(img)
|
64 |
+
|
65 |
+
|
66 |
+
class DropFileExtensions(BaseTextTransform):
|
67 |
+
"""Remove file extensions from the input text."""
|
68 |
+
|
69 |
+
def __call__(self, text: str):
|
70 |
+
"""
|
71 |
+
Args:
|
72 |
+
text (str): Text to remove file extensions from.
|
73 |
+
"""
|
74 |
+
text = re.sub(r"\.\w+", "", text)
|
75 |
+
|
76 |
+
return text
|
77 |
+
|
78 |
+
|
79 |
+
class DropNonAlpha(BaseTextTransform):
|
80 |
+
"""Remove non-alpha words from the input text."""
|
81 |
+
|
82 |
+
def __call__(self, text: str):
|
83 |
+
"""
|
84 |
+
Args:
|
85 |
+
text (str): Text to remove non-alpha words from.
|
86 |
+
"""
|
87 |
+
text = re.sub(r"[^a-zA-Z\s]", "", text)
|
88 |
+
|
89 |
+
return text
|
90 |
+
|
91 |
+
|
92 |
+
class DropShortWords(BaseTextTransform):
|
93 |
+
"""Remove short words from the input text.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
min_length (int): Minimum length of words to keep.
|
97 |
+
"""
|
98 |
+
|
99 |
+
def __init__(self, min_length) -> None:
|
100 |
+
super().__init__()
|
101 |
+
self.min_length = min_length
|
102 |
+
|
103 |
+
def __call__(self, text: str):
|
104 |
+
"""
|
105 |
+
Args:
|
106 |
+
text (str): Text to remove short words from.
|
107 |
+
"""
|
108 |
+
text = " ".join([word for word in text.split() if len(word) >= self.min_length])
|
109 |
+
|
110 |
+
return text
|
111 |
+
|
112 |
+
def __repr__(self) -> str:
|
113 |
+
return f"{self.__class__.__name__}(min_length={self.min_length})"
|
114 |
+
|
115 |
+
|
116 |
+
class DropSpecialCharacters(BaseTextTransform):
|
117 |
+
"""Remove special characters from the input text.
|
118 |
+
|
119 |
+
Special characters are defined as any character that is not a word character, whitespace,
|
120 |
+
hyphen, period, apostrophe, or ampersand.
|
121 |
+
"""
|
122 |
+
|
123 |
+
def __call__(self, text: str):
|
124 |
+
"""
|
125 |
+
Args:
|
126 |
+
text (str): Text to remove special characters from.
|
127 |
+
"""
|
128 |
+
text = re.sub(r"[^\w\s\-\.\'\&]", "", text)
|
129 |
+
|
130 |
+
return text
|
131 |
+
|
132 |
+
|
133 |
+
class DropTokens(BaseTextTransform):
|
134 |
+
"""Remove tokens from the input text.
|
135 |
+
|
136 |
+
Tokens are defined as strings enclosed in angle brackets, e.g. <token>.
|
137 |
+
"""
|
138 |
+
|
139 |
+
def __call__(self, text: str):
|
140 |
+
"""
|
141 |
+
Args:
|
142 |
+
text (str): Text to remove tokens from.
|
143 |
+
"""
|
144 |
+
text = re.sub(r"<[^>]+>", "", text)
|
145 |
+
|
146 |
+
return text
|
147 |
+
|
148 |
+
|
149 |
+
class DropURLs(BaseTextTransform):
|
150 |
+
"""Remove URLs from the input text."""
|
151 |
+
|
152 |
+
def __call__(self, text: str):
|
153 |
+
"""
|
154 |
+
Args:
|
155 |
+
text (str): Text to remove URLs from.
|
156 |
+
"""
|
157 |
+
text = re.sub(r"http\S+", "", text)
|
158 |
+
|
159 |
+
return text
|
160 |
+
|
161 |
+
|
162 |
+
class DropWords(BaseTextTransform):
|
163 |
+
"""Remove words from the input text.
|
164 |
+
|
165 |
+
It is case-insensitive and supports singular and plural forms of the words.
|
166 |
+
"""
|
167 |
+
|
168 |
+
def __init__(self, words: list[str]) -> None:
|
169 |
+
super().__init__()
|
170 |
+
self.words = words
|
171 |
+
self.pattern = r"\b(?:{})\b".format("|".join(words))
|
172 |
+
|
173 |
+
def __call__(self, text: str):
|
174 |
+
"""
|
175 |
+
Args:
|
176 |
+
text (str): Text to remove words from.
|
177 |
+
"""
|
178 |
+
text = re.sub(self.pattern, "", text, flags=re.IGNORECASE)
|
179 |
+
|
180 |
+
return text
|
181 |
+
|
182 |
+
def __repr__(self) -> str:
|
183 |
+
return f"{self.__class__.__name__}(pattern={self.pattern})"
|
184 |
+
|
185 |
+
|
186 |
+
class FilterPOS(BaseTextTransform):
|
187 |
+
"""Filter words by POS tags.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
tags (list): List of POS tags to remove.
|
191 |
+
engine (str): POS tagger to use. Must be one of "nltk" or "flair". Defaults to "nltk".
|
192 |
+
keep_compound_nouns (bool): Whether to keep composed words. Defaults to True.
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(self, tags: list, engine: str = "nltk", keep_compound_nouns: bool = True) -> None:
|
196 |
+
super().__init__()
|
197 |
+
self.tags = tags
|
198 |
+
self.engine = engine
|
199 |
+
self.keep_compound_nouns = keep_compound_nouns
|
200 |
+
|
201 |
+
if engine == "nltk":
|
202 |
+
nltk.download("averaged_perceptron_tagger", quiet=True)
|
203 |
+
nltk.download("punkt", quiet=True)
|
204 |
+
self.tagger = lambda x: nltk.pos_tag(nltk.word_tokenize(x))
|
205 |
+
elif engine == "flair":
|
206 |
+
self.tagger = SequenceTagger.load("flair/pos-english-fast").predict
|
207 |
+
|
208 |
+
def __call__(self, text: str):
|
209 |
+
"""
|
210 |
+
Args:
|
211 |
+
text (str): Text to remove words with specific POS tags from.
|
212 |
+
"""
|
213 |
+
if self.engine == "nltk":
|
214 |
+
word_tags = self.tagger(text)
|
215 |
+
text = " ".join([word for word, tag in word_tags if tag not in self.tags])
|
216 |
+
elif self.engine == "flair":
|
217 |
+
sentence = Sentence(text)
|
218 |
+
self.tagger(sentence)
|
219 |
+
text = " ".join([token.text for token in sentence.tokens if token.tag in self.tags])
|
220 |
+
|
221 |
+
if self.keep_compound_nouns:
|
222 |
+
compound_nouns = []
|
223 |
+
|
224 |
+
if self.engine == "nltk":
|
225 |
+
for i in range(len(word_tags) - 1):
|
226 |
+
if word_tags[i][1] == "NN" and word_tags[i + 1][1] == "NN":
|
227 |
+
# if they are the same word, skip
|
228 |
+
if word_tags[i][0] == word_tags[i + 1][0]:
|
229 |
+
continue
|
230 |
+
|
231 |
+
compound_noun = word_tags[i][0] + "_" + word_tags[i + 1][0]
|
232 |
+
compound_nouns.append(compound_noun)
|
233 |
+
elif self.engine == "flair":
|
234 |
+
for i in range(len(sentence.tokens) - 1):
|
235 |
+
if sentence.tokens[i].tag == "NN" and sentence.tokens[i + 1].tag == "NN":
|
236 |
+
# if they are the same word, skip
|
237 |
+
if sentence.tokens[i].text == sentence.tokens[i + 1].text:
|
238 |
+
continue
|
239 |
+
|
240 |
+
compound_noun = sentence.tokens[i].text + "_" + sentence.tokens[i + 1].text
|
241 |
+
compound_nouns.append(compound_noun)
|
242 |
+
|
243 |
+
text = " ".join([text, " ".join(compound_nouns)])
|
244 |
+
|
245 |
+
return text
|
246 |
+
|
247 |
+
def __repr__(self) -> str:
|
248 |
+
return f"{self.__class__.__name__}(tags={self.tags}, engine={self.engine})"
|
249 |
+
|
250 |
+
|
251 |
+
class FrequencyMinWordCount(BaseTextTransform):
|
252 |
+
"""Keep only words that occur more than a minimum number of times in the input text.
|
253 |
+
|
254 |
+
If the threshold is too strong and no words pass the threshold, the threshold is reduced to
|
255 |
+
the most frequent word.
|
256 |
+
|
257 |
+
Args:
|
258 |
+
min_count (int): Minimum number of occurrences of a word to keep.
|
259 |
+
"""
|
260 |
+
|
261 |
+
def __init__(self, min_count) -> None:
|
262 |
+
super().__init__()
|
263 |
+
self.min_count = min_count
|
264 |
+
|
265 |
+
def __call__(self, text: str):
|
266 |
+
"""
|
267 |
+
Args:
|
268 |
+
text (str): Text to remove infrequent words from.
|
269 |
+
"""
|
270 |
+
if self.min_count <= 1:
|
271 |
+
return text
|
272 |
+
|
273 |
+
words = text.split()
|
274 |
+
word_counts = {word: words.count(word) for word in words}
|
275 |
+
|
276 |
+
# if nothing passes the threshold, reduce the threshold to the most frequent word
|
277 |
+
max_word_count = max(word_counts.values() or [0])
|
278 |
+
min_count = max_word_count if self.min_count > max_word_count else self.min_count
|
279 |
+
|
280 |
+
text = " ".join([word for word in words if word_counts[word] >= min_count])
|
281 |
+
|
282 |
+
return text
|
283 |
+
|
284 |
+
def __repr__(self) -> str:
|
285 |
+
return f"{self.__class__.__name__}(min_count={self.min_count})"
|
286 |
+
|
287 |
+
|
288 |
+
class FrequencyTopK(BaseTextTransform):
|
289 |
+
"""Keep only the top k most frequent words in the input text.
|
290 |
+
|
291 |
+
In case of a tie, all words with the same count as the last word are kept.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
top_k (int): Number of top words to keep.
|
295 |
+
"""
|
296 |
+
|
297 |
+
def __init__(self, top_k: int) -> None:
|
298 |
+
super().__init__()
|
299 |
+
self.top_k = top_k
|
300 |
+
|
301 |
+
def __call__(self, text: str):
|
302 |
+
"""
|
303 |
+
Args:
|
304 |
+
text (str): Text to remove infrequent words from.
|
305 |
+
"""
|
306 |
+
if self.top_k < 1:
|
307 |
+
return text
|
308 |
+
|
309 |
+
words = text.split()
|
310 |
+
word_counts = {word: words.count(word) for word in words}
|
311 |
+
top_words = sorted(word_counts, key=word_counts.get, reverse=True)
|
312 |
+
|
313 |
+
# in case of a tie, keep all words with the same count
|
314 |
+
top_words = top_words[: self.top_k]
|
315 |
+
top_words = [word for word in top_words if word_counts[word] == word_counts[top_words[-1]]]
|
316 |
+
|
317 |
+
text = " ".join([word for word in words if word in top_words])
|
318 |
+
|
319 |
+
return text
|
320 |
+
|
321 |
+
def __repr__(self) -> str:
|
322 |
+
return f"{self.__class__.__name__}(top_k={self.top_k})"
|
323 |
+
|
324 |
+
|
325 |
+
class ReplaceSeparators(BaseTextTransform):
|
326 |
+
"""Replace underscores and dashes with spaces."""
|
327 |
+
|
328 |
+
def __call__(self, text: str):
|
329 |
+
"""
|
330 |
+
Args:
|
331 |
+
text (str): Text to replace separators in.
|
332 |
+
"""
|
333 |
+
text = re.sub(r"[_\-]", " ", text)
|
334 |
+
|
335 |
+
return text
|
336 |
+
|
337 |
+
def __repr__(self) -> str:
|
338 |
+
return f"{self.__class__.__name__}()"
|
339 |
+
|
340 |
+
|
341 |
+
class RemoveDuplicates(BaseTextTransform):
|
342 |
+
"""Remove duplicate words from the input text."""
|
343 |
+
|
344 |
+
def __call__(self, text: str):
|
345 |
+
"""
|
346 |
+
Args:
|
347 |
+
text (str): Text to remove duplicate words from.
|
348 |
+
"""
|
349 |
+
text = " ".join(list(set(text.split())))
|
350 |
+
|
351 |
+
return text
|
352 |
+
|
353 |
+
|
354 |
+
class TextCompose:
|
355 |
+
"""Compose several transforms together.
|
356 |
+
|
357 |
+
It differs from the torchvision.transforms.Compose class in that it applies the transforms to
|
358 |
+
a string instead of a PIL Image or Tensor. In addition, it automatically join the list of
|
359 |
+
input strings into a single string and splits the output string into a list of words.
|
360 |
+
|
361 |
+
Args:
|
362 |
+
transforms (list): List of transforms to compose.
|
363 |
+
"""
|
364 |
+
|
365 |
+
def __init__(self, transforms: list[BaseTextTransform]) -> None:
|
366 |
+
self.transforms = transforms
|
367 |
+
|
368 |
+
def __call__(self, text: Union[str, list[str]]) -> Any:
|
369 |
+
if isinstance(text, list):
|
370 |
+
text = " ".join(text)
|
371 |
+
|
372 |
+
for t in self.transforms:
|
373 |
+
text = t(text)
|
374 |
+
return text.split()
|
375 |
+
|
376 |
+
def __repr__(self) -> str:
|
377 |
+
format_string = self.__class__.__name__ + "("
|
378 |
+
for t in self.transforms:
|
379 |
+
format_string += "\n"
|
380 |
+
format_string += f" {t}"
|
381 |
+
format_string += "\n)"
|
382 |
+
return format_string
|
383 |
+
|
384 |
+
|
385 |
+
class ToRGBTensor(T.ToTensor):
|
386 |
+
"""Convert a `PIL Image` or `numpy.ndarray` to tensor.
|
387 |
+
|
388 |
+
Compared with the torchvision `ToTensor` transform, it converts images with a single channel to
|
389 |
+
RGB images. In addition, the conversion to tensor is done only if the input is not already a
|
390 |
+
tensor.
|
391 |
+
"""
|
392 |
+
|
393 |
+
def __call__(self, pic: Union[PIL.Image.Image, np.ndarray, torch.Tensor]):
|
394 |
+
"""
|
395 |
+
Args:
|
396 |
+
pic (PIL Image | numpy.ndarray | torch.Tensor): Image to be converted to tensor.
|
397 |
+
"""
|
398 |
+
img = pic if isinstance(pic, torch.Tensor) else F.to_tensor(pic)
|
399 |
+
img = cast(torch.Tensor, img)
|
400 |
+
|
401 |
+
if img.shape[0] == 1:
|
402 |
+
img = img.repeat(3, 1, 1)
|
403 |
+
|
404 |
+
return img
|
405 |
+
|
406 |
+
def __repr__(self) -> str:
|
407 |
+
return f"{self.__class__.__name__}()"
|
408 |
+
|
409 |
+
|
410 |
+
class ToLowercase(BaseTextTransform):
|
411 |
+
"""Convert text to lowercase."""
|
412 |
+
|
413 |
+
def __call__(self, text: str):
|
414 |
+
"""
|
415 |
+
Args:
|
416 |
+
text (str): Text to convert to lowercase.
|
417 |
+
"""
|
418 |
+
text = text.lower()
|
419 |
+
|
420 |
+
return text
|
421 |
+
|
422 |
+
|
423 |
+
class ToSingular(BaseTextTransform):
|
424 |
+
"""Convert plural words to singular form."""
|
425 |
+
|
426 |
+
def __init__(self) -> None:
|
427 |
+
super().__init__()
|
428 |
+
self.transform = inflect.engine().singular_noun
|
429 |
+
|
430 |
+
def __call__(self, text: str):
|
431 |
+
"""
|
432 |
+
Args:
|
433 |
+
text (str): Text to convert to singular form.
|
434 |
+
"""
|
435 |
+
words = text.split()
|
436 |
+
for i, word in enumerate(words):
|
437 |
+
if not word.endswith("s"):
|
438 |
+
continue
|
439 |
+
|
440 |
+
if word[-2:] in ["ss", "us", "is"]:
|
441 |
+
continue
|
442 |
+
|
443 |
+
if word[-3:] in ["ies", "oes"]:
|
444 |
+
continue
|
445 |
+
|
446 |
+
words[i] = self.transform(word) or word
|
447 |
+
|
448 |
+
text = " ".join(words)
|
449 |
+
|
450 |
+
return text
|
451 |
+
|
452 |
+
def __repr__(self) -> str:
|
453 |
+
return f"{self.__class__.__name__}()"
|
454 |
+
|
455 |
+
|
456 |
+
def default_preprocess(size: Optional[int] = None) -> T.Compose:
|
457 |
+
"""Preprocess input images with preprocessing transforms.
|
458 |
+
|
459 |
+
Args:
|
460 |
+
size (int): Size to resize image to.
|
461 |
+
"""
|
462 |
+
transforms = []
|
463 |
+
if size is not None:
|
464 |
+
transforms.append(DynamicResize(size, interpolation=T.InterpolationMode.BICUBIC))
|
465 |
+
transforms.append(ToRGBTensor())
|
466 |
+
transforms = T.Compose(transforms)
|
467 |
+
|
468 |
+
return transforms
|
469 |
+
|
470 |
+
|
471 |
+
def default_vocabulary_transforms() -> TextCompose:
|
472 |
+
"""Preprocess input text with preprocessing transforms."""
|
473 |
+
words_to_drop = [
|
474 |
+
"image",
|
475 |
+
"photo",
|
476 |
+
"picture",
|
477 |
+
"thumbnail",
|
478 |
+
"logo",
|
479 |
+
"symbol",
|
480 |
+
"clipart",
|
481 |
+
"portrait",
|
482 |
+
"painting",
|
483 |
+
"illustration",
|
484 |
+
"icon",
|
485 |
+
"profile",
|
486 |
+
]
|
487 |
+
pos_tags = ["NN", "NNS", "NNP", "NNPS", "JJ", "JJR", "JJS", "VBG", "VBN"]
|
488 |
+
|
489 |
+
transforms = []
|
490 |
+
transforms.append(DropTokens())
|
491 |
+
transforms.append(DropURLs())
|
492 |
+
transforms.append(DropSpecialCharacters())
|
493 |
+
transforms.append(DropFileExtensions())
|
494 |
+
transforms.append(ReplaceSeparators())
|
495 |
+
transforms.append(DropShortWords(min_length=3))
|
496 |
+
transforms.append(DropNonAlpha())
|
497 |
+
transforms.append(ToLowercase())
|
498 |
+
transforms.append(ToSingular())
|
499 |
+
transforms.append(DropWords(words=words_to_drop))
|
500 |
+
transforms.append(FrequencyMinWordCount(min_count=2))
|
501 |
+
transforms.append(FilterPOS(tags=pos_tags, engine="flair", keep_compound_nouns=False))
|
502 |
+
transforms.append(RemoveDuplicates())
|
503 |
+
|
504 |
+
transforms = TextCompose(transforms)
|
505 |
+
|
506 |
+
return transforms
|