diamantrsd's picture
Update app.py
db0625e verified
raw
history blame
No virus
2.67 kB
import gradio as gr
import tensorflow as tf
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
from tensorflow.image import resize
# Load the image classification model
image_classification_model = tf.keras.models.load_model("mobilenetfashion_v2.h5")
# Load pre-trained GPT-2 model and tokenizer
gpt2_model_name = "diamantrsd/copywriting-otomatis"
gpt2_model = GPT2LMHeadModel.from_pretrained(gpt2_model_name)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model_name)
def classify_and_generate_text(image, keywords=""):
try:
# Convert Gradio Image interface output to a NumPy array
img_array = image.astype('float32') / 255.0
# Resize the image to the expected shape (224, 224)
img_array_resized = resize(img_array, (224, 224))
# Classify the resized image using the image classification model
class_label = image_classification_model.predict(np.expand_dims(img_array_resized, axis=0))
# Map class label to corresponding category (adjust as needed)
category = map_class_label_to_category(class_label)
# Generate text based on the category and keywords using the GPT-2 model
generated_text = generate_text_with_gpt2(category, keywords)
return generated_text
except Exception as e:
return f"Error: {str(e)}"
def map_class_label_to_category(class_label):
# Map the class label to a category (replace with your own mapping)
categories = [ 'Backpack','Celana Panjang','Celana Pendek','Dompet',
'Dress','Kacamata','Kaos', 'Kaos Kaki','Kemeja', 'Outerwear','Sandal', 'Sepatu',
'Sepatu Flat','Tas','Topi']
return categories[np.argmax(class_label, axis=-1)[0]]
def generate_text_with_gpt2(product_category, keywords):
prompt = f"Produk: {product_category}, Keywords: {keywords}, Copywriting:"
input_ids = gpt2_tokenizer.encode(prompt, return_tensors="pt")
# Adjust parameters as needed
max_length = 50
no_repeat_ngram_size = 3
top_k = 50
top_p = 0.95
output = gpt2_model.generate(
input_ids,
max_length=max_length,
no_repeat_ngram_size=no_repeat_ngram_size,
top_k=top_k,
top_p=top_p,
pad_token_id=gpt2_tokenizer.eos_token_id,
num_return_sequences=1
)
generated_text = gpt2_tokenizer.decode(output[0], skip_special_tokens=True)
return generated_text
# Create Gradio Interface
iface = gr.Interface(
fn=classify_and_generate_text,
inputs=[gr.Image(image_mode="RGB"), gr.Textbox(placeholder="", label="Keywords")],
outputs="text",
live=True
)
# Launch the Gradio Interface
iface.launch()