File size: 1,836 Bytes
27e891b
 
 
e4f5644
27e891b
bb5cd12
 
 
 
 
7b13977
be6293d
7b13977
f3b81e7
be6293d
bb5cd12
 
be6293d
bb5cd12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02bc168
bb5cd12
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

import os
os.system("pip install deepspeed")
os.system("pip freeze")

import gradio as gr
import re
from magma import Magma
from magma.image_input import ImageInput

from huggingface_hub import hf_hub_url, cached_download

checkpoint_url = hf_hub_url(repo_id="osanseviero/magma", filename="model.pt")
checkpoint_path = cached_download(checkpoint_url)

model = Magma.from_checkpoint(
    config_path = "configs/MAGMA_v1.yml",
    checkpoint_path = checkpoint_path,
    device = 'cuda:0'
)

def generate(context, length, temperature, top_k):
  context = context.strip()

  url_regex = r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
  lines = context.split('\n')
  inputs = []
  for line in lines:
    if re.match(url_regex, line):
      try:
        inputs.append(ImageInput(line))
      except Exception as e:
        return str(e)
    else:
      inputs.append(line)

  ## returns a tensor of shape: (1, 149, 4096)
  embeddings = model.preprocess_inputs(inputs)  

  ## returns a list of length embeddings.shape[0] (batch size)
  output = model.generate(
    embeddings = embeddings,
    max_steps = length,
    temperature = (0.01 if temperature == 0 else temperature),
    top_k = top_k
  )

  return context + output[0]

iface = gr.Interface(
  fn=generate,
  inputs=[
    gr.inputs.Textbox(
      label="Prompt (image URLs need to be on their own lines):",
      default="https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg\nDescribe the painting:",
      lines=7),
    gr.inputs.Slider(minimum=1, maximum=100, default=15, step=1, label="Output tokens:"),
    gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.7, label='Temperature'),
    gr.inputs.Slider(minimum=0, maximum=100, default=0, step=1, label='Top K')
  ],
  outputs=["textbox"]
).launch()