evelyncsb's picture
add features
624ee8e
raw
history blame
No virus
1.46 kB
import gradio as gr
import os
import skimage
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch
from imagebind import data
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
import torch.nn as nn
import pickle
device = "cpu" #"cuda:0" if torch.cuda.is_available() else "cpu"
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)
image_features = pickle.load(open("./assets/image_features_norm_2.pkl","rb"))
image_paths = pickle.load(open("./assets/image_paths.pkl","rb"))
def generate_image(text):
inputs = {
ModalityType.TEXT: data.load_and_transform_text([text], device)
}
with torch.no_grad():
embeddings = model(inputs)
text_features = embeddings[ModalityType.TEXT]
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
#pega index maior
index_img = np.argmax(similarity)
img_name = os.path.basename(image_paths[index_img])
im = Image.open(f"./assets/images/{img_name}").convert("RGB")
return im
# Interface do Gradio
iface = gr.Interface(
fn=generate_image,
inputs="text",
outputs="image",
title="Texto para Imagem",
description="Digite um texto e obtenha uma imagem com o texto."
)
# Executa o servidor Gradio
iface.launch()