File size: 402 Bytes
70b2a7d
cae4936
 
70b2a7d
cae4936
0843a80
 
 
 
 
cae4936
fb694ab
cae4936
70b2a7d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from transformers import pipeline


def generate_story(image_caption, image, genre, n_stories):

    story_gen = pipeline(
        "text-generation", 
        "pranavpsv/genre-story-generator-v2"
        )
    
    input = f"<BOS> <{genre}> {image_caption}"
    stories = '\n\n'.join([f"Story {i+1}\n{story_gen(input)[0]['generated_text'].strip(input)}" for i in range(n_stories)])

    return stories