File size: 1,667 Bytes
cae4936
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32613f0
 
 
 
cae4936
 
32613f0
 
 
cae4936
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32613f0
cae4936
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr

from prefix_clip import download_pretrained_model, generate_caption
from gpt2_story_gen import generate_story


def main(pil_image, genre, model="Conceptual", use_beam_search=True):
    model_file = "pretrained_weights.pt"

    download_pretrained_model(model.lower(), file_to_save=model_file)

    image_caption = generate_caption(
        model_path=model_file,
        pil_image=pil_image,
        use_beam_search=use_beam_search,
    )
    story = generate_story(image_caption, genre.lower())
    return story


if __name__ == "__main__":
    title = "Image to Story"
    article = "Combines the power of [clip prefix captioning](https://github.com/rmokady/CLIP_prefix_caption) with [gpt2 story generator](https://huggingface.co/pranavpsv/genre-story-generator-v2) to create stories of different genres from image"
    description = "Drop an image and generate stories of different genre based on that image"

    interface = gr.Interface(
        main,
        title=title,
        description=description,
        article=article,
        inputs=[
            gr.inputs.Image(type="pil", source="upload", label="Input"),
            gr.inputs.Dropdown(
                type="value",
                label="Story genre",
                choices=[
                    "superhero",
                    "action",
                    "drama",
                    "horror",
                    "thriller",
                    "sci_fi",
                ],
            ),
        ],
        outputs=gr.outputs.Textbox(label="Generated story"),
        examples=[["dog_image.jpg", "action"]],
        enable_queue=True,
    )
    interface.launch()