kiri00 commited on
Commit
797b64f
1 Parent(s): 1a3f242

initial commit

Browse files
Files changed (1) hide show
  1. app.py +55 -0
app.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ import gradio as gr
4
+ from pathlib import Path
5
+ from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
6
+ # Pattern to ignore all the text after 2 or more full stops
7
+ regex_pattern = "[.]{2,}"
8
+ def post_process(text):
9
+ try:
10
+ text = text.strip()
11
+ text = re.split(regex_pattern, text)[0]
12
+ except Exception as e:
13
+ print(e)
14
+ pass
15
+ return text
16
+ def predict(image, max_length=64, num_beams=4):
17
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
18
+ pixel_values = pixel_values.to(device)
19
+ with torch.no_grad():
20
+ output_ids = model.generate(
21
+ pixel_values,
22
+ max_length=max_length,
23
+ num_beams=num_beams,
24
+ return_dict_in_generate=True,
25
+ ).sequences
26
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
27
+ pred = post_process(preds[0])
28
+ return pred
29
+ model_name_or_path = "deepklarity/poster2plot"
30
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
31
+ # Load model.
32
+ model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path)
33
+ model.to(device)
34
+ print("Loaded model")
35
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path)
36
+ print("Loaded feature_extractor")
37
+ tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
38
+ if model.decoder.name_or_path == "gpt2":
39
+ tokenizer.pad_token = tokenizer.eos_token
40
+ print("Loaded tokenizer")
41
+ title = "Poster2Plot: Upload a Movie/T.V show poster to generate a plot"
42
+ description = ""
43
+ input = gr.inputs.Image(type="pil")
44
+ example_images = sorted([f.as_posix() for f in Path("examples").glob("*.jpg")])
45
+ print(f"Loaded {len(example_images)} example images")
46
+ interface = gr.Interface(
47
+ fn=predict,
48
+ inputs=input,
49
+ outputs="textbox",
50
+ title=title,
51
+ description=description,
52
+ examples=example_images,
53
+ live=True,
54
+ )
55
+ interface.launch()