Spaces:
Runtime error
Runtime error
eddydecena
commited on
Commit
•
54c0802
1
Parent(s):
03e858a
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import tensorflow as tf
|
5 |
+
from keras_tuner import HyperParameters
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
|
8 |
+
from src.models import MakeHyperModel
|
9 |
+
from src.preprocessing import get_data_augmentation
|
10 |
+
from src.config import IMAGE_SIZE
|
11 |
+
|
12 |
+
data_augmentation = get_data_augmentation()
|
13 |
+
cache_dir = os.path.join('hf_hub')
|
14 |
+
|
15 |
+
for f in ['checkpoint', 'checkpoint.data-00000-of-00001', 'checkpoint.index']:
|
16 |
+
print(f)
|
17 |
+
old_name = hf_hub_download(repo_id="eddydecena/cat-vs-dog", filename=f"tuner_model/cat-vs-dog/trial_0484d8d758a5ef7b91ca97d334ba7870/checkpoints/epoch_0/{f}", cache_dir=cache_dir)
|
18 |
+
temp_value = old_name.split('/')
|
19 |
+
temp_value.pop(-1)
|
20 |
+
path = '/'.join(temp_value)
|
21 |
+
os.rename(old_name, os.path.join(path, f))
|
22 |
+
|
23 |
+
latest = tf.train.latest_checkpoint('./tuner_model/cat-vs-dog/trial_0484d8d758a5ef7b91ca97d334ba7870/checkpoints/epoch_0')
|
24 |
+
hypermodel = MakeHyperModel(input_shape=IMAGE_SIZE + (3,), num_classes=2, data_augmentation=data_augmentation)
|
25 |
+
model = hypermodel.build(hp=HyperParameters())
|
26 |
+
model.load_weights(latest).expect_partial()
|
27 |
+
|
28 |
+
def cat_vs_dog(image):
|
29 |
+
img_array = tf.constant(image, dtype=tf.float32)
|
30 |
+
img_array = tf.expand_dims(img_array, 0)
|
31 |
+
predictions = model.predict(img_array)
|
32 |
+
score = predictions[0]
|
33 |
+
return {'cat': float((1 - score)), 'dog': float(score)}
|
34 |
+
|
35 |
+
iface = gr.Interface(
|
36 |
+
cat_vs_dog,
|
37 |
+
gr.inputs.Image(shape=IMAGE_SIZE),
|
38 |
+
gr.outputs.Label(num_top_classes=2),
|
39 |
+
capture_session=True,
|
40 |
+
interpretation="default",
|
41 |
+
examples=[
|
42 |
+
["examples/cat1.jpg"],
|
43 |
+
["examples/cat2.jpg"],
|
44 |
+
["examples/dog1.jpeg"],
|
45 |
+
["examples/dog2.jpeg"]
|
46 |
+
])
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
iface.launch()
|