memero / app.py
andreinigo's picture
Update app.py
437e10c
raw
history blame
2.72 kB
import os
import textwrap
from PIL import Image, ImageDraw, ImageFont
import gradio as gr
import numpy as np
import torch
from lavis.models import load_model_and_preprocess
import openai
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model, vis_processors, _ = load_model_and_preprocess(
name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
)
openai.api_key = os.environ["OPENAI_API_KEY"]
def generate_caption(image):
pil_image = image.copy() # Create a copy of the input PIL image
image = vis_processors["eval"](image).unsqueeze(0).to(device)
caption = model.generate({"image": image})
caption = "\n".join(caption)
#use gpt-4 to generate a meme based on the caption
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{"role": "system", "content": "Escribe un meme chistoso en español para una imagen a partir en la descripción dada por el usuario. No uses emojis. El meme tiene que ser corto y gracioso. El output del asistente solo debe ser el meme. Asegúrate que el meme sea tan bueno que se vuelva viral!"},
{"role": "user", "content": caption}
],
temperature=0.6
)
meme_text = response.choices[0].message.content
print(meme_text)
# Put the meme text on the image
draw = ImageDraw.Draw(pil_image)
# Determine font size
max_width = int(pil_image.width * 0.9)
font_size = int(max_width / (len(meme_text) / 2))
font = ImageFont.truetype("impact.ttf", font_size)
# Wrap the text to fit within the image width and have a maximum of 2 lines
wrapped_text = textwrap.fill(meme_text, width=int(max_width / font.getsize("A")[0]))
text_lines = wrapped_text.split('\n')
y = 10
for line in text_lines:
line_width, line_height = draw.textsize(line, font=font)
x = (pil_image.width - line_width) // 2
draw.text((x, y), line, fill=(255, 255, 255), font=font)
y += line_height
pil_image = pil_image.convert('RGB')
if torch.cuda.is_available():
torch.cuda.empty_cache()
return pil_image
with gr.Blocks() as demo:
gr.Markdown(
"### Memero - Generador de Memes"
)
gr.Markdown(
"Genera un meme en español a partir de una imagen."
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Imagen", type="pil")
btn_caption = gr.Button("Generar meme")
output_text = gr.Image(label="Meme", lines=5)
btn_caption.click(
generate_caption, inputs=[input_image], outputs=[output_text]
)
demo.launch()