|
import os |
|
import torch |
|
import re |
|
import gradio as gr |
|
from pathlib import Path |
|
from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel |
|
|
|
|
|
|
|
regex_pattern = "[.]{2,}" |
|
|
|
|
|
def post_process(text): |
|
try: |
|
text = text.strip() |
|
text = re.split(regex_pattern, text)[0] |
|
except Exception as e: |
|
print(e) |
|
pass |
|
return text |
|
|
|
|
|
def set_example_image(example: list) -> dict: |
|
return gr.Image.update(value=example[0]) |
|
|
|
|
|
def predict(image, max_length=64, num_beams=4): |
|
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values |
|
pixel_values = pixel_values.to(device) |
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate( |
|
pixel_values, |
|
max_length=max_length, |
|
num_beams=num_beams, |
|
return_dict_in_generate=True, |
|
).sequences |
|
|
|
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
pred = post_process(preds[0]) |
|
|
|
return pred |
|
|
|
|
|
model_name_or_path = "deepklarity/poster2plot" |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path) |
|
model.to(device) |
|
print("Loaded model") |
|
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path) |
|
print("Loaded feature_extractor") |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True) |
|
if model.decoder.name_or_path == "gpt2": |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
print("Loaded tokenizer") |
|
|
|
title = "Poster2Plot: Upload a Movie/T.V show poster to generate a plot" |
|
description = "" |
|
|
|
input = gr.inputs.Image(type="pil") |
|
|
|
example_images = sorted( |
|
[f.as_posix() for f in Path("examples").glob("*.jpg")] |
|
) |
|
print(f"Loaded {len(example_images)} example images") |
|
|
|
demo = gr.Blocks() |
|
filenames = next(os.walk('examples'), (None, None, []))[2] |
|
examples = [[f"examples/{filename}"] for filename in filenames] |
|
print(examples) |
|
|
|
with demo: |
|
with gr.Column(): |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image() |
|
with gr.Row(): |
|
clear_button = gr.Button(value="Clear", variant='secondary') |
|
submit_button = gr.Button(value="Submit", variant='primary') |
|
with gr.Column(): |
|
plot = gr.Textbox() |
|
with gr.Row(): |
|
example_images = gr.Dataset(components=[input_image], samples=examples) |
|
|
|
submit_button.click(fn=predict, inputs=[input_image], outputs=[plot]) |
|
example_images.click(fn=set_example_image, inputs=[example_images], outputs=example_images.components) |
|
|
|
demo.launch() |
|
|