mojtaba-nafez's picture
add image projection weights and configs
db11718
from models import CLIPModel
from inference import predict_poems_from_image
from utils import get_poem_embeddings
import config as CFG
import json
import torch
import gradio as gr
def greet_user(name):
return "Hello " + name + " Welcome to Gradio!😎"
if __name__ == "__main__":
model = CLIPModel(image_encoder_pretrained=True,
text_encoder_pretrained=True,
text_projection_trainable=False,
is_image_poem_pair=True
).to(CFG.device)
model.eval()
# Inference: Output some example predictions and write them in a file
with open('poem_embeddings.json', encoding="utf-8") as f:
pe = json.load(f)
poem_embeddings = torch.Tensor([p['embeddings'] for p in pe]).to(CFG.device)
print(poem_embeddings.shape)
poems = [p['beyt'] for p in pe]
def gradio_make_predictions(image):
beyts = predict_poems_from_image(model, poem_embeddings, image, poems, n=10)
return "\n".join(beyts)
CFG.batch_size = 512
image_input = gr.Image(type="filepath")
output = gr.Textbox()
app = gr.Interface(fn = gradio_make_predictions, inputs=image_input, outputs=output)
app.launch()