Spaces:
Running
Running
Using a new model trained on 224x224 images
Browse filesRetrained resnet34 using the same FER203 dataset, but
resized these images to 224x224 - these are what the original
model expects, and the accuracy is a bit better, with no data
cleaning steps.
Also, most images people will upload will be higher resolution
so using 224x224 pixels will give better results than using 48x48.
- app.py +3 -2
- environment.yaml +203 -0
- fec224-resnet34-v1.pkl +3 -0
app.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from fastai.vision.all import *
|
| 3 |
import cv2
|
|
|
|
| 4 |
|
| 5 |
-
learn = load_learner('
|
| 6 |
labels = learn.dls.vocab
|
| 7 |
def predict(img):
|
| 8 |
image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY, dstCn=3 )
|
|
@@ -12,6 +13,6 @@ def predict(img):
|
|
| 12 |
title = "Facial Expression Classifier"
|
| 13 |
description = "A facial expression classifier, trained using the <a href='https://www.kaggle.com/datasets/msambare/fer2013'>FER-2013 dataset</a>. This dataset consists of 28,709 examples of faces: each one is 48x48 grayscale pixels and is labelled with one of the following expressions: anger, disgust, fear, happy, neutral, sad, surprise.<p><p>This was used to train a resnet34 model."
|
| 14 |
examples = ["angryExample.jpg", "disgustExample.jpg", "fearExample.jpg", "happyExample.jpg", "neutralExample.jpg", "sadExample.jpg", "surpriseExample.jpg"]
|
| 15 |
-
iface = gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(
|
| 16 |
iface.launch()
|
| 17 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from fastai.vision.all import *
|
| 3 |
import cv2
|
| 4 |
+
import PIL
|
| 5 |
|
| 6 |
+
learn = load_learner('fec224-resnet34-v1.pkl')
|
| 7 |
labels = learn.dls.vocab
|
| 8 |
def predict(img):
|
| 9 |
image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY, dstCn=3 )
|
|
|
|
| 13 |
title = "Facial Expression Classifier"
|
| 14 |
description = "A facial expression classifier, trained using the <a href='https://www.kaggle.com/datasets/msambare/fer2013'>FER-2013 dataset</a>. This dataset consists of 28,709 examples of faces: each one is 48x48 grayscale pixels and is labelled with one of the following expressions: anger, disgust, fear, happy, neutral, sad, surprise.<p><p>This was used to train a resnet34 model."
|
| 15 |
examples = ["angryExample.jpg", "disgustExample.jpg", "fearExample.jpg", "happyExample.jpg", "neutralExample.jpg", "sadExample.jpg", "surpriseExample.jpg"]
|
| 16 |
+
iface = gr.Interface(fn=predict, inputs=gr.inputs.Image(shape=(224,224)), outputs=gr.outputs.Label(num_top_classes=3), examples=examples, title=title, description=description,interpretation='default')
|
| 17 |
iface.launch()
|
| 18 |
|
environment.yaml
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: fastai
|
| 2 |
+
channels:
|
| 3 |
+
- fastai
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=main
|
| 7 |
+
- _openmp_mutex=5.1=1_gnu
|
| 8 |
+
- _pytorch_select=0.2=gpu_0
|
| 9 |
+
- blas=1.0=mkl
|
| 10 |
+
- blosc=1.21.0=h8c45485_0
|
| 11 |
+
- bottleneck=1.3.4=py39hce1f21e_0
|
| 12 |
+
- brotli=1.0.9=h5eee18b_7
|
| 13 |
+
- brotli-bin=1.0.9=h5eee18b_7
|
| 14 |
+
- brotlipy=0.7.0=py39h27cfd23_1003
|
| 15 |
+
- brunsli=0.1=h2531618_0
|
| 16 |
+
- bzip2=1.0.8=h7b6447c_0
|
| 17 |
+
- c-ares=1.18.1=h7f8727e_0
|
| 18 |
+
- ca-certificates=2022.10.11=h06a4308_0
|
| 19 |
+
- catalogue=2.0.7=py39h06a4308_0
|
| 20 |
+
- certifi=2022.12.7=py39h06a4308_0
|
| 21 |
+
- cffi=1.15.1=py39h74dc2b5_0
|
| 22 |
+
- cfitsio=3.470=h5893167_7
|
| 23 |
+
- charls=2.2.0=h2531618_0
|
| 24 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
| 25 |
+
- click=8.0.4=py39h06a4308_0
|
| 26 |
+
- cloudpickle=2.0.0=pyhd3eb1b0_0
|
| 27 |
+
- colorama=0.4.6=py39h06a4308_0
|
| 28 |
+
- cryptography=38.0.1=py39h9ce1e76_0
|
| 29 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
| 30 |
+
- cymem=2.0.6=py39h295c915_0
|
| 31 |
+
- cython-blis=0.7.7=py39hce1f21e_0
|
| 32 |
+
- cytoolz=0.11.0=py39h27cfd23_0
|
| 33 |
+
- dask-core=2022.7.0=py39h06a4308_0
|
| 34 |
+
- dbus=1.13.18=hb2f20db_0
|
| 35 |
+
- expat=2.4.4=h295c915_0
|
| 36 |
+
- fastcore=1.5.27=py_0
|
| 37 |
+
- fastdownload=0.0.7=py_0
|
| 38 |
+
- fastprogress=1.0.3=py_0
|
| 39 |
+
- fftw=3.3.9=h27cfd23_1
|
| 40 |
+
- flit-core=3.6.0=pyhd3eb1b0_0
|
| 41 |
+
- fontconfig=2.13.1=h6c09931_0
|
| 42 |
+
- fonttools=4.25.0=pyhd3eb1b0_0
|
| 43 |
+
- freetype=2.11.0=h70c0345_0
|
| 44 |
+
- fsspec=2022.11.0=py39h06a4308_0
|
| 45 |
+
- future=0.18.2=py39h06a4308_1
|
| 46 |
+
- giflib=5.2.1=h7b6447c_0
|
| 47 |
+
- glib=2.69.1=h4ff587b_1
|
| 48 |
+
- gst-plugins-base=1.14.0=h8213a91_2
|
| 49 |
+
- gstreamer=1.14.0=h28cd5cc_2
|
| 50 |
+
- icu=58.2=he6710b0_3
|
| 51 |
+
- idna=3.4=py39h06a4308_0
|
| 52 |
+
- imagecodecs=2021.8.26=py39h4cda21f_0
|
| 53 |
+
- intel-openmp=2021.4.0=h06a4308_3561
|
| 54 |
+
- jinja2=3.1.2=py39h06a4308_0
|
| 55 |
+
- joblib=1.1.1=py39h06a4308_0
|
| 56 |
+
- jpeg=9e=h7f8727e_0
|
| 57 |
+
- jxrlib=1.1=h7b6447c_2
|
| 58 |
+
- kiwisolver=1.4.2=py39h295c915_0
|
| 59 |
+
- krb5=1.19.2=hac12032_0
|
| 60 |
+
- langcodes=3.3.0=pyhd3eb1b0_0
|
| 61 |
+
- lcms2=2.12=h3be6417_0
|
| 62 |
+
- ld_impl_linux-64=2.38=h1181459_1
|
| 63 |
+
- lerc=3.0=h295c915_0
|
| 64 |
+
- libaec=1.0.4=he6710b0_1
|
| 65 |
+
- libbrotlicommon=1.0.9=h5eee18b_7
|
| 66 |
+
- libbrotlidec=1.0.9=h5eee18b_7
|
| 67 |
+
- libbrotlienc=1.0.9=h5eee18b_7
|
| 68 |
+
- libcurl=7.86.0=h91b91d3_0
|
| 69 |
+
- libdeflate=1.8=h7f8727e_5
|
| 70 |
+
- libedit=3.1.20221030=h5eee18b_0
|
| 71 |
+
- libev=4.33=h7f8727e_1
|
| 72 |
+
- libffi=3.3=he6710b0_2
|
| 73 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 74 |
+
- libgfortran-ng=11.2.0=h00389a5_1
|
| 75 |
+
- libgfortran5=11.2.0=h1234567_1
|
| 76 |
+
- libgomp=11.2.0=h1234567_1
|
| 77 |
+
- libmklml=2019.0.5=h06a4308_0
|
| 78 |
+
- libnghttp2=1.46.0=hce63b2e_0
|
| 79 |
+
- libpng=1.6.37=hbc83047_0
|
| 80 |
+
- libssh2=1.10.0=h8f2d780_0
|
| 81 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 82 |
+
- libtiff=4.2.0=h85742a9_0
|
| 83 |
+
- libuuid=1.0.3=h7f8727e_2
|
| 84 |
+
- libwebp=1.2.4=h11a3e52_0
|
| 85 |
+
- libwebp-base=1.2.4=h5eee18b_0
|
| 86 |
+
- libxcb=1.15=h7f8727e_0
|
| 87 |
+
- libxml2=2.9.14=h74e7548_0
|
| 88 |
+
- libzopfli=1.0.3=he6710b0_0
|
| 89 |
+
- locket=1.0.0=py39h06a4308_0
|
| 90 |
+
- lz4-c=1.9.4=h6a678d5_0
|
| 91 |
+
- markupsafe=2.1.1=py39h7f8727e_0
|
| 92 |
+
- mkl=2021.4.0=h06a4308_640
|
| 93 |
+
- mkl-service=2.4.0=py39h7f8727e_0
|
| 94 |
+
- mkl_fft=1.3.0=py39h54f3939_0
|
| 95 |
+
- mkl_random=1.0.2=py39h63df603_0
|
| 96 |
+
- munkres=1.1.4=py_0
|
| 97 |
+
- murmurhash=1.0.7=py39h295c915_0
|
| 98 |
+
- ncurses=6.3=h5eee18b_3
|
| 99 |
+
- networkx=2.8.4=py39h06a4308_0
|
| 100 |
+
- ninja=1.10.2=h06a4308_5
|
| 101 |
+
- ninja-base=1.10.2=hd09550d_5
|
| 102 |
+
- numexpr=2.7.3=py39hb2eb853_0
|
| 103 |
+
- openjpeg=2.4.0=h3ad879b_0
|
| 104 |
+
- openssl=1.1.1s=h7f8727e_0
|
| 105 |
+
- packaging=22.0=py39h06a4308_0
|
| 106 |
+
- pandas=1.4.2=py39h295c915_0
|
| 107 |
+
- partd=1.2.0=pyhd3eb1b0_1
|
| 108 |
+
- pathy=0.6.1=py39h06a4308_0
|
| 109 |
+
- pcre=8.45=h295c915_0
|
| 110 |
+
- pip=22.3.1=py39h06a4308_0
|
| 111 |
+
- preshed=3.0.6=py39h295c915_0
|
| 112 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
| 113 |
+
- pyopenssl=22.0.0=pyhd3eb1b0_0
|
| 114 |
+
- pyparsing=3.0.9=py39h06a4308_0
|
| 115 |
+
- pyqt=5.9.2=py39h2531618_6
|
| 116 |
+
- pysocks=1.7.1=py39h06a4308_0
|
| 117 |
+
- python=3.9.12=h12debd9_1
|
| 118 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
| 119 |
+
- pytz=2022.7=py39h06a4308_0
|
| 120 |
+
- pywavelets=1.3.0=py39h7f8727e_0
|
| 121 |
+
- pyyaml=6.0=py39h5eee18b_1
|
| 122 |
+
- qt=5.9.7=h5867ecd_1
|
| 123 |
+
- readline=8.2=h5eee18b_0
|
| 124 |
+
- requests=2.28.1=py39h06a4308_0
|
| 125 |
+
- scikit-learn=1.0.2=py39h51133e4_1
|
| 126 |
+
- scipy=1.6.2=py39h91f5cce_0
|
| 127 |
+
- setuptools=65.5.0=py39h06a4308_0
|
| 128 |
+
- shellingham=1.5.0=py39h06a4308_0
|
| 129 |
+
- sip=4.19.13=py39h295c915_0
|
| 130 |
+
- six=1.16.0=pyhd3eb1b0_1
|
| 131 |
+
- smart_open=5.2.1=py39h06a4308_0
|
| 132 |
+
- snappy=1.1.9=h295c915_0
|
| 133 |
+
- spacy=3.3.0=py39hae6d005_0
|
| 134 |
+
- spacy-legacy=3.0.9=py39h06a4308_0
|
| 135 |
+
- spacy-loggers=1.0.1=pyhd3eb1b0_0
|
| 136 |
+
- sqlite=3.40.0=h5082296_0
|
| 137 |
+
- srsly=2.4.3=py39h295c915_0
|
| 138 |
+
- thinc=8.0.15=py39hae6d005_0
|
| 139 |
+
- threadpoolctl=2.2.0=pyh0d69192_0
|
| 140 |
+
- tifffile=2021.7.2=pyhd3eb1b0_2
|
| 141 |
+
- tk=8.6.12=h1ccaba5_0
|
| 142 |
+
- toolz=0.12.0=py39h06a4308_0
|
| 143 |
+
- tornado=6.1=py39h27cfd23_0
|
| 144 |
+
- tqdm=4.64.1=py39h06a4308_0
|
| 145 |
+
- typer=0.4.1=py39h06a4308_0
|
| 146 |
+
- typing-extensions=4.4.0=py39h06a4308_0
|
| 147 |
+
- typing_extensions=4.4.0=py39h06a4308_0
|
| 148 |
+
- tzdata=2022g=h04d1e81_0
|
| 149 |
+
- urllib3=1.26.13=py39h06a4308_0
|
| 150 |
+
- wasabi=0.9.1=py39h06a4308_0
|
| 151 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
| 152 |
+
- xz=5.2.8=h5eee18b_0
|
| 153 |
+
- yaml=0.2.5=h7b6447c_0
|
| 154 |
+
- zfp=0.5.5=h295c915_6
|
| 155 |
+
- zlib=1.2.13=h5eee18b_0
|
| 156 |
+
- zstd=1.4.9=haebb681_0
|
| 157 |
+
- pip:
|
| 158 |
+
- aiohttp==3.8.3
|
| 159 |
+
- aiosignal==1.3.1
|
| 160 |
+
- altair==4.2.0
|
| 161 |
+
- anyio==3.6.2
|
| 162 |
+
- async-timeout==4.0.2
|
| 163 |
+
- attrs==22.2.0
|
| 164 |
+
- contourpy==1.0.6
|
| 165 |
+
- entrypoints==0.4
|
| 166 |
+
- fastai==2.7.10
|
| 167 |
+
- fastapi==0.88.0
|
| 168 |
+
- ffmpy==0.3.0
|
| 169 |
+
- frozenlist==1.3.3
|
| 170 |
+
- gradio==3.15.0
|
| 171 |
+
- h11==0.14.0
|
| 172 |
+
- httpcore==0.16.3
|
| 173 |
+
- httpx==0.23.2
|
| 174 |
+
- jsonschema==4.17.3
|
| 175 |
+
- linkify-it-py==1.0.3
|
| 176 |
+
- markdown-it-py==2.1.0
|
| 177 |
+
- matplotlib==3.6.2
|
| 178 |
+
- mdit-py-plugins==0.3.3
|
| 179 |
+
- mdurl==0.1.2
|
| 180 |
+
- multidict==6.0.4
|
| 181 |
+
- numpy==1.22.4
|
| 182 |
+
- nvidia-cublas-cu11==11.10.3.66
|
| 183 |
+
- nvidia-cuda-nvrtc-cu11==11.7.99
|
| 184 |
+
- nvidia-cuda-runtime-cu11==11.7.99
|
| 185 |
+
- nvidia-cudnn-cu11==8.5.0.96
|
| 186 |
+
- opencv-python==4.7.0.68
|
| 187 |
+
- orjson==3.8.3
|
| 188 |
+
- pillow==9.1.1
|
| 189 |
+
- pycryptodome==3.16.0
|
| 190 |
+
- pydantic==1.8.2
|
| 191 |
+
- pydub==0.25.1
|
| 192 |
+
- pyrsistent==0.19.3
|
| 193 |
+
- python-multipart==0.0.5
|
| 194 |
+
- rfc3986==1.5.0
|
| 195 |
+
- sniffio==1.3.0
|
| 196 |
+
- starlette==0.22.0
|
| 197 |
+
- torch==1.13.1
|
| 198 |
+
- torchvision==0.14.1
|
| 199 |
+
- uc-micro-py==1.0.1
|
| 200 |
+
- uvicorn==0.20.0
|
| 201 |
+
- websockets==10.4
|
| 202 |
+
- yarl==1.8.2
|
| 203 |
+
prefix: /home/lbiswas/anaconda3/envs/fastai
|
fec224-resnet34-v1.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94be543f9d508f82c5ac7e9e314ec217ae3bfceb515a5e51c7c9c3f65fc910b3
|
| 3 |
+
size 87860517
|