File size: 2,748 Bytes
27e891b
 
 
e4f5644
27e891b
bb5cd12
 
 
 
 
7b13977
be6293d
7b13977
f3b81e7
be6293d
bb5cd12
 
be6293d
bb5cd12
 
 
0484642
be6090c
bb5cd12
be6090c
 
 
 
 
 
 
 
 
 
 
0484642
 
 
 
 
 
 
 
 
 
 
 
bb5cd12
 
 
 
 
 
 
 
 
 
 
 
11c80e1
10e850e
0484642
bb5cd12
10e850e
7ae7e2a
 
bb5cd12
 
 
0484642
 
be6090c
0601286
bb5cd12
 
0484642
 
bb5cd12
10e850e
 
 
 
 
11c80e1
bb5cd12
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

import os
os.system("pip install deepspeed")
os.system("pip freeze")

import gradio as gr
import re
from magma import Magma
from magma.image_input import ImageInput

from huggingface_hub import hf_hub_url, cached_download

checkpoint_url = hf_hub_url(repo_id="osanseviero/magma", filename="model.pt")
checkpoint_path = cached_download(checkpoint_url)

model = Magma.from_checkpoint(
    config_path = "configs/MAGMA_v1.yml",
    checkpoint_path = checkpoint_path,
    device = 'cuda:0'
)

def generate(image,context, length, temperature, top_k,rearrange):
  # context = context.strip()

  # url_regex = r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
  # lines = context.split('\n')
  # inputs = []
  # for line in lines:
  #   if re.match(url_regex, line):
  #     try:
  #       inputs.append(ImageInput(line))
  #     except Exception as e:
  #       return str(e)
  #   else:
  #     inputs.append(line)
  if rearrange:
    inputs =[
      ## supports urls and path/to/image
      context,
      ImageInput(image)
    ]
  else:
    inputs =[
      ## supports urls and path/to/image
      ImageInput(image),
      context
    ]

  ## returns a tensor of shape: (1, 149, 4096)
  embeddings = model.preprocess_inputs(inputs)  

  ## returns a list of length embeddings.shape[0] (batch size)
  output = model.generate(
    embeddings = embeddings,
    max_steps = length,
    temperature = (0.01 if temperature == 0 else temperature),
    top_k = top_k
  )

  return output[0]
  
examples=[["woods_hi.jpeg","Describe the painting:",15,0.7,0,False]]

title="MAGMA"
description="Gradio Demo for MAGMA -- Multimodal Augmentation of Generative Models through Adapter-based Finetuning. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.05253' target='_blank'>MAGMA -- Multimodal Augmentation of Generative Models through Adapter-based Finetuning</a> | <a href='https://github.com/Aleph-Alpha/magma' target='_blank'>Github Repo</a></p>"
iface = gr.Interface(
  fn=generate,
  inputs=[
    gr.inputs.Image(type="filepath",label="Image Prompt"),gr.inputs.Textbox(
      label="Text Prompt:",
      default="Describe the painting:",
      lines=7),
    gr.inputs.Slider(minimum=1, maximum=100, default=15, step=1, label="Output tokens:"),
    gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.7, label='Temperature'),
    gr.inputs.Slider(minimum=0, maximum=100, default=0, step=1, label='Top K'),
    gr.inputs.Checkbox(default=False, label="Rearrange Prompt", optional=False)
  ],
  outputs=["textbox"],
  examples=examples,
  title=title,
  description=description,
  article=article
).launch(enable_queue=True,cache_examples=True)