bipin commited on
Commit
70b2a7d
1 Parent(s): ef62727

added more than single story option

Browse files
Files changed (2) hide show
  1. app.py +4 -3
  2. gpt2_story_gen.py +4 -19
app.py CHANGED
@@ -9,7 +9,7 @@ download_pretrained_model('coco', file_to_save=coco_weights)
9
  download_pretrained_model('conceptual', file_to_save=conceptual_weights)
10
 
11
 
12
- def main(pil_image, genre, model, use_beam_search=False):
13
  if model.lower()=='coco':
14
  model_file = coco_weights
15
  elif model.lower()=='conceptual':
@@ -20,7 +20,7 @@ def main(pil_image, genre, model, use_beam_search=False):
20
  pil_image=pil_image,
21
  use_beam_search=use_beam_search,
22
  )
23
- story = generate_story(image_caption, pil_image, genre.lower())
24
  return story
25
 
26
 
@@ -48,7 +48,8 @@ if __name__ == "__main__":
48
  "sci_fi",
49
  ],
50
  ),
51
- gr.inputs.Radio(choices=["coco", "conceptual"], label="Model")
 
52
  ],
53
  outputs=gr.outputs.Textbox(label="Generated story"),
54
  examples=[["car.jpg", "drama", "conceptual"], ["gangster.jpg", "action", "coco"]],
 
9
  download_pretrained_model('conceptual', file_to_save=conceptual_weights)
10
 
11
 
12
+ def main(pil_image, genre, model, n_stories, use_beam_search=False):
13
  if model.lower()=='coco':
14
  model_file = coco_weights
15
  elif model.lower()=='conceptual':
 
20
  pil_image=pil_image,
21
  use_beam_search=use_beam_search,
22
  )
23
+ story = generate_story(image_caption, pil_image, genre.lower(), n_stories)
24
  return story
25
 
26
 
 
48
  "sci_fi",
49
  ],
50
  ),
51
+ gr.inputs.Radio(choices=["coco", "conceptual"], label="Model"),
52
+ gr.inputs.Dropdown(choices=[1, 2, 3], label="No. of stories", type="value"),
53
  ],
54
  outputs=gr.outputs.Textbox(label="Generated story"),
55
  examples=[["car.jpg", "drama", "conceptual"], ["gangster.jpg", "action", "coco"]],
gpt2_story_gen.py CHANGED
@@ -1,11 +1,7 @@
1
- from transformers import pipeline, CLIPProcessor, CLIPModel
2
- import torch
3
 
4
 
5
- def generate_story(image_caption, image, genre):
6
- clip_ranker_checkpoint = "openai/clip-vit-base-patch32"
7
- clip_ranker_processor = CLIPProcessor.from_pretrained(clip_ranker_checkpoint)
8
- clip_ranker_model = CLIPModel.from_pretrained(clip_ranker_checkpoint)
9
 
10
  story_gen = pipeline(
11
  "text-generation",
@@ -13,17 +9,6 @@ def generate_story(image_caption, image, genre):
13
  )
14
 
15
  input = f"<BOS> <{genre}> {image_caption}"
16
- stories = [story_gen(input)[0]['generated_text'].strip(input) for i in range(3)]
17
- clip_ranker_inputs = clip_ranker_processor(
18
- text=stories,
19
- images=image,
20
- truncation=True,
21
- return_tensors='pt',
22
- padding=True
23
- )
24
- clip_ranker_outputs = clip_ranker_model(**clip_ranker_inputs)
25
- logits_per_image = clip_ranker_outputs.logits_per_image
26
- probs = logits_per_image.softmax(dim=1)
27
- story = stories[torch.argmax(probs).item()]
28
 
29
- return story
 
1
+ from transformers import pipeline
 
2
 
3
 
4
+ def generate_story(image_caption, image, genre, n_stories):
 
 
 
5
 
6
  story_gen = pipeline(
7
  "text-generation",
 
9
  )
10
 
11
  input = f"<BOS> <{genre}> {image_caption}"
12
+ stories = '\n'.join([f"Story {i+1}\n{story_gen(input)[0]['generated_text'].strip(input)}" for i in range(n_stories)])
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ return stories