Spaces:
Runtime error
Runtime error
import gradio as gr | |
from prefix_clip import download_pretrained_model, generate_caption | |
from gpt2_story_gen import generate_story | |
def main(pil_image, genre, model="Conceptual", use_beam_search=False): | |
model_file = "pretrained_weights.pt" | |
download_pretrained_model(model.lower(), file_to_save=model_file) | |
image_caption = generate_caption( | |
model_path=model_file, | |
pil_image=pil_image, | |
use_beam_search=use_beam_search, | |
) | |
story = generate_story(image_caption, genre.lower()) | |
return story | |
if __name__ == "__main__": | |
title = "Image to Story" | |
article = "Combines the power of [clip prefix captioning](https://github.com/rmokady/CLIP_prefix_caption) with [gpt2 story generator](https://huggingface.co/pranavpsv/genre-story-generator-v2) to create stories of different genres from image" | |
description = "Drop an image and generate stories of different genre based on that image" | |
interface = gr.Interface( | |
main, | |
title=title, | |
description=description, | |
article=article, | |
inputs=[ | |
gr.inputs.Image(type="pil", source="upload", label="Input"), | |
gr.inputs.Dropdown( | |
type="value", | |
label="Story genre", | |
choices=[ | |
"superhero", | |
"action", | |
"drama", | |
"horror", | |
"thriller", | |
"sci_fi", | |
], | |
), | |
], | |
outputs=gr.outputs.Textbox(label="Generated story"), | |
examples=[["image.jpg", "action"]], | |
enable_queue=True, | |
) | |
interface.launch() | |