Spaces:
Sleeping
Sleeping
gitlost-murali
commited on
Commit
•
da59cbe
1
Parent(s):
33024b0
initial checkpoint inference push
Browse files- Dockerfile +24 -0
- app.py +116 -0
- requirements.txt +4 -0
- utils.py +144 -0
Dockerfile
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM ubuntu:22.04
|
5 |
+
# install curl
|
6 |
+
RUN apt-get update && apt-get install -y curl && apt-get install -y git && \
|
7 |
+
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | bash && \
|
8 |
+
apt-get install -y git-lfs
|
9 |
+
|
10 |
+
WORKDIR /code
|
11 |
+
|
12 |
+
RUN git lfs clone https://huggingface.co/AskUI/pta-text-0.1 /code/model/
|
13 |
+
|
14 |
+
COPY ./requirements.txt /code/requirements.txt
|
15 |
+
|
16 |
+
RUN apt-get install -y python3 python3-pip
|
17 |
+
|
18 |
+
# RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
19 |
+
RUN pip install --upgrade -r /code/requirements.txt
|
20 |
+
|
21 |
+
|
22 |
+
COPY . .
|
23 |
+
|
24 |
+
CMD ["python3", "app.py", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image, ImageDraw
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch
|
7 |
+
from transformers import Pix2StructProcessor, Pix2StructVisionModel
|
8 |
+
from utils import download_default_font, render_header
|
9 |
+
|
10 |
+
class Pix2StructForRegression(nn.Module):
|
11 |
+
def __init__(self, sourcemodel_path, device):
|
12 |
+
super(Pix2StructForRegression, self).__init__()
|
13 |
+
self.model = Pix2StructVisionModel.from_pretrained(sourcemodel_path)
|
14 |
+
print("Pix2StructForRegression Model is Loaded...")
|
15 |
+
self.regression_layer1 = nn.Linear(768, 1536)
|
16 |
+
self.dropout1 = nn.Dropout(0.1)
|
17 |
+
self.regression_layer2 = nn.Linear(1536, 768)
|
18 |
+
self.dropout2 = nn.Dropout(0.1)
|
19 |
+
self.regression_layer3 = nn.Linear(768, 2)
|
20 |
+
self.device = device
|
21 |
+
print("Regression Layers are Loaded...")
|
22 |
+
|
23 |
+
def forward(self, *args, **kwargs):
|
24 |
+
outputs = self.model(*args, **kwargs)
|
25 |
+
sequence_output = outputs.last_hidden_state
|
26 |
+
first_token_output = sequence_output[:, 0, :]
|
27 |
+
|
28 |
+
x = F.relu(self.regression_layer1(first_token_output))
|
29 |
+
x = F.relu(self.regression_layer2(x))
|
30 |
+
regression_output = torch.sigmoid(self.regression_layer3(x))
|
31 |
+
|
32 |
+
return regression_output
|
33 |
+
|
34 |
+
def load_state_dict_file(self, checkpoint_path, strict=True):
|
35 |
+
print("Loading Model Weights...")
|
36 |
+
state_dict = torch.load(checkpoint_path, map_location=self.device)
|
37 |
+
self.load_state_dict(state_dict, strict=strict)
|
38 |
+
print("Model Weights are Loaded...")
|
39 |
+
|
40 |
+
class Inference:
|
41 |
+
def __init__(self) -> None:
|
42 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
43 |
+
self.model, self.processor = self.load_model_and_processor("matcha-base", "model/pta-text-v0.1.pt")
|
44 |
+
print("Model and Processor are Loaded...")
|
45 |
+
|
46 |
+
def load_model_and_processor(self, model_name, checkpoint_path):
|
47 |
+
model = Pix2StructForRegression(sourcemodel_path=model_name, device=self.device)
|
48 |
+
model.load_state_dict_file(checkpoint_path=checkpoint_path)
|
49 |
+
model.eval()
|
50 |
+
model = model.to(self.device)
|
51 |
+
processor = Pix2StructProcessor.from_pretrained(model_name, is_vqa=False)
|
52 |
+
return model, processor
|
53 |
+
|
54 |
+
def prepare_image(self, image, prompt, processor):
|
55 |
+
image = image.resize((1920, 1080))
|
56 |
+
download_default_font_path = download_default_font()
|
57 |
+
rendered_image, _, render_variables = render_header(
|
58 |
+
image=image,
|
59 |
+
header=prompt,
|
60 |
+
bbox={"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0},
|
61 |
+
font_path=download_default_font_path,
|
62 |
+
)
|
63 |
+
encoding = processor(
|
64 |
+
images=rendered_image,
|
65 |
+
max_patches=2048,
|
66 |
+
add_special_tokens=True,
|
67 |
+
return_tensors="pt",
|
68 |
+
)
|
69 |
+
return encoding, render_variables
|
70 |
+
|
71 |
+
def predict_coordinates(self, encoding, model, render_variables):
|
72 |
+
with torch.no_grad():
|
73 |
+
pred_regression_outs = model(flattened_patches=encoding["flattened_patches"], attention_mask=encoding["attention_mask"])
|
74 |
+
new_height = render_variables["height"]
|
75 |
+
new_header_height = render_variables["header_height"]
|
76 |
+
new_total_height = render_variables["total_height"]
|
77 |
+
|
78 |
+
pred_regression_outs[:, 1] = (
|
79 |
+
(pred_regression_outs[:, 1] * new_total_height) - new_header_height
|
80 |
+
) / new_height
|
81 |
+
|
82 |
+
pred_coordinates = pred_regression_outs.squeeze().tolist()
|
83 |
+
return pred_coordinates
|
84 |
+
|
85 |
+
def draw_circle_on_image(self, image, coordinates):
|
86 |
+
x, y = coordinates[0] * image.width, coordinates[1] * image.height
|
87 |
+
print(coordinates)
|
88 |
+
draw = ImageDraw.Draw(image)
|
89 |
+
radius = 5
|
90 |
+
draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="red")
|
91 |
+
return image
|
92 |
+
|
93 |
+
def process_image_and_draw_circle(self, image, prompt):
|
94 |
+
encoding, render_variables = self.prepare_image(image, prompt, self.processor)
|
95 |
+
pred_coordinates = self.predict_coordinates(encoding.to(self.device) , self.model, render_variables)
|
96 |
+
result_image = self.draw_circle_on_image(image, pred_coordinates)
|
97 |
+
return result_image
|
98 |
+
|
99 |
+
|
100 |
+
def main():
|
101 |
+
inference = Inference()
|
102 |
+
print("Model and Processor are Loaded...")
|
103 |
+
# Gradio Interface
|
104 |
+
iface = gr.Interface(
|
105 |
+
fn=inference.process_image_and_draw_circle,
|
106 |
+
inputs=[gr.Image(type="pil", label = "Upload Image"),
|
107 |
+
gr.Textbox(label = "Prompt", placeholder="Enter prompt here...")],
|
108 |
+
outputs=gr.Image(type="pil"),
|
109 |
+
title="Pix2Struct Image Processing",
|
110 |
+
description="Upload an image and enter a prompt to see the model's prediction."
|
111 |
+
)
|
112 |
+
|
113 |
+
|
114 |
+
iface.launch()
|
115 |
+
if __name__ == "__main__":
|
116 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
gradio
|
4 |
+
Pillow
|
utils.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import os
|
3 |
+
import textwrap
|
4 |
+
from typing import Dict, Optional, Tuple
|
5 |
+
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from PIL import Image, ImageDraw, ImageFont
|
8 |
+
|
9 |
+
DEFAULT_FONT_PATH = "ybelkada/fonts"
|
10 |
+
|
11 |
+
|
12 |
+
def download_default_font():
|
13 |
+
font_path = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF")
|
14 |
+
return font_path
|
15 |
+
|
16 |
+
|
17 |
+
def render_text(
|
18 |
+
text: str,
|
19 |
+
text_size: int = 36,
|
20 |
+
text_color: str = "black",
|
21 |
+
background_color: str = "white",
|
22 |
+
left_padding: int = 5,
|
23 |
+
right_padding: int = 5,
|
24 |
+
top_padding: int = 5,
|
25 |
+
bottom_padding: int = 5,
|
26 |
+
font_bytes: Optional[bytes] = None,
|
27 |
+
font_path: Optional[str] = None,
|
28 |
+
) -> Image.Image:
|
29 |
+
"""
|
30 |
+
Render text. This script is entirely adapted from the original script that can be found here:
|
31 |
+
https://github.com/google-research/pix2struct/blob/main/pix2struct/preprocessing/preprocessing_utils.py
|
32 |
+
|
33 |
+
Args:
|
34 |
+
text (`str`, *optional*, defaults to ):
|
35 |
+
Text to render.
|
36 |
+
text_size (`int`, *optional*, defaults to 36):
|
37 |
+
Size of the text.
|
38 |
+
text_color (`str`, *optional*, defaults to `"black"`):
|
39 |
+
Color of the text.
|
40 |
+
background_color (`str`, *optional*, defaults to `"white"`):
|
41 |
+
Color of the background.
|
42 |
+
left_padding (`int`, *optional*, defaults to 5):
|
43 |
+
Padding on the left.
|
44 |
+
right_padding (`int`, *optional*, defaults to 5):
|
45 |
+
Padding on the right.
|
46 |
+
top_padding (`int`, *optional*, defaults to 5):
|
47 |
+
Padding on the top.
|
48 |
+
bottom_padding (`int`, *optional*, defaults to 5):
|
49 |
+
Padding on the bottom.
|
50 |
+
font_bytes (`bytes`, *optional*):
|
51 |
+
Bytes of the font to use. If `None`, the default font will be used.
|
52 |
+
font_path (`str`, *optional*):
|
53 |
+
Path to the font to use. If `None`, the default font will be used.
|
54 |
+
"""
|
55 |
+
wrapper = textwrap.TextWrapper(
|
56 |
+
width=80
|
57 |
+
) # Add new lines so that each line is no more than 80 characters.
|
58 |
+
lines = wrapper.wrap(text=text)
|
59 |
+
wrapped_text = "\n".join(lines)
|
60 |
+
|
61 |
+
if font_bytes is not None and font_path is None:
|
62 |
+
font = io.BytesIO(font_bytes)
|
63 |
+
elif font_path is not None:
|
64 |
+
font = font_path
|
65 |
+
else:
|
66 |
+
font = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF")
|
67 |
+
raise ValueError(
|
68 |
+
"Either font_bytes or font_path must be provided. "
|
69 |
+
f"Using default font {font}."
|
70 |
+
)
|
71 |
+
font = ImageFont.truetype(font, encoding="UTF-8", size=text_size)
|
72 |
+
|
73 |
+
# Use a temporary canvas to determine the width and height in pixels when
|
74 |
+
# rendering the text.
|
75 |
+
temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color))
|
76 |
+
_, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font)
|
77 |
+
|
78 |
+
# Create the actual image with a bit of padding around the text.
|
79 |
+
image_width = text_width + left_padding + right_padding
|
80 |
+
image_height = text_height + top_padding + bottom_padding
|
81 |
+
image = Image.new("RGB", (image_width, image_height), background_color)
|
82 |
+
draw = ImageDraw.Draw(image)
|
83 |
+
draw.text(
|
84 |
+
xy=(left_padding, top_padding), text=wrapped_text, fill=text_color, font=font
|
85 |
+
)
|
86 |
+
return image
|
87 |
+
|
88 |
+
|
89 |
+
# Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87
|
90 |
+
def render_header(
|
91 |
+
image: Image.Image, header: str, bbox: Dict[str, float], font_path: str, **kwargs
|
92 |
+
) -> Tuple[Image.Image, Tuple[float, float, float, float]]:
|
93 |
+
"""
|
94 |
+
Renders the input text as a header on the input image and updates the bounding box.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
image (Image.Image):
|
98 |
+
The image to render the header on.
|
99 |
+
header (str):
|
100 |
+
The header text.
|
101 |
+
bbox (Dict[str,float]):
|
102 |
+
The bounding box in relative position (0-1), format ("x_min": 0,
|
103 |
+
"y_min": 0,
|
104 |
+
"x_max": 0,
|
105 |
+
"y_max": 0).
|
106 |
+
input_data_format (Union[str, ChildProcessError], optional):
|
107 |
+
The data format of the image.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
Tuple[Image.Image, Dict[str, float] ]:
|
111 |
+
The image with the header rendered and the updated bounding box.
|
112 |
+
"""
|
113 |
+
assert os.path.exists(font_path), f"Font path {font_path} does not exist."
|
114 |
+
header_image = render_text(text=header, font_path=font_path, **kwargs)
|
115 |
+
new_width = max(header_image.width, image.width)
|
116 |
+
|
117 |
+
new_height = int(image.height * (new_width / image.width))
|
118 |
+
new_header_height = int(header_image.height * (new_width / header_image.width))
|
119 |
+
|
120 |
+
new_image = Image.new("RGB", (new_width, new_height + new_header_height), "white")
|
121 |
+
new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0))
|
122 |
+
new_image.paste(image.resize((new_width, new_height)), (0, new_header_height))
|
123 |
+
|
124 |
+
new_total_height = new_image.height
|
125 |
+
|
126 |
+
new_bbox = {
|
127 |
+
"xmin": bbox["xmin"],
|
128 |
+
"ymin": ((bbox["ymin"] * new_height) + new_header_height)
|
129 |
+
/ new_total_height, # shift y_min down by the header's relative height
|
130 |
+
"xmax": bbox["xmax"],
|
131 |
+
"ymax": ((bbox["ymax"] * new_height) + new_header_height)
|
132 |
+
/ new_total_height, # shift y_min down by the header's relative height
|
133 |
+
}
|
134 |
+
|
135 |
+
return (
|
136 |
+
new_image,
|
137 |
+
new_bbox,
|
138 |
+
{
|
139 |
+
"width": new_width,
|
140 |
+
"height": new_height,
|
141 |
+
"header_height": new_header_height,
|
142 |
+
"total_height": new_total_height,
|
143 |
+
},
|
144 |
+
)
|