resep-kue-lokal / app.py
RickyIG's picture
Update app.py
f2ef965
raw
history blame
No virus
3.99 kB
import gradio as gr
import json
import requests
import torch
import torchvision
import torchvision.models as models
from torchvision import datasets, transforms
from torchvision.models import mobilenet_v2
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
from jcopdl.callback import Callback, set_config
import pandas as pd
import numpy as np
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device
import openai
import os
import time
from pathlib import Path
from PIL import Image
import io
torch.manual_seed(0)
class CustomMobileNetv2(nn.Module):
def __init__(self, output_size):
super().__init__()
self.mnet = mobilenet_v2(pretrained=True)
self.freeze()
self.mnet.classifier = nn.Sequential(
nn.Linear(1280, output_size),
nn.LogSoftmax()
)
def forward(self, x):
return self.mnet(x)
def freeze(self):
for param in self.mnet.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.mnet.parameters():
param.requires_grad = True
kue_lokal_model = torch.load('rickyig_mobilenetv2_kue_lokal_classifier_entire_model.pth', map_location=torch.device('cpu'))
dict_for_inference = {0: 'kue dadar gulung',
1: 'kue kastengel',
2: 'kue klepon',
3: 'kue lapis',
4: 'kue lumpur',
5: 'kue putri salju',
6: 'kue risoles',
7: 'kue serabi'}
def get_completion(prompt, model="gpt-3.5-turbo"):
messages = [{"role": "user", "content": prompt}]
response = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=0,
)
return response.choices[0].message["content"]
def get_response(classify_result):
prompt = "Apa itu {} dan sebutkan resep dari {}.".format(classify_result, classify_result)
response = get_completion(prompt)
return response
def classify_image(input_image):
kue_lokal_model.eval()
image_for_testing = input_image
img = Image.open(image_for_testing)
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
input_data = transform(img).unsqueeze(0).to(device='cpu')
class_to_label = dict_for_inference
with torch.no_grad():
output = kue_lokal_model(input_data)
probs = torch.nn.functional.softmax(output, dim=1)
conf, predicted_class = torch.max(probs, 1)
# Create a dictionary of class labels and their probabilities
output_dict = {"predicted_label": class_to_label[predicted_class.item()], "probability": conf.item()}
# Convert the dictionary to JSON format
output_json = json.dumps(output_dict)
output_bentuk_text = "Hasil Klasifikasi Gambar \nKue : {} \nProbability: {:.2f}%".format(class_to_label[predicted_class.item()], conf.item()*100)
output_response = get_response(class_to_label[predicted_class.item()])
return output_json, output_bentuk_text, output_response
# Create a Gradio interface
input_image = gr.Image(label="input_image", type="filepath")
output_json = gr.JSON(label="Output (JSON)")
output_bentuk_text = gr.Textbox(label="Hasil Output")
output_response = gr.Textbox(label="Resep Kue")
example_input_image = "3.jpg"
interface = gr.Interface(
fn=classify_image,
inputs=input_image,
outputs=[output_json, output_bentuk_text, output_response], # Use JSON output
title="Resep Kue Lokal",
examples=[
[example_input_image]
],
description="Unggah foto kue lokal dan dapatkan hasil klasifikasi gambar beserta resep kue.<br>Kue yang tersedia: kue dadar gulung, kue kastengel, kue klepon, kue lapis, kue lumpur, kue putri salju, kue risoles, kue serabi.",
)
# Start the Gradio app
interface.launch(share=True, debug=True)