magma / app.py
stellaathena's picture
This should work
bb5cd12
raw
history blame
1.6 kB
import gradio as gr
import re
from magma import Magma
from magma.image_input import ImageInput
model = Magma.from_checkpoint(
config_path = "configs/MAGMA_v1.yml",
checkpoint_path = "./mp_rank_00_model_states.pt",
device = 'cuda:0'
)
def generate(context, length, temperature, top_k):
context = context.strip()
url_regex = r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
lines = context.split('\n')
inputs = []
for line in lines:
if re.match(url_regex, line):
try:
inputs.append(ImageInput(line))
except Exception as e:
return str(e)
else:
inputs.append(line)
## returns a tensor of shape: (1, 149, 4096)
embeddings = model.preprocess_inputs(inputs)
## returns a list of length embeddings.shape[0] (batch size)
output = model.generate(
embeddings = embeddings,
max_steps = length,
temperature = (0.01 if temperature == 0 else temperature),
top_k = top_k
)
return context + output[0]
iface = gr.Interface(
fn=generate,
inputs=[
gr.inputs.Textbox(
label="Prompt (image URLs need to be on their own lines):",
default="https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg\nDescribe the painting:",
lines=7),
gr.inputs.Slider(minimum=1, maximum=100, default=15, step=1, label="Output tokens:"),
gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.7, label='Temperature'),
gr.inputs.Slider(minimum=0, maximum=100, default=0, step=1, label='Top K')
],
outputs=["textbox"]
).launch(share=True)