File size: 3,186 Bytes
a4d40bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import torch
import re
import gradio as gr
from pathlib import Path
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel


# Pattern to ignore all the text after 2 or more full stops
regex_pattern = "[.]{2,}"


def post_process(text):
    try:
        text = text.strip()
        text = re.split(regex_pattern, text)[0]
    except Exception as e:
        print(e)
        pass
    return text


def set_example_image(example: list) -> dict:
    return gr.Image.update(value=example[0])


def predict(image, max_length=64, num_beams=4):
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device)

    with torch.no_grad():
        output_ids = model.generate(
            pixel_values,
            max_length=max_length,
            num_beams=num_beams,
            return_dict_in_generate=True,
        ).sequences

    preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
    pred = post_process(preds[0])

    return pred


model_name_or_path = "deepklarity/poster2plot"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load model.

model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path)
model.to(device)
print("Loaded model")

feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path)
print("Loaded feature_extractor")

tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
if model.decoder.name_or_path == "gpt2":
    tokenizer.pad_token = tokenizer.eos_token

print("Loaded tokenizer")

title = "Poster2Plot: Upload a Movie/T.V show poster to generate a plot"
description = ""

input = gr.inputs.Image(type="pil")

example_images = sorted(
    [f.as_posix() for f in Path("examples").glob("*.jpg")]
)
print(f"Loaded {len(example_images)} example images")

demo = gr.Blocks()
filenames = next(os.walk('examples'), (None, None, []))[2]
examples = [[f"examples/{filename}"] for filename in filenames]
print(examples)

with demo:
    with gr.Column():
        with gr.Row():
            with gr.Column():
                input_image = gr.Image()
                with gr.Row():
                    clear_button = gr.Button(value="Clear", variant='secondary')
                    submit_button = gr.Button(value="Submit", variant='primary')
            with gr.Column():
                plot = gr.Textbox()
        with gr.Row():
            example_images = gr.Dataset(components=[input_image], samples=examples)

    submit_button.click(fn=predict, inputs=[input_image], outputs=[plot])
    example_images.click(fn=set_example_image, inputs=[example_images], outputs=example_images.components)

demo.launch()


interface = gr.Interface(
    fn=predict,
    inputs=input,
    outputs="textbox",
    title=title,
    description=description,
    examples=example_images,
    examples_per_page=20,
    live=True,
    article='<p>Made by: <a href="https://twitter.com/kartik_godawat" target="_blank" rel="noopener noreferrer">dk-crazydiv</a> and <a href="https://twitter.com/dsr_ai" target="_blank" rel="noopener noreferrer">dsr</a></p>'
)

interface.launch()