mrahmed0499 commited on
Commit
ba513bb
1 Parent(s): 5df5ec6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +264 -0
app.py CHANGED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import spaces
4
+ import torch
5
+
6
+ import gradio as gr
7
+
8
+ from gradio_client.client import DEFAULT_TEMP_DIR
9
+ from playwright.sync_api import sync_playwright
10
+ from threading import Thread
11
+ from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer
12
+ from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
13
+ from typing import List
14
+ from PIL import Image
15
+
16
+ from transformers.image_transforms import resize, to_channel_dimension_format
17
+
18
+
19
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
20
+
21
+ DEVICE = torch.device("cuda")
22
+ PROCESSOR = AutoProcessor.from_pretrained(
23
+ "HuggingFaceM4/VLM_WebSight_finetuned",
24
+ )
25
+ MODEL = AutoModelForCausalLM.from_pretrained(
26
+ "HuggingFaceM4/VLM_WebSight_finetuned",
27
+ trust_remote_code=True,
28
+ torch_dtype=torch.bfloat16,
29
+ ).to(DEVICE)
30
+ if MODEL.config.use_resampler:
31
+ image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
32
+ else:
33
+ image_seq_len = (
34
+ MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size
35
+ ) ** 2
36
+ BOS_TOKEN = PROCESSOR.tokenizer.bos_token
37
+ BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
38
+
39
+
40
+ ## Utils
41
+
42
+ def convert_to_rgb(image):
43
+ # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
44
+ # for transparent images. The call to `alpha_composite` handles this case
45
+ if image.mode == "RGB":
46
+ return image
47
+
48
+ image_rgba = image.convert("RGBA")
49
+ background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
50
+ alpha_composite = Image.alpha_composite(background, image_rgba)
51
+ alpha_composite = alpha_composite.convert("RGB")
52
+ return alpha_composite
53
+
54
+ # The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip,
55
+ # so this is a hack in order to redefine ONLY the transform method
56
+ def custom_transform(x):
57
+ x = convert_to_rgb(x)
58
+ x = to_numpy_array(x)
59
+ x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
60
+ x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
61
+ x = PROCESSOR.image_processor.normalize(
62
+ x,
63
+ mean=PROCESSOR.image_processor.image_mean,
64
+ std=PROCESSOR.image_processor.image_std
65
+ )
66
+ x = to_channel_dimension_format(x, ChannelDimension.FIRST)
67
+ x = torch.tensor(x)
68
+ return x
69
+
70
+ ## End of Utils
71
+
72
+
73
+ IMAGE_GALLERY_PATHS = [
74
+ f"example_images/{ex_image}"
75
+ for ex_image in os.listdir(f"example_images")
76
+ ]
77
+
78
+
79
+ def install_playwright():
80
+ try:
81
+ subprocess.run(["playwright", "install"], check=True)
82
+ print("Playwright installation successful.")
83
+ except subprocess.CalledProcessError as e:
84
+ print(f"Error during Playwright installation: {e}")
85
+
86
+ install_playwright()
87
+
88
+
89
+ def add_file_gallery(
90
+ selected_state: gr.SelectData,
91
+ gallery_list: List[str]
92
+ ):
93
+ return Image.open(gallery_list.root[selected_state.index].image.path)
94
+
95
+
96
+ def render_webpage(
97
+ html_css_code,
98
+ ):
99
+ with sync_playwright() as p:
100
+ browser = p.chromium.launch(headless=True)
101
+ context = browser.new_context(
102
+ user_agent=(
103
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0"
104
+ " Safari/537.36"
105
+ )
106
+ )
107
+ page = context.new_page()
108
+ page.set_content(html_css_code)
109
+ page.wait_for_load_state("networkidle")
110
+ output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png"
111
+ _ = page.screenshot(path=output_path_screenshot, full_page=True)
112
+
113
+ context.close()
114
+ browser.close()
115
+
116
+ return Image.open(output_path_screenshot)
117
+
118
+
119
+ @spaces.GPU(duration=180)
120
+ def model_inference(
121
+ image,
122
+ ):
123
+ if image is None:
124
+ raise ValueError("`image` is None. It should be a PIL image.")
125
+
126
+ inputs = PROCESSOR.tokenizer(
127
+ f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
128
+ return_tensors="pt",
129
+ add_special_tokens=False,
130
+ )
131
+ inputs["pixel_values"] = PROCESSOR.image_processor(
132
+ [image],
133
+ transform=custom_transform
134
+ )
135
+ inputs = {
136
+ k: v.to(DEVICE)
137
+ for k, v in inputs.items()
138
+ }
139
+
140
+ streamer = TextIteratorStreamer(
141
+ PROCESSOR.tokenizer,
142
+ skip_prompt=True,
143
+ )
144
+ generation_kwargs = dict(
145
+ inputs,
146
+ bad_words_ids=BAD_WORDS_IDS,
147
+ max_length=4096,
148
+ streamer=streamer,
149
+ )
150
+ # Regular generation version
151
+ # generation_kwargs.pop("streamer")
152
+ # generated_ids = MODEL.generate(**generation_kwargs)
153
+ # generated_text = PROCESSOR.batch_decode(
154
+ # generated_ids,
155
+ # skip_special_tokens=True
156
+ # )[0]
157
+ # rendered_page = render_webpage(generated_text)
158
+ # return generated_text, rendered_page
159
+ # Token streaming version
160
+ thread = Thread(
161
+ target=MODEL.generate,
162
+ kwargs=generation_kwargs,
163
+ )
164
+ thread.start()
165
+ generated_text = ""
166
+ for new_text in streamer:
167
+ if "</s>" in new_text:
168
+ new_text = new_text.replace("</s>", "")
169
+ rendered_image = render_webpage(generated_text)
170
+ else:
171
+ rendered_image = None
172
+ generated_text += new_text
173
+ yield generated_text, rendered_image
174
+
175
+
176
+ generated_html = gr.Code(
177
+ label="Extracted HTML",
178
+ elem_id="generated_html",
179
+ )
180
+ rendered_html = gr.Image(
181
+ label="Rendered HTML",
182
+ show_download_button=False,
183
+ show_share_button=False,
184
+ )
185
+ # rendered_html = gr.HTML(
186
+ # label="Rendered HTML"
187
+ # )
188
+
189
+
190
+ css = """
191
+ .gradio-container{max-width: 1000px!important}
192
+ h1{display: flex;align-items: center;justify-content: center;gap: .25em}
193
+ *{transition: width 0.5s ease, flex-grow 0.5s ease}
194
+ """
195
+
196
+
197
+ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo:
198
+ gr.Markdown(
199
+ "Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content."
200
+ )
201
+ with gr.Row(equal_height=True):
202
+ with gr.Column(scale=4, min_width=250) as upload_area:
203
+ imagebox = gr.Image(
204
+ type="pil",
205
+ label="Screenshot to extract",
206
+ visible=True,
207
+ sources=["upload", "clipboard"],
208
+ )
209
+ with gr.Group():
210
+ with gr.Row():
211
+ submit_btn = gr.Button(
212
+ value="▶️ Submit", visible=True, min_width=120
213
+ )
214
+ clear_btn = gr.ClearButton(
215
+ [imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120
216
+ )
217
+ regenerate_btn = gr.Button(
218
+ value="🔄 Regenerate", visible=True, min_width=120
219
+ )
220
+ with gr.Column(scale=4):
221
+ rendered_html.render()
222
+
223
+ with gr.Row():
224
+ generated_html.render()
225
+
226
+ with gr.Row():
227
+ template_gallery = gr.Gallery(
228
+ value=IMAGE_GALLERY_PATHS,
229
+ label="Templates Gallery",
230
+ allow_preview=False,
231
+ columns=5,
232
+ elem_id="gallery",
233
+ show_share_button=False,
234
+ height=400,
235
+ )
236
+
237
+ gr.on(
238
+ triggers=[
239
+ imagebox.upload,
240
+ submit_btn.click,
241
+ regenerate_btn.click,
242
+ ],
243
+ fn=model_inference,
244
+ inputs=[imagebox],
245
+ outputs=[generated_html, rendered_html],
246
+ )
247
+ regenerate_btn.click(
248
+ fn=model_inference,
249
+ inputs=[imagebox],
250
+ outputs=[generated_html, rendered_html],
251
+ )
252
+ template_gallery.select(
253
+ fn=add_file_gallery,
254
+ inputs=[template_gallery],
255
+ outputs=[imagebox],
256
+ ).success(
257
+ fn=model_inference,
258
+ inputs=[imagebox],
259
+ outputs=[generated_html, rendered_html],
260
+ )
261
+ demo.load()
262
+
263
+ demo.queue(max_size=40, api_open=False)
264
+ demo.launch(max_threads=400)