hassaanik Jeff28 commited on
Commit
92ce65f
·
verified ·
1 Parent(s): f926a86

Update app.py (#2)

Browse files

- Update app.py (21922409acafb8da66d91c517e8c9ca61cb38b99)


Co-authored-by: Geofrey Kamau <Jeff28@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +51 -51
app.py CHANGED
@@ -1,51 +1,51 @@
1
- from flask import Flask, render_template, request
2
- from PIL import Image
3
- from io import BytesIO
4
- import base64
5
- from predict import predict_potato, predict_tomato
6
- from model import potato_model, tomato_model
7
- import torch
8
-
9
- app = Flask(__name__)
10
-
11
- # Load models
12
- potato_model.load_state_dict(torch.load("models\\potato_model_statedict__f.pth", map_location=torch.device('cpu')))
13
- tomato_model.load_state_dict(torch.load("models\\tomato_model_statedict__f.pth", map_location=torch.device('cpu')))
14
-
15
- # potato_model = torch.load("Models\\potato_model_statedict__f.pth", map_location=torch.device('cpu'))
16
- # potato_model.load_state_dict(torch.load("Models\\potato_model_statedict__f.pth", map_location=torch.device('cpu')))
17
- # tomato_model = torch.load("Models\\tomato_model_statedict__f.pth", map_location=torch.device('cpu'))
18
- # potato_model.load_state_dict(torch.load("Models\\tomato_model_statedict__f.pth", map_location=torch.device('cpu')))
19
-
20
-
21
- @app.route('/')
22
- def home():
23
- # Default to potato model
24
- return render_template('index.html', model_type='potato')
25
-
26
- @app.route('/predict', methods=['POST'])
27
- def predict():
28
- # Get the selected model type
29
- model_type = request.form['model_type']
30
-
31
- # Get the image file from the request
32
- file = request.files['file']
33
-
34
- if model_type == 'tomato':
35
- class_name, probability, image = predict_tomato(file, tomato_model)
36
- background_image = r'static\\tomato_background.jpg'
37
-
38
- else:
39
- class_name, probability, image = predict_potato(file, potato_model)
40
- background_image = r'static\\potato_background.webp'
41
-
42
- # Convert image to base64 format
43
- buffered = BytesIO()
44
- image.save(buffered, format="JPEG")
45
- img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
46
-
47
- # Pass the base64 encoded image and background image to the frontend
48
- return render_template('index.html', image=img_str, class_name=class_name, probability=f"{probability * 100:.2f}%", background_image=background_image)
49
-
50
- if __name__ == '__main__':
51
- app.run(debug=True)
 
1
+ from flask import Flask, render_template, request
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ import base64
5
+ from predict import predict_potato, predict_tomato
6
+ from model import potato_model, tomato_model
7
+ import torch
8
+
9
+ app = Flask(__name__)
10
+
11
+ # Load models
12
+ potato_model.load_state_dict(torch.load("models/potato_model_statedict__f.pth", map_location=torch.device('cpu')))
13
+ tomato_model.load_state_dict(torch.load("models/tomato_model_statedict__f.pth", map_location=torch.device('cpu')))
14
+
15
+ # potato_model = torch.load("Models\\potato_model_statedict__f.pth", map_location=torch.device('cpu'))
16
+ # potato_model.load_state_dict(torch.load("Models\\potato_model_statedict__f.pth", map_location=torch.device('cpu')))
17
+ # tomato_model = torch.load("Models\\tomato_model_statedict__f.pth", map_location=torch.device('cpu'))
18
+ # potato_model.load_state_dict(torch.load("Models\\tomato_model_statedict__f.pth", map_location=torch.device('cpu')))
19
+
20
+
21
+ @app.route('/')
22
+ def home():
23
+ # Default to potato model
24
+ return render_template('index.html', model_type='potato')
25
+
26
+ @app.route('/predict', methods=['POST'])
27
+ def predict():
28
+ # Get the selected model type
29
+ model_type = request.form['model_type']
30
+
31
+ # Get the image file from the request
32
+ file = request.files['file']
33
+
34
+ if model_type == 'tomato':
35
+ class_name, probability, image = predict_tomato(file, tomato_model)
36
+ background_image = r'static/tomato_background.jpg'
37
+
38
+ else:
39
+ class_name, probability, image = predict_potato(file, potato_model)
40
+ background_image = r'static/potato_background.webp'
41
+
42
+ # Convert image to base64 format
43
+ buffered = BytesIO()
44
+ image.save(buffered, format="JPEG")
45
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
46
+
47
+ # Pass the base64 encoded image and background image to the frontend
48
+ return render_template('index.html', image=img_str, class_name=class_name, probability=f"{probability * 100:.2f}%", background_image=background_image)
49
+
50
+ if __name__ == '__main__':
51
+ app.run(debug=True)