diamantrsd's picture
Update app.py
5547ccf verified
import gradio as gr
import tensorflow as tf
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from keras.preprocessing import image
from tensorflow.image import resize
from keras.layers import BatchNormalization
# Load the image classification model
custom_objects = {'BatchNormalization': BatchNormalization}
# Load the model with custom objects
try:
model = tf.keras.models.load_model('model1025_ver1.h5', custom_objects=custom_objects)
except Exception as e:
print(f"An error occurred: {e}")
# Load GPT-2 tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("diamantrsd/copywriting-otomatis")
gpt2_model = GPT2LMHeadModel.from_pretrained("diamantrsd/copywriting-otomatis")
def predict_image(img_path, keyword=""):
if img_path is None:
return "Error: Belum upload gambar. Silakan upload gambar"
# Preprocess the image
img = image.load_img(img_path, target_size=(224, 224))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = img_array / 255.0
# Perform inference with the model
predictions = model.predict(img_array)
# Get the predicted class
predicted_class = np.argmax(predictions[0])
# You might need to create a dictionary mapping class indices to their labels
class_labels = {
0: 'Backpack',
1: 'Celana Panjang',
2: 'Celana Pendek',
3: 'Dompet',
4: 'Dress',
5: 'kacamata',
6: 'Kaos',
7: 'Kaos Kaki',
8: 'Kemeja',
9: 'Outerwear',
10: 'Sandal',
11: 'Sepatu',
12: 'Sepatu Flat',
13: 'Tas',
14: 'Topi'
}
# Get class label
class_label = class_labels[predicted_class]
# Generate copywriting based on class label
copywriting = generate_copywriting(class_label, keyword)
# Return the prediction result and copywriting
return class_label, copywriting
# Function to generate copywriting
def generate_copywriting(class_label,keyword):
# Generate prompt based on class label
prompt = f"{keyword.lower()}, {class_label.lower()}" if keyword else f"{class_label.lower()}"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# Adjust parameters as needed
max_length = 100
no_repeat_ngram_size = 3
top_k = 35
top_p = 0.50
output = gpt2_model.generate(
input_ids,
max_length=max_length,
no_repeat_ngram_size=no_repeat_ngram_size,
top_k=top_k,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
generated_text = generated_text.replace(keyword, '', 1).strip()
return generated_text
image_input = gr.Image(label="Upload Gambar Fashion Kamu", type="filepath")
keyword_input = gr.Textbox(label="Keyword Tambahan (Opsional)")
output_label = gr.Textbox(label="Kelas Prediksi")
output_copywriting = gr.Textbox(label="Copywriting")
# Add the prediction function to the Gradio interface
iface = gr.Interface(fn=predict_image, inputs=[image_input, keyword_input], outputs=[output_label, output_copywriting], title="Copywriting Otomatis")
# Run the Gradio interface
iface.launch()