Spaces:
Runtime error
Runtime error
upgrade to gradio blocks
Browse files
app.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
-
import
|
2 |
import re
|
|
|
|
|
3 |
import gradio as gr
|
4 |
-
from pathlib import Path
|
5 |
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
|
6 |
|
7 |
-
|
8 |
# Pattern to ignore all the text after 2 or more full stops
|
9 |
regex_pattern = "[.]{2,}"
|
10 |
|
@@ -19,6 +19,10 @@ def post_process(text):
|
|
19 |
return text
|
20 |
|
21 |
|
|
|
|
|
|
|
|
|
22 |
def predict(image, max_length=64, num_beams=4):
|
23 |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
24 |
pixel_values = pixel_values.to(device)
|
@@ -52,29 +56,29 @@ print("Loaded feature_extractor")
|
|
52 |
tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
|
53 |
if model.decoder.name_or_path == "gpt2":
|
54 |
tokenizer.pad_token = tokenizer.eos_token
|
55 |
-
|
56 |
print("Loaded tokenizer")
|
57 |
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
)
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
)
|
79 |
-
|
80 |
-
|
|
|
|
1 |
+
import os
|
2 |
import re
|
3 |
+
|
4 |
+
import torch
|
5 |
import gradio as gr
|
|
|
6 |
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
|
7 |
|
|
|
8 |
# Pattern to ignore all the text after 2 or more full stops
|
9 |
regex_pattern = "[.]{2,}"
|
10 |
|
|
|
19 |
return text
|
20 |
|
21 |
|
22 |
+
def set_example_image(example: list) -> dict:
|
23 |
+
return gr.Image.update(value=example[0])
|
24 |
+
|
25 |
+
|
26 |
def predict(image, max_length=64, num_beams=4):
|
27 |
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
|
28 |
pixel_values = pixel_values.to(device)
|
|
|
56 |
tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
|
57 |
if model.decoder.name_or_path == "gpt2":
|
58 |
tokenizer.pad_token = tokenizer.eos_token
|
|
|
59 |
print("Loaded tokenizer")
|
60 |
|
61 |
+
examples = [[f"examples/{filename}"] for filename in next(os.walk('examples'), (None, None, []))[2]]
|
62 |
+
print(f"Loaded {len(examples)} example images")
|
63 |
+
|
64 |
+
with gr.Blocks(css="#title { margin: 0 auto; padding: 25px 25px 25px 25px }") as poster2plot:
|
65 |
+
with gr.Column():
|
66 |
+
with gr.Row():
|
67 |
+
gr.Markdown("# Poster2Plot: Upload a Movie/T.V show poster to generate a plot", elem_id='title')
|
68 |
+
with gr.Row():
|
69 |
+
with gr.Column():
|
70 |
+
with gr.Row():
|
71 |
+
input_image = gr.Image(label='Input Image', type='numpy')
|
72 |
+
with gr.Row():
|
73 |
+
submit_button = gr.Button(value="Submit", variant='primary')
|
74 |
+
with gr.Column():
|
75 |
+
plot = gr.Textbox(label="Plot")
|
76 |
+
with gr.Row():
|
77 |
+
example_images = gr.Dataset(components=[input_image], samples=examples)
|
78 |
+
with gr.Row():
|
79 |
+
gr.Markdown("Made by: [dk-crazydiv](https://twitter.com/kartik_godawat) and [dsr](https://twitter.com/dsr_ai)")
|
80 |
+
|
81 |
+
submit_button.click(fn=predict, inputs=[input_image], outputs=[plot])
|
82 |
+
example_images.click(fn=set_example_image, inputs=[example_images], outputs=example_images.components)
|
83 |
+
|
84 |
+
poster2plot.launch()
|