Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import tensorflow as tf | |
from tensorflow.keras.models import load_model # Thêm dòng này để import load_model | |
from PIL import Image | |
import numpy as np | |
# Tải mô hình YOLOv8 (YOLOv8) | |
yolo_model = torch.hub.load("ultralytics/yolov5", "yolov5s") # Lưu ý bạn có thể thay đổi mô hình nếu có | |
# Tải mô hình CNN | |
cnn_model = load_model('cnn_food_classifier.h5') # Bây giờ load_model sẽ hoạt động | |
# Lớp phân loại món ăn cho CNN | |
food_classes = ['Ca hu kho', 'Canh cai', 'Canh chua', 'Com trang', 'Dau hu sot ca', 'Ga chien', 'Rau muong xao', 'Thit kho', 'Thit kho trung', 'Trung chien'] | |
# Hàm để nhận diện món ăn và dự đoán giá tiền | |
def predict_food(image): | |
# Dự đoán với YOLOv8 | |
results = yolo_model(image) | |
detected_classes = results.names | |
predicted_classes = [detected_classes[class_id] for class_id in results.pred[0][:, -1].cpu().numpy().astype(int)] | |
# Dự đoán với CNN | |
image = np.array(image.resize((224, 224))) # Resize ảnh cho phù hợp với input của CNN | |
image = np.expand_dims(image, axis=0) # Thêm chiều batch | |
image = image / 255.0 # Chuẩn hóa ảnh | |
cnn_predictions = cnn_model.predict(image) # Dự đoán với mô hình CNN | |
# Xử lý kết quả dự đoán | |
food_prediction = food_classes[np.argmax(cnn_predictions)] | |
return {"food_class": food_prediction, "detected_classes": predicted_classes} | |