vincentclaes commited on
Commit
e45afa6
β€’
1 Parent(s): 0df1067

have a working model

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +59 -6
  3. poetry.lock +10 -68
  4. pyproject.toml +1 -0
  5. requirements.txt +1 -0
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: Emoji Predictor
3
- emoji: πŸ“Š
4
  colorFrom: pink
5
  colorTo: indigo
6
  sdk: gradio
 
1
  ---
2
  title: Emoji Predictor
3
+ emoji: 😎
4
  colorFrom: pink
5
  colorTo: indigo
6
  sdk: gradio
app.py CHANGED
@@ -1,10 +1,16 @@
1
  import gradio as gr
2
  import torch
 
 
3
  from PIL import Image
 
 
 
4
  from transformers import CLIPProcessor, CLIPModel
5
 
6
  checkpoint = "vincentclaes/emoji-predictor"
7
- no_of_emojis = range(20)
 
8
  emojis_as_images = [Image.open(f"emojis/{i}.png") for i in no_of_emojis]
9
  K = 4
10
 
@@ -12,6 +18,29 @@ processor = CLIPProcessor.from_pretrained(checkpoint)
12
  model = CLIPModel.from_pretrained(checkpoint)
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def get_emoji(text, model=model, processor=processor, emojis=emojis_as_images, K=4):
16
  inputs = processor(text=text, images=emojis, return_tensors="pt", padding=True, truncation=True)
17
  outputs = model(**inputs)
@@ -23,11 +52,35 @@ def get_emoji(text, model=model, processor=processor, emojis=emojis_as_images, K
23
  predictions_suggestions_for_chunk = [torch.topk(prob, K).indices.tolist() for prob in probs][0]
24
  predictions_suggestions_for_chunk
25
 
26
- return [f"emojis/{i}.png" for i in predictions_suggestions_for_chunk]
 
 
27
 
28
 
29
- text = gr.inputs.Textbox()
30
  title = "Predicting an Emoji"
31
- description = "Enter a text and we will try to predict an emoji.\nThe model is a few shot fine tuned CLIP model trained on images of emoji's."
32
- examples = ["I'm so glad I finally arrived in my holiday resort!"]
33
- gr.Interface(fn=get_emoji, inputs=text, outputs=gr.Gallery(), examples=examples, title=title).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ import os
4
+
5
  from PIL import Image
6
+ from pathlib import Path
7
+ from more_itertools import chunked
8
+
9
  from transformers import CLIPProcessor, CLIPModel
10
 
11
  checkpoint = "vincentclaes/emoji-predictor"
12
+ x_, _, files = next(os.walk("./emojis"))
13
+ no_of_emojis = range(len(files))
14
  emojis_as_images = [Image.open(f"emojis/{i}.png") for i in no_of_emojis]
15
  K = 4
16
 
 
18
  model = CLIPModel.from_pretrained(checkpoint)
19
 
20
 
21
+ def concat_images(*images):
22
+ """Generate composite of all supplied images.
23
+ https://stackoverflow.com/a/71315656/1771155
24
+ """
25
+ # Get the widest width.
26
+ width = max(image.width for image in images)
27
+ # Add up all the heights.
28
+ height = max(image.height for image in images)
29
+ # set the correct size of width and heigtht of composite.
30
+ composite = Image.new('RGB', (2*width, 2*height))
31
+ assert K == 4, "We expect 4 suggestions, other numbers won't work."
32
+ for i, image in enumerate(images):
33
+ if i == 0:
34
+ composite.paste(image, (0, 0))
35
+ elif i == 1:
36
+ composite.paste(image, (width, 0))
37
+ elif i == 2:
38
+ composite.paste(image, (0, height))
39
+ elif i == 3:
40
+ composite.paste(image, (width, height))
41
+ return composite
42
+
43
+
44
  def get_emoji(text, model=model, processor=processor, emojis=emojis_as_images, K=4):
45
  inputs = processor(text=text, images=emojis, return_tensors="pt", padding=True, truncation=True)
46
  outputs = model(**inputs)
 
52
  predictions_suggestions_for_chunk = [torch.topk(prob, K).indices.tolist() for prob in probs][0]
53
  predictions_suggestions_for_chunk
54
 
55
+ images = [Image.open(f"emojis/{i}.png") for i in predictions_suggestions_for_chunk]
56
+ images_concat = concat_images(*images)
57
+ return images_concat
58
 
59
 
60
+ text = gr.inputs.Textbox(placeholder="Enter a text and we will try to predict an emoji...")
61
  title = "Predicting an Emoji"
62
+ description = """You provide a sentence and our few-shot fine tuned CLIP model will predict from the following emoji's:
63
+ \n❀️ 😍 πŸ˜‚ πŸ’• πŸ”₯ 😊 😎 ✨ πŸ’™ 😘 πŸ“· πŸ‡ΊπŸ‡Έ β˜€ πŸ’œ πŸ˜‰ πŸ’― 😁 πŸŽ„ πŸ“Έ 😜 ☹️ 😭 πŸ˜” 😑 πŸ’’ 😀 😳 πŸ™ƒ 😩 😠 πŸ™ˆ πŸ™„\n
64
+ """
65
+ article = """
66
+ \n
67
+ +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
68
+ \n
69
+ We fine tuned Open Ai's CLIP model on both text (tweets) and images of emoji's!\n
70
+ The current model is fine-tuned on 15 samples per emoji.
71
+
72
+ - model: https://huggingface.co/vincentclaes/emoji-predictor \n
73
+ - dataset: https://huggingface.co/datasets/vincentclaes/emoji-predictor \n
74
+ - code: https://github.com/vincentclaes/emoji-predictor \n
75
+ - profile: https://huggingface.co/vincentclaes \n
76
+ """
77
+ examples = [
78
+ "I'm so happy for you!",
79
+ "I'm not feeling great today.",
80
+ "This makes me angry!",
81
+ "Can I follow you?",
82
+ "I'm so bored right now ...",
83
+ ]
84
+ gr.Interface(fn=get_emoji, inputs=text, outputs=gr.Image(shape=(72,72)),
85
+ examples=examples, title=title, description=description,
86
+ article=article).launch()
poetry.lock CHANGED
@@ -155,17 +155,6 @@ category = "main"
155
  optional = false
156
  python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
157
 
158
- [[package]]
159
- name = "commonmark"
160
- version = "0.9.1"
161
- description = "Python parser for the CommonMark Markdown spec"
162
- category = "main"
163
- optional = false
164
- python-versions = "*"
165
-
166
- [package.extras]
167
- test = ["flake8 (==3.7.8)", "hypothesis (==3.55.3)"]
168
-
169
  [[package]]
170
  name = "contourpy"
171
  version = "1.0.5"
@@ -211,29 +200,6 @@ category = "main"
211
  optional = false
212
  python-versions = ">=3.6"
213
 
214
- [[package]]
215
- name = "docarray"
216
- version = "0.16.5"
217
- description = "The data structure for unstructured data"
218
- category = "main"
219
- optional = false
220
- python-versions = "*"
221
-
222
- [package.dependencies]
223
- numpy = "*"
224
- rich = ">=12.0.0"
225
-
226
- [package.extras]
227
- annlite = ["annlite (>=0.3.10)"]
228
- benchmark = ["pandas", "seaborn"]
229
- common = ["protobuf (>=3.13.0)", "lz4", "requests", "matplotlib", "pillow", "fastapi", "uvicorn", "jina-hubble-sdk (>=0.11.0)"]
230
- elasticsearch = ["elasticsearch (>=8.2.0)"]
231
- full = ["protobuf (>=3.13.0)", "lz4", "requests", "matplotlib", "pillow", "trimesh", "scipy", "jina-hubble-sdk (>=0.10.0)", "av", "fastapi", "uvicorn", "strawberry-graphql"]
232
- qdrant = ["qdrant-client (>=0.7.3,<0.8.0)"]
233
- redis = ["redis (>=4.3.0)"]
234
- test = ["pytest", "pytest-timeout", "pytest-mock", "pytest-cov", "pytest-repeat", "pytest-reraise", "mock", "pytest-custom-exit-code", "black (==22.3.0)", "tensorflow (==2.7.0)", "paddlepaddle (==2.2.0)", "torch (==1.9.0)", "torchvision (==0.10.0)", "datasets", "onnx", "onnxruntime", "jupyterlab", "transformers (>=4.16.2)", "weaviate-client (>=3.3.0,<3.4.0)", "annlite (>=0.3.10)", "elasticsearch (>=8.2.0)", "redis (>=4.3.0)", "jina"]
235
- weaviate = ["weaviate-client (>=3.3.0,<3.4.0)"]
236
-
237
  [[package]]
238
  name = "fastapi"
239
  version = "0.85.0"
@@ -567,6 +533,14 @@ category = "main"
567
  optional = false
568
  python-versions = "*"
569
 
 
 
 
 
 
 
 
 
570
  [[package]]
571
  name = "multidict"
572
  version = "6.0.2"
@@ -694,17 +668,6 @@ category = "main"
694
  optional = false
695
  python-versions = "*"
696
 
697
- [[package]]
698
- name = "pygments"
699
- version = "2.13.0"
700
- description = "Pygments is a syntax highlighting package written in Python."
701
- category = "main"
702
- optional = false
703
- python-versions = ">=3.6"
704
-
705
- [package.extras]
706
- plugins = ["importlib-metadata"]
707
-
708
  [[package]]
709
  name = "pynacl"
710
  version = "1.5.0"
@@ -809,21 +772,6 @@ idna = {version = "*", optional = true, markers = "extra == \"idna2008\""}
809
  [package.extras]
810
  idna2008 = ["idna"]
811
 
812
- [[package]]
813
- name = "rich"
814
- version = "12.5.1"
815
- description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
816
- category = "main"
817
- optional = false
818
- python-versions = ">=3.6.3,<4.0.0"
819
-
820
- [package.dependencies]
821
- commonmark = ">=0.9.0,<0.10.0"
822
- pygments = ">=2.6.0,<3.0.0"
823
-
824
- [package.extras]
825
- jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"]
826
-
827
  [[package]]
828
  name = "setuptools-scm"
829
  version = "7.0.5"
@@ -1051,7 +999,7 @@ multidict = ">=4.0"
1051
  [metadata]
1052
  lock-version = "1.1"
1053
  python-versions = "^3.9"
1054
- content-hash = "d1503a7bf493757c63052449403b2d5ed7275e673eaf4ebfcd1c0930e2fada42"
1055
 
1056
  [metadata.files]
1057
  aiohttp = [
@@ -1152,17 +1100,12 @@ click = [
1152
  {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"},
1153
  ]
1154
  colorama = []
1155
- commonmark = [
1156
- {file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"},
1157
- {file = "commonmark-0.9.1.tar.gz", hash = "sha256:452f9dc859be7f06631ddcb328b6919c67984aca654e5fefb3914d54691aed60"},
1158
- ]
1159
  contourpy = []
1160
  cryptography = []
1161
  cycler = [
1162
  {file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
1163
  {file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
1164
  ]
1165
- docarray = []
1166
  fastapi = []
1167
  ffmpy = []
1168
  filelock = []
@@ -1225,6 +1168,7 @@ matplotlib = []
1225
  mdit-py-plugins = []
1226
  mdurl = []
1227
  monotonic = []
 
1228
  multidict = [
1229
  {file = "multidict-6.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b9e95a740109c6047602f4db4da9949e6c5945cefbad34a1299775ddc9a62e2"},
1230
  {file = "multidict-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac0e27844758d7177989ce406acc6a83c16ed4524ebc363c1f748cba184d89d3"},
@@ -1302,7 +1246,6 @@ pycparser = [
1302
  pycryptodome = []
1303
  pydantic = []
1304
  pydub = []
1305
- pygments = []
1306
  pynacl = []
1307
  pyparsing = [
1308
  {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
@@ -1354,7 +1297,6 @@ pyyaml = [
1354
  regex = []
1355
  requests = []
1356
  rfc3986 = []
1357
- rich = []
1358
  setuptools-scm = []
1359
  six = [
1360
  {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
 
155
  optional = false
156
  python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
157
 
 
 
 
 
 
 
 
 
 
 
 
158
  [[package]]
159
  name = "contourpy"
160
  version = "1.0.5"
 
200
  optional = false
201
  python-versions = ">=3.6"
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  [[package]]
204
  name = "fastapi"
205
  version = "0.85.0"
 
533
  optional = false
534
  python-versions = "*"
535
 
536
+ [[package]]
537
+ name = "more-itertools"
538
+ version = "8.14.0"
539
+ description = "More routines for operating on iterables, beyond itertools"
540
+ category = "main"
541
+ optional = false
542
+ python-versions = ">=3.5"
543
+
544
  [[package]]
545
  name = "multidict"
546
  version = "6.0.2"
 
668
  optional = false
669
  python-versions = "*"
670
 
 
 
 
 
 
 
 
 
 
 
 
671
  [[package]]
672
  name = "pynacl"
673
  version = "1.5.0"
 
772
  [package.extras]
773
  idna2008 = ["idna"]
774
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
  [[package]]
776
  name = "setuptools-scm"
777
  version = "7.0.5"
 
999
  [metadata]
1000
  lock-version = "1.1"
1001
  python-versions = "^3.9"
1002
+ content-hash = "5bc12d64b69b9c1f0f68ae6858e97ba26663256bae5a9172c0f5bb69402f6c62"
1003
 
1004
  [metadata.files]
1005
  aiohttp = [
 
1100
  {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"},
1101
  ]
1102
  colorama = []
 
 
 
 
1103
  contourpy = []
1104
  cryptography = []
1105
  cycler = [
1106
  {file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
1107
  {file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
1108
  ]
 
1109
  fastapi = []
1110
  ffmpy = []
1111
  filelock = []
 
1168
  mdit-py-plugins = []
1169
  mdurl = []
1170
  monotonic = []
1171
+ more-itertools = []
1172
  multidict = [
1173
  {file = "multidict-6.0.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b9e95a740109c6047602f4db4da9949e6c5945cefbad34a1299775ddc9a62e2"},
1174
  {file = "multidict-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac0e27844758d7177989ce406acc6a83c16ed4524ebc363c1f748cba184d89d3"},
 
1246
  pycryptodome = []
1247
  pydantic = []
1248
  pydub = []
 
1249
  pynacl = []
1250
  pyparsing = [
1251
  {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
 
1297
  regex = []
1298
  requests = []
1299
  rfc3986 = []
 
1300
  setuptools-scm = []
1301
  six = [
1302
  {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
pyproject.toml CHANGED
@@ -9,6 +9,7 @@ python = "^3.9"
9
  torch = "^1.12.1"
10
  gradio = "^3.3.1"
11
  transformers = "^4.22.1"
 
12
 
13
  [tool.poetry.dev-dependencies]
14
 
 
9
  torch = "^1.12.1"
10
  gradio = "^3.3.1"
11
  transformers = "^4.22.1"
12
+ more-itertools = "^8.14.0"
13
 
14
  [tool.poetry.dev-dependencies]
15
 
requirements.txt CHANGED
@@ -34,6 +34,7 @@ matplotlib==3.6.0
34
  mdit-py-plugins==0.3.0
35
  mdurl==0.1.2
36
  monotonic==1.6
 
37
  multidict==6.0.2
38
  numpy==1.23.3
39
  orjson==3.8.0
 
34
  mdit-py-plugins==0.3.0
35
  mdurl==0.1.2
36
  monotonic==1.6
37
+ more-itertools==8.14.0
38
  multidict==6.0.2
39
  numpy==1.23.3
40
  orjson==3.8.0