File size: 2,190 Bytes
cae4936
837545d
cae4936
718c50f
cae4936
 
718c50f
 
8985863
cae4936
70b2a7d
8985863
 
 
 
cae4936
 
 
 
 
 
70b2a7d
cae4936
 
 
 
32613f0
 
 
 
cae4936
 
32613f0
 
 
cae4936
 
 
 
 
 
 
 
 
 
 
 
 
 
70b2a7d
 
cae4936
 
1e66cb4
cae4936
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
from huggingface_hub import hf_hub_download

from prefix_clip import generate_caption
from gpt2_story_gen import generate_story

conceptual_weights = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-conceptual-weights", filename="conceptual_weights.pt")
coco_weights = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-COCO-weights", filename="coco_weights.pt")


def main(pil_image, genre, model, n_stories, use_beam_search=False):
    if model.lower()=='coco':
        model_file = coco_weights
    elif model.lower()=='conceptual':
        model_file = conceptual_weights

    image_caption = generate_caption(
        model_path=model_file,
        pil_image=pil_image,
        use_beam_search=use_beam_search,
    )
    story = generate_story(image_caption, pil_image, genre.lower(), n_stories)
    return story


if __name__ == "__main__":
    title = "Image to Story"
    article = "Combines the power of [clip prefix captioning](https://github.com/rmokady/CLIP_prefix_caption) with [gpt2 story generator](https://huggingface.co/pranavpsv/genre-story-generator-v2) to create stories of different genres from image"
    description = "Drop an image and generate stories of different genre based on that image"

    interface = gr.Interface(
        main,
        title=title,
        description=description,
        article=article,
        inputs=[
            gr.inputs.Image(type="pil", source="upload", label="Input"),
            gr.inputs.Dropdown(
                type="value",
                label="Story genre",
                choices=[
                    "superhero",
                    "action",
                    "drama",
                    "horror",
                    "thriller",
                    "sci_fi",
                ],
            ),
            gr.inputs.Radio(choices=["coco", "conceptual"], label="Model"),
            gr.inputs.Dropdown(choices=[1, 2, 3], label="No. of stories", type="value"),
        ],
        outputs=gr.outputs.Textbox(label="Generated story"),
        examples=[["car.jpg", "drama", "conceptual"], ["gangster.jpg", "action", "coco"]],
        enable_queue=True,
    )
    interface.launch()