Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from prefix_clip import generate_caption | |
| from gpt2_story_gen import generate_story | |
| conceptual_weights = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-conceptual-weights", filename="conceptual_weights.pt") | |
| coco_weights = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-COCO-weights", filename="coco_weights.pt") | |
| def main(pil_image, genre, model, n_stories, use_beam_search=False): | |
| if model.lower()=='coco': | |
| model_file = coco_weights | |
| elif model.lower()=='conceptual': | |
| model_file = conceptual_weights | |
| image_caption = generate_caption( | |
| model_path=model_file, | |
| pil_image=pil_image, | |
| use_beam_search=use_beam_search, | |
| ) | |
| story = generate_story(image_caption, pil_image, genre.lower(), n_stories) | |
| return story | |
| if __name__ == "__main__": | |
| title = "Image to Story" | |
| article = "Combines the power of [clip prefix captioning](https://github.com/rmokady/CLIP_prefix_caption) with [gpt2 story generator](https://huggingface.co/pranavpsv/genre-story-generator-v2) to create stories of different genres from image" | |
| description = "Drop an image and generate stories of different genre based on that image" | |
| interface = gr.Interface( | |
| main, | |
| title=title, | |
| description=description, | |
| article=article, | |
| inputs=[ | |
| gr.inputs.Image(type="pil", source="upload", label="Input"), | |
| gr.inputs.Dropdown( | |
| type="value", | |
| label="Story genre", | |
| choices=[ | |
| "superhero", | |
| "action", | |
| "drama", | |
| "horror", | |
| "thriller", | |
| "sci_fi", | |
| ], | |
| ), | |
| gr.inputs.Radio(choices=["coco", "conceptual"], label="Model"), | |
| gr.inputs.Dropdown(choices=[1, 2, 3], label="No. of stories", type="value"), | |
| ], | |
| outputs=gr.outputs.Textbox(label="Generated story"), | |
| examples=[["car.jpg", "drama", "conceptual"], ["gangster.jpg", "action", "coco"]], | |
| enable_queue=True, | |
| ) | |
| interface.launch() | |