Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image, ImageDraw | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch | |
from transformers import Pix2StructProcessor, Pix2StructVisionModel | |
from utils import download_default_font, render_header | |
class Pix2StructForRegression(nn.Module): | |
def __init__(self, sourcemodel_path, device): | |
super(Pix2StructForRegression, self).__init__() | |
self.model = Pix2StructVisionModel.from_pretrained(sourcemodel_path) | |
self.regression_layer1 = nn.Linear(768, 1536) | |
self.dropout1 = nn.Dropout(0.1) | |
self.regression_layer2 = nn.Linear(1536, 768) | |
self.dropout2 = nn.Dropout(0.1) | |
self.regression_layer3 = nn.Linear(768, 2) | |
self.device = device | |
def forward(self, *args, **kwargs): | |
outputs = self.model(*args, **kwargs) | |
sequence_output = outputs.last_hidden_state | |
first_token_output = sequence_output[:, 0, :] | |
x = F.relu(self.regression_layer1(first_token_output)) | |
x = F.relu(self.regression_layer2(x)) | |
regression_output = torch.sigmoid(self.regression_layer3(x)) | |
return regression_output | |
def load_state_dict_file(self, checkpoint_path, strict=True): | |
state_dict = torch.load(checkpoint_path, map_location=self.device) | |
self.load_state_dict(state_dict, strict=strict) | |
class Inference: | |
def __init__(self) -> None: | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model, self.processor = self.load_model_and_processor("google/matcha-base", "model/pta-text-v0.1.pt") | |
def load_model_and_processor(self, model_name, checkpoint_path): | |
model = Pix2StructForRegression(sourcemodel_path=model_name, device=self.device) | |
model.load_state_dict_file(checkpoint_path=checkpoint_path) | |
model.eval() | |
model = model.to(self.device) | |
processor = Pix2StructProcessor.from_pretrained(model_name, is_vqa=False) | |
return model, processor | |
def prepare_image(self, image, prompt, processor): | |
image = image.resize((1920, 1080)) | |
download_default_font_path = download_default_font() | |
rendered_image, _, render_variables = render_header( | |
image=image, | |
header=prompt, | |
bbox={"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0}, | |
font_path=download_default_font_path, | |
) | |
encoding = processor( | |
images=rendered_image, | |
max_patches=2048, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
return encoding, render_variables | |
def predict_coordinates(self, encoding, model, render_variables): | |
with torch.no_grad(): | |
pred_regression_outs = model(flattened_patches=encoding["flattened_patches"], attention_mask=encoding["attention_mask"]) | |
new_height = render_variables["height"] | |
new_header_height = render_variables["header_height"] | |
new_total_height = render_variables["total_height"] | |
pred_regression_outs[:, 1] = ( | |
(pred_regression_outs[:, 1] * new_total_height) - new_header_height | |
) / new_height | |
pred_coordinates = pred_regression_outs.squeeze().tolist() | |
return pred_coordinates | |
def draw_circle_on_image(self, image, coordinates): | |
x, y = coordinates[0] * image.width, coordinates[1] * image.height | |
draw = ImageDraw.Draw(image) | |
radius = 5 | |
draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="red") | |
return image | |
def process_image_and_draw_circle(self, image, prompt): | |
encoding, render_variables = self.prepare_image(image, prompt, self.processor) | |
pred_coordinates = self.predict_coordinates(encoding.to(self.device) , self.model, render_variables) | |
result_image = self.draw_circle_on_image(image, pred_coordinates) | |
return result_image | |
def main(): | |
inference = Inference() | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=inference.process_image_and_draw_circle, | |
inputs=[gr.Image(type="pil", label = "Upload Image"), | |
gr.Textbox(label = "Prompt", placeholder="Enter prompt here...")], | |
outputs=gr.Image(type="pil"), | |
title="Pix2Struct Image Processing", | |
description="Upload an image and enter a prompt to see the model's prediction." | |
) | |
iface.launch(server_name="0.0.0.0", port=7860) | |
if __name__ == "__main__": | |
main() | |