import os
os.system("pip install deepspeed")
os.system("pip freeze")
import gradio as gr
import re
from magma import Magma
from magma.image_input import ImageInput
from huggingface_hub import hf_hub_url, cached_download
checkpoint_url = hf_hub_url(repo_id="osanseviero/magma", filename="model.pt")
checkpoint_path = cached_download(checkpoint_url)
model = Magma.from_checkpoint(
config_path = "configs/MAGMA_v1.yml",
checkpoint_path = checkpoint_path,
device = 'cuda:0'
)
def generate(image,context, length, temperature, top_k,rearrange):
# context = context.strip()
# url_regex = r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
# lines = context.split('\n')
# inputs = []
# for line in lines:
# if re.match(url_regex, line):
# try:
# inputs.append(ImageInput(line))
# except Exception as e:
# return str(e)
# else:
# inputs.append(line)
if rearrange:
inputs =[
## supports urls and path/to/image
context,
ImageInput(image)
]
else:
inputs =[
## supports urls and path/to/image
ImageInput(image),
context
]
## returns a tensor of shape: (1, 149, 4096)
embeddings = model.preprocess_inputs(inputs)
## returns a list of length embeddings.shape[0] (batch size)
output = model.generate(
embeddings = embeddings,
max_steps = length,
temperature = (0.01 if temperature == 0 else temperature),
top_k = top_k
)
return output[0]
examples=[["woods_hi.jpeg","Describe the painting:",15,0.7,0,False], ["E8EB3C7B-291C-400A-81F2-AE9229D9CE23.jpeg", "Q: Is the person in the image older than 35?\nA: " , 15, 0.7, 0, False]]
title="MAGMA"
description="Gradio Demo for MAGMA -- Multimodal Augmentation of Generative Models through Adapter-based Finetuning by Constantin Eichenberg, Sid Black, Samuel Weinbach, Letitia Parcalabescu, and Anette Frank
arXiv | Github Repo"
article = ""
iface = gr.Interface(
fn=generate,
inputs=[
gr.inputs.Image(type="filepath",label="Image Prompt"),gr.inputs.Textbox(
label="Text Prompt:",
default="Describe the painting:",
lines=7),
gr.inputs.Slider(minimum=1, maximum=100, default=15, step=1, label="Output tokens:"),
gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.7, label='Temperature'),
gr.inputs.Slider(minimum=0, maximum=100, default=0, step=1, label='Top K'),
gr.inputs.Checkbox(default=False, label="Rearrange Prompt", optional=False)
],
outputs=["textbox"],
examples=examples,
title=title,
description=description,
article=article
).launch(enable_queue=True,cache_examples=True)