from models import PoemTextModel from inference import predict_poems_from_text from utils import get_poem_embeddings import config as CFG import json import gradio as gr def greet_user(name): return "Hello " + name + " Welcome to Gradio!😎" if __name__ == "__main__": model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device) model.eval() # Inference: Output some example predictions and write them in a file with open(CFG.dataset_path, encoding="utf-8") as f: dataset = json.load(f) def gradio_make_predictions(text): beyts = predict_poems_from_text(model, poem_embeddings, text, [data['beyt'] for data in dataset], n=10) return "\n".join(beyts) CFG.batch_size = 512 model, poem_embeddings = get_poem_embeddings(dataset, model) # print(poem_embeddings[0]) # with open('poem_embeddings.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f: # f.write(json.dumps(poem_embeddings, indent= 4)) text_input = gr.Textbox(label = "Enter the text to find poem beyts for") output = gr.Textbox() app = gr.Interface(fn = gradio_make_predictions, inputs=text_input, outputs=output) app.launch()