martinsinnona commited on
Commit
7b0ea0f
1 Parent(s): 0e2c012
Files changed (2) hide show
  1. app.py +39 -9
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,32 +2,62 @@ import gradio as gr
2
  from transformers import AutoProcessor, Pix2StructForConditionalGeneration
3
  import torch
4
  from PIL import Image
 
 
 
 
5
 
6
  # Load the processor and model
7
  processor = AutoProcessor.from_pretrained("google/matcha-base")
8
  processor.image_processor.is_vqa = False
9
- model = Pix2StructForConditionalGeneration.from_pretrained("martinsinnona/visdecode_B").to("cuda" if torch.cuda.is_available() else "cpu")
 
10
  model.eval()
11
 
12
  def generate_caption(image):
13
 
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
-
16
  inputs = processor(images=image, return_tensors="pt", max_patches=1024).to(device)
17
  generated_ids = model.generate(flattened_patches=inputs.flattened_patches, attention_mask=inputs.attention_mask, max_length=600)
18
  generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
 
 
19
 
20
- return generated_caption
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Create the Gradio interface
23
- demo = gr.Interface(
 
24
  fn=generate_caption,
25
  inputs=gr.Image(type="pil"),
26
- outputs="text",
27
- title="Image to Text Generator",
28
- description="Upload an image and get a generated caption."
29
  )
30
 
31
  # Launch the interface
32
  if __name__ == "__main__":
33
- demo.launch(share=True)
 
2
  from transformers import AutoProcessor, Pix2StructForConditionalGeneration
3
  import torch
4
  from PIL import Image
5
+ import json
6
+ import vl_convert as vlc # Ensure you have this library installed (pip install vl-convert)
7
+
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
 
10
  # Load the processor and model
11
  processor = AutoProcessor.from_pretrained("google/matcha-base")
12
  processor.image_processor.is_vqa = False
13
+
14
+ model = Pix2StructForConditionalGeneration.from_pretrained("martinsinnona/visdecode_B").to(device)
15
  model.eval()
16
 
17
  def generate_caption(image):
18
 
 
 
19
  inputs = processor(images=image, return_tensors="pt", max_patches=1024).to(device)
20
  generated_ids = model.generate(flattened_patches=inputs.flattened_patches, attention_mask=inputs.attention_mask, max_length=600)
21
  generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
22
+
23
+ # Generate the Vega image
24
+ vega = string_to_vega(generated_caption)
25
+ vega_image = draw_vega(vega)
26
+
27
+ return generated_caption, vega_image
28
+
29
+ def draw_vega(vega, scale=3):
30
 
31
+ spec = json.dumps(vega, indent=4)
32
+ png_data = vlc.vegalite_to_png(vl_spec=spec, scale=scale)
33
+
34
+ return Image.open(png_data)
35
+
36
+ def string_to_vega(string):
37
+
38
+ string = string.replace("'", "\"")
39
+ vega = json.loads(string)
40
+
41
+ for axis in ["x", "y"]:
42
+ field = vega["encoding"][axis]["field"]
43
+ if field == "":
44
+ vega["encoding"][axis]["field"] = axis
45
+ vega["encoding"][axis]["title"] = ""
46
+ else:
47
+ for entry in vega["data"]["values"]:
48
+ entry[field] = entry.pop(axis)
49
+ return vega
50
 
51
  # Create the Gradio interface
52
+ iface = gr.Interface(
53
+
54
  fn=generate_caption,
55
  inputs=gr.Image(type="pil"),
56
+ outputs=[gr.Textbox(), gr.Image(type="pil")],
57
+ title="Image to Vega-Lite",
58
+ description="Upload an image to generate vega-lite"
59
  )
60
 
61
  # Launch the interface
62
  if __name__ == "__main__":
63
+ iface.launch(share=True)
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  transformers
2
- torch
 
 
1
  transformers
2
+ torch
3
+ vl-convert-python