Spaces:
Runtime error
Runtime error
bipin
commited on
Commit
•
70b2a7d
1
Parent(s):
ef62727
added more than single story option
Browse files- app.py +4 -3
- gpt2_story_gen.py +4 -19
app.py
CHANGED
@@ -9,7 +9,7 @@ download_pretrained_model('coco', file_to_save=coco_weights)
|
|
9 |
download_pretrained_model('conceptual', file_to_save=conceptual_weights)
|
10 |
|
11 |
|
12 |
-
def main(pil_image, genre, model, use_beam_search=False):
|
13 |
if model.lower()=='coco':
|
14 |
model_file = coco_weights
|
15 |
elif model.lower()=='conceptual':
|
@@ -20,7 +20,7 @@ def main(pil_image, genre, model, use_beam_search=False):
|
|
20 |
pil_image=pil_image,
|
21 |
use_beam_search=use_beam_search,
|
22 |
)
|
23 |
-
story = generate_story(image_caption, pil_image, genre.lower())
|
24 |
return story
|
25 |
|
26 |
|
@@ -48,7 +48,8 @@ if __name__ == "__main__":
|
|
48 |
"sci_fi",
|
49 |
],
|
50 |
),
|
51 |
-
gr.inputs.Radio(choices=["coco", "conceptual"], label="Model")
|
|
|
52 |
],
|
53 |
outputs=gr.outputs.Textbox(label="Generated story"),
|
54 |
examples=[["car.jpg", "drama", "conceptual"], ["gangster.jpg", "action", "coco"]],
|
|
|
9 |
download_pretrained_model('conceptual', file_to_save=conceptual_weights)
|
10 |
|
11 |
|
12 |
+
def main(pil_image, genre, model, n_stories, use_beam_search=False):
|
13 |
if model.lower()=='coco':
|
14 |
model_file = coco_weights
|
15 |
elif model.lower()=='conceptual':
|
|
|
20 |
pil_image=pil_image,
|
21 |
use_beam_search=use_beam_search,
|
22 |
)
|
23 |
+
story = generate_story(image_caption, pil_image, genre.lower(), n_stories)
|
24 |
return story
|
25 |
|
26 |
|
|
|
48 |
"sci_fi",
|
49 |
],
|
50 |
),
|
51 |
+
gr.inputs.Radio(choices=["coco", "conceptual"], label="Model"),
|
52 |
+
gr.inputs.Dropdown(choices=[1, 2, 3], label="No. of stories", type="value"),
|
53 |
],
|
54 |
outputs=gr.outputs.Textbox(label="Generated story"),
|
55 |
examples=[["car.jpg", "drama", "conceptual"], ["gangster.jpg", "action", "coco"]],
|
gpt2_story_gen.py
CHANGED
@@ -1,11 +1,7 @@
|
|
1 |
-
from transformers import pipeline
|
2 |
-
import torch
|
3 |
|
4 |
|
5 |
-
def generate_story(image_caption, image, genre):
|
6 |
-
clip_ranker_checkpoint = "openai/clip-vit-base-patch32"
|
7 |
-
clip_ranker_processor = CLIPProcessor.from_pretrained(clip_ranker_checkpoint)
|
8 |
-
clip_ranker_model = CLIPModel.from_pretrained(clip_ranker_checkpoint)
|
9 |
|
10 |
story_gen = pipeline(
|
11 |
"text-generation",
|
@@ -13,17 +9,6 @@ def generate_story(image_caption, image, genre):
|
|
13 |
)
|
14 |
|
15 |
input = f"<BOS> <{genre}> {image_caption}"
|
16 |
-
stories = [story_gen(input)[0]['generated_text'].strip(input) for i in range(
|
17 |
-
clip_ranker_inputs = clip_ranker_processor(
|
18 |
-
text=stories,
|
19 |
-
images=image,
|
20 |
-
truncation=True,
|
21 |
-
return_tensors='pt',
|
22 |
-
padding=True
|
23 |
-
)
|
24 |
-
clip_ranker_outputs = clip_ranker_model(**clip_ranker_inputs)
|
25 |
-
logits_per_image = clip_ranker_outputs.logits_per_image
|
26 |
-
probs = logits_per_image.softmax(dim=1)
|
27 |
-
story = stories[torch.argmax(probs).item()]
|
28 |
|
29 |
-
return
|
|
|
1 |
+
from transformers import pipeline
|
|
|
2 |
|
3 |
|
4 |
+
def generate_story(image_caption, image, genre, n_stories):
|
|
|
|
|
|
|
5 |
|
6 |
story_gen = pipeline(
|
7 |
"text-generation",
|
|
|
9 |
)
|
10 |
|
11 |
input = f"<BOS> <{genre}> {image_caption}"
|
12 |
+
stories = '\n'.join([f"Story {i+1}\n{story_gen(input)[0]['generated_text'].strip(input)}" for i in range(n_stories)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
+
return stories
|