dsr commited on
Commit
3e688ca
1 Parent(s): 7510b9d

Create poster2plot space

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .vscode
2
+ .ipynb_checkpoints
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import re
3
+ import gradio as gr
4
+ from pathlib import Path
5
+ from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
6
+
7
+
8
+ # Pattern to ignore all the text after 2 or more full stops
9
+ regex_pattern = "[.]{2,}"
10
+
11
+
12
+ def post_process(text):
13
+ try:
14
+ text = text.strip()
15
+ text = re.split(regex_pattern, text)[0]
16
+ except Exception as e:
17
+ print(e)
18
+ pass
19
+ return text
20
+
21
+
22
+ def predict(image, max_length=64, num_beams=4):
23
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
24
+ pixel_values = pixel_values.to(device)
25
+
26
+ with torch.no_grad():
27
+ output_ids = model.generate(
28
+ pixel_values,
29
+ max_length=max_length,
30
+ num_beams=num_beams,
31
+ return_dict_in_generate=True,
32
+ ).sequences
33
+
34
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
35
+ pred = post_process(preds[0])
36
+
37
+ return pred
38
+
39
+
40
+ model_name_or_path = "deepklarity/poster2plot"
41
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42
+
43
+ # Load model.
44
+
45
+ model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path)
46
+ model.to(device)
47
+ print("Loaded model")
48
+
49
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path)
50
+ print("Loaded feature_extractor")
51
+
52
+ tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
53
+ if model.decoder.name_or_path == "gpt2":
54
+ tokenizer.pad_token = tokenizer.eos_token
55
+
56
+ print("Loaded tokenizer")
57
+
58
+ title = "Poster2Plot: Upload a Movie/T.V show poster to generate a plot"
59
+ description = ""
60
+
61
+ input = gr.inputs.Image(type="pil")
62
+
63
+ example_images = sorted([f.as_posix() for f in Path("examples").glob("*.jpg")])
64
+ print(f"Loaded {len(example_images)} example images")
65
+
66
+ interface = gr.Interface(
67
+ fn=predict,
68
+ inputs=input,
69
+ outputs="textbox",
70
+ title=title,
71
+ description=description,
72
+ examples=example_images,
73
+ live=True,
74
+ )
75
+
76
+ interface.launch()
examples/tt0068646-the-godfather.jpg ADDED
examples/tt0076759-star-wars.jpg ADDED
examples/tt0108778-friends.jpg ADDED
examples/tt10062292-never-have-i-ever.jpg ADDED
examples/tt10919420-squid-games.jpg ADDED
examples/tt6468322-money-heist.jpg ADDED
examples/tt7991608-red-notice.jpg ADDED
examples/tt8366590-baaghi3.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ --find-links https://download.pytorch.org/whl/torch_stable.html
2
+ gradio==2.2.6
3
+ transformers==4.12.5
4
+ torch==1.10.0+cpu