Matej commited on
Commit
fe6f0ef
1 Parent(s): d282048
Files changed (5) hide show
  1. .gitignore +4 -0
  2. README.md +1 -1
  3. classes.txt +101 -0
  4. my_app.py +53 -0
  5. requirements.txt +97 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .ipynb_checkpoints
2
+ flagged
3
+ model_checkpoints
4
+ saved_model
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: gray
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.1.2
8
- app_file: app.py
9
  pinned: false
10
  ---
11
 
 
5
  colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.1.2
8
+ app_file: my_app.py
9
  pinned: false
10
  ---
11
 
classes.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apple_pie
2
+ baby_back_ribs
3
+ baklava
4
+ beef_carpaccio
5
+ beef_tartare
6
+ beet_salad
7
+ beignets
8
+ bibimbap
9
+ bread_pudding
10
+ breakfast_burrito
11
+ bruschetta
12
+ caesar_salad
13
+ cannoli
14
+ caprese_salad
15
+ carrot_cake
16
+ ceviche
17
+ cheesecake
18
+ cheese_plate
19
+ chicken_curry
20
+ chicken_quesadilla
21
+ chicken_wings
22
+ chocolate_cake
23
+ chocolate_mousse
24
+ churros
25
+ clam_chowder
26
+ club_sandwich
27
+ crab_cakes
28
+ creme_brulee
29
+ croque_madame
30
+ cup_cakes
31
+ deviled_eggs
32
+ donuts
33
+ dumplings
34
+ edamame
35
+ eggs_benedict
36
+ escargots
37
+ falafel
38
+ filet_mignon
39
+ fish_and_chips
40
+ foie_gras
41
+ french_fries
42
+ french_onion_soup
43
+ french_toast
44
+ fried_calamari
45
+ fried_rice
46
+ frozen_yogurt
47
+ garlic_bread
48
+ gnocchi
49
+ greek_salad
50
+ grilled_cheese_sandwich
51
+ grilled_salmon
52
+ guacamole
53
+ gyoza
54
+ hamburger
55
+ hot_and_sour_soup
56
+ hot_dog
57
+ huevos_rancheros
58
+ hummus
59
+ ice_cream
60
+ lasagna
61
+ lobster_bisque
62
+ lobster_roll_sandwich
63
+ macaroni_and_cheese
64
+ macarons
65
+ miso_soup
66
+ mussels
67
+ nachos
68
+ omelette
69
+ onion_rings
70
+ oysters
71
+ pad_thai
72
+ paella
73
+ pancakes
74
+ panna_cotta
75
+ peking_duck
76
+ pho
77
+ pizza
78
+ pork_chop
79
+ poutine
80
+ prime_rib
81
+ pulled_pork_sandwich
82
+ ramen
83
+ ravioli
84
+ red_velvet_cake
85
+ risotto
86
+ samosa
87
+ sashimi
88
+ scallops
89
+ seaweed_salad
90
+ shrimp_and_grits
91
+ spaghetti_bolognese
92
+ spaghetti_carbonara
93
+ spring_rolls
94
+ steak
95
+ strawberry_shortcake
96
+ sushi
97
+ tacos
98
+ takoyaki
99
+ tiramisu
100
+ tuna_tartare
101
+ waffles
my_app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+
6
+ # Load your trained models
7
+ model1 = tf.keras.models.load_model('model/FoodVisionFineTuneAug/')
8
+ model2 = tf.keras.models.load_model('model/FoodVisionFineTune/')
9
+
10
+ with open('classes.txt', 'r') as f:
11
+ classes = [line.strip() for line in f]
12
+
13
+ # Add information about the models
14
+ model1_info = """
15
+ ### Model 1 Information
16
+
17
+ This model is based on the EfficientNetB0 architecture and was trained on the Food101 dataset.
18
+ """
19
+
20
+ model2_info = """
21
+ ### Model 2 Information
22
+
23
+ This model is based on the EfficientNetB0 architecture and was trained on augmented data, providing improved generalization.
24
+ """
25
+
26
+ def preprocess(image: Image.Image):
27
+ # Convert numpy array to PIL Image
28
+ image = Image.fromarray((image * 255).astype(np.uint8))
29
+ image = image.resize((224, 224)) # replace with the input size of your models
30
+ image = np.array(image)
31
+ # image = image / 255.0 # normalize if you've done so while training
32
+ image = np.expand_dims(image, axis=0)
33
+ return image
34
+
35
+ def predict(model_selection, image: Image.Image):
36
+ # Choose the model based on the dropdown selection
37
+ model = model1 if model_selection == "EfficentNetB0 Fine Tune" else model2
38
+
39
+ image = preprocess(image)
40
+ prediction = model.predict(image)
41
+ predicted_class = classes[np.argmax(prediction)]
42
+ confidence = np.max(prediction)
43
+ return predicted_class, confidence
44
+
45
+ iface = gr.Interface(
46
+ fn=predict,
47
+ inputs=[gr.Dropdown(["EfficentNetB0 Fine Tune", "EfficentNetB0 Fine Tune Augmented"]), gr.Image()],
48
+ outputs=[gr.Textbox(label="Predicted Class"), gr.Textbox(label="Confidence")],
49
+ title="Transfer Learning Mini Project",
50
+ description=f"{model1_info}\n\n{model2_info}",
51
+ )
52
+
53
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.0.0
2
+ aiofiles==23.2.1
3
+ altair==5.1.2
4
+ annotated-types==0.6.0
5
+ anyio==3.7.1
6
+ astunparse==1.6.3
7
+ attrs==23.1.0
8
+ cachetools==5.3.2
9
+ certifi==2023.7.22
10
+ charset-normalizer==3.3.2
11
+ click==8.1.7
12
+ colorama==0.4.6
13
+ contourpy==1.2.0
14
+ cycler==0.12.1
15
+ exceptiongroup==1.1.3
16
+ fastapi==0.104.1
17
+ ffmpy==0.3.1
18
+ filelock==3.13.1
19
+ flatbuffers==23.5.26
20
+ fonttools==4.44.0
21
+ fsspec==2023.10.0
22
+ gast==0.5.4
23
+ google-auth==2.23.4
24
+ google-auth-oauthlib==1.0.0
25
+ google-pasta==0.2.0
26
+ gradio==4.1.2
27
+ gradio_client==0.7.0
28
+ grpcio==1.59.2
29
+ h11==0.14.0
30
+ h5py==3.10.0
31
+ httpcore==1.0.1
32
+ httpx==0.25.1
33
+ huggingface-hub==0.19.0
34
+ idna==3.4
35
+ importlib-metadata==6.8.0
36
+ importlib-resources==6.1.1
37
+ Jinja2==3.1.2
38
+ jsonschema==4.19.2
39
+ jsonschema-specifications==2023.7.1
40
+ keras==2.14.0
41
+ kiwisolver==1.4.5
42
+ libclang==16.0.6
43
+ Markdown==3.5.1
44
+ markdown-it-py==3.0.0
45
+ MarkupSafe==2.1.3
46
+ matplotlib==3.8.1
47
+ mdurl==0.1.2
48
+ ml-dtypes==0.2.0
49
+ numpy==1.26.1
50
+ oauthlib==3.2.2
51
+ opt-einsum==3.3.0
52
+ orjson==3.9.10
53
+ packaging==23.2
54
+ pandas==2.1.2
55
+ Pillow==10.1.0
56
+ protobuf==4.25.0
57
+ pyasn1==0.5.0
58
+ pyasn1-modules==0.3.0
59
+ pydantic==2.4.2
60
+ pydantic_core==2.10.1
61
+ pydub==0.25.1
62
+ Pygments==2.16.1
63
+ pyparsing==3.1.1
64
+ python-dateutil==2.8.2
65
+ python-multipart==0.0.6
66
+ pytz==2023.3.post1
67
+ PyYAML==6.0.1
68
+ referencing==0.30.2
69
+ requests==2.31.0
70
+ requests-oauthlib==1.3.1
71
+ rich==13.6.0
72
+ rpds-py==0.12.0
73
+ rsa==4.9
74
+ semantic-version==2.10.0
75
+ shellingham==1.5.4
76
+ six==1.16.0
77
+ sniffio==1.3.0
78
+ starlette==0.27.0
79
+ tensorboard==2.14.1
80
+ tensorboard-data-server==0.7.2
81
+ tensorflow==2.14.0
82
+ tensorflow-estimator==2.14.0
83
+ tensorflow-intel==2.14.0
84
+ tensorflow-io-gcs-filesystem==0.31.0
85
+ termcolor==2.3.0
86
+ tomlkit==0.12.0
87
+ toolz==0.12.0
88
+ tqdm==4.66.1
89
+ typer==0.9.0
90
+ typing_extensions==4.8.0
91
+ tzdata==2023.3
92
+ urllib3==2.0.7
93
+ uvicorn==0.24.0.post1
94
+ websockets==11.0.3
95
+ Werkzeug==3.0.1
96
+ wrapt==1.14.1
97
+ zipp==3.17.0