File size: 1,201 Bytes
db11718
 
1bc9b9d
 
 
 
 
 
 
 
 
 
db11718
 
 
 
 
1bc9b9d
 
 
 
 
 
 
 
 
db11718
 
1bc9b9d
 
 
 
db11718
1bc9b9d
 
db11718
1bc9b9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from models import CLIPModel
from inference import predict_poems_from_image
from utils import get_poem_embeddings
import config as CFG
import json
import torch
import gradio as gr

def greet_user(name):
	return "Hello " + name + " Welcome to Gradio!😎"

if __name__ == "__main__":
    model = CLIPModel(image_encoder_pretrained=True, 
                text_encoder_pretrained=True, 
                text_projection_trainable=False,
                is_image_poem_pair=True
                ).to(CFG.device)
    model.eval()
    # Inference: Output some example predictions and write them in a file
    with open('poem_embeddings.json', encoding="utf-8") as f:
        pe = json.load(f)
    
    poem_embeddings = torch.Tensor([p['embeddings'] for p in pe]).to(CFG.device)
    print(poem_embeddings.shape)
    poems = [p['beyt'] for p in pe]

    def gradio_make_predictions(image):
        beyts = predict_poems_from_image(model, poem_embeddings, image, poems, n=10)
        return "\n".join(beyts)

    CFG.batch_size = 512

    image_input = gr.Image(type="filepath")
    output = gr.Textbox()

    app = gr.Interface(fn = gradio_make_predictions, inputs=image_input, outputs=output)
    app.launch()