Deepak Singh Rawat commited on
Commit
47962d5
1 Parent(s): a8567af

Add model card

Browse files
Files changed (1) hide show
  1. README.md +96 -0
README.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - image-classification
5
+ - image-captioning
6
+
7
+ ---
8
+
9
+ # Poster2Plot
10
+
11
+ An image captioning model to generate movie/t.v show plot from poster. It generates decent plots but is no way perfect. We are still working on improving the model.
12
+
13
+ # Model Details
14
+
15
+ The base model uses a Vision Transformer (ViT) model as an image encoder and GPT-2 as a decoder.
16
+
17
+ We used the following models:
18
+
19
+ * Encoder: [google/vit-base-patch16-224-in21k](https://huggingface.co/google/vit-base-patch16-224-in21k)
20
+ * Decoder: [gpt2](https://huggingface.co/gpt2)
21
+
22
+ # Datasets
23
+
24
+ Publicly available IMDb datasets were used to train the model.
25
+
26
+ # How to use
27
+
28
+ ## In PyTorch
29
+
30
+ ```python
31
+ import torch
32
+ import re
33
+ import requests
34
+ from PIL import Image
35
+ from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
36
+
37
+ # Pattern to ignore all the text after 2 or more full stops
38
+ regex_pattern = "[.]{2,}"
39
+
40
+
41
+ def post_process(text):
42
+ try:
43
+ text = text.strip()
44
+ text = re.split(regex_pattern, text)[0]
45
+ except Exception as e:
46
+ print(e)
47
+ pass
48
+ return text
49
+
50
+
51
+ def predict(image, max_length=64, num_beams=4):
52
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
53
+ pixel_values = pixel_values.to(device)
54
+
55
+ with torch.no_grad():
56
+ output_ids = model.generate(
57
+ pixel_values,
58
+ max_length=max_length,
59
+ num_beams=num_beams,
60
+ return_dict_in_generate=True,
61
+ ).sequences
62
+
63
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
64
+ pred = post_process(preds[0])
65
+
66
+ return pred
67
+
68
+
69
+ model_name_or_path = "deepklarity/poster2plot"
70
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
71
+
72
+ # Load model.
73
+
74
+ model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path)
75
+ model.to(device)
76
+ print("Loaded model")
77
+
78
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path)
79
+ print("Loaded feature_extractor")
80
+
81
+ tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
82
+ if model.decoder.name_or_path == "gpt2":
83
+ tokenizer.pad_token = tokenizer.eos_token
84
+
85
+ print("Loaded tokenizer")
86
+
87
+ url = "https://upload.wikimedia.org/wikipedia/en/2/26/Moana_Teaser_Poster.jpg"
88
+ with Image.open(requests.get(url, stream=True).raw) as image:
89
+ pred = predict(image)
90
+
91
+ print(pred)
92
+
93
+ ```
94
+
95
+
96
+