magma / app.py
stellaathena's picture
Update app.py
09b6ca0
raw
history blame contribute delete
No virus
2.82 kB
import os
os.system("pip install deepspeed")
os.system("pip freeze")
import gradio as gr
import re
from magma import Magma
from magma.image_input import ImageInput
from huggingface_hub import hf_hub_url, cached_download
checkpoint_url = hf_hub_url(repo_id="osanseviero/magma", filename="model.pt")
checkpoint_path = cached_download(checkpoint_url)
model = Magma.from_checkpoint(
config_path = "configs/MAGMA_v1.yml",
checkpoint_path = checkpoint_path,
device = 'cuda:0'
)
def generate(image,context, length, temperature, top_k,rearrange):
# context = context.strip()
# url_regex = r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
# lines = context.split('\n')
# inputs = []
# for line in lines:
# if re.match(url_regex, line):
# try:
# inputs.append(ImageInput(line))
# except Exception as e:
# return str(e)
# else:
# inputs.append(line)
if rearrange:
inputs =[
## supports urls and path/to/image
context,
ImageInput(image)
]
else:
inputs =[
## supports urls and path/to/image
ImageInput(image),
context
]
## returns a tensor of shape: (1, 149, 4096)
embeddings = model.preprocess_inputs(inputs)
## returns a list of length embeddings.shape[0] (batch size)
output = model.generate(
embeddings = embeddings,
max_steps = length,
temperature = (0.01 if temperature == 0 else temperature),
top_k = top_k
)
return output[0]
examples=[["woods_hi.jpeg","Describe the painting:",15,0.7,0,False], ["E8EB3C7B-291C-400A-81F2-AE9229D9CE23.jpeg", "Q: Is the person in the image older than 35?\nA: " , 15, 0.7, 0, False]]
title="MAGMA"
description="Gradio Demo for MAGMA -- Multimodal Augmentation of Generative Models through Adapter-based Finetuning by Constantin Eichenberg, Sid Black, Samuel Weinbach, Letitia Parcalabescu, and Anette Frank<br> <br><a href='https://arxiv.org/abs/2112.05253' target='_blank'>arXiv</a> | <a href='https://github.com/Aleph-Alpha/magma' target='_blank'>Github Repo</a>"
article = ""
iface = gr.Interface(
fn=generate,
inputs=[
gr.inputs.Image(type="filepath",label="Image Prompt"),gr.inputs.Textbox(
label="Text Prompt:",
default="Describe the painting:",
lines=7),
gr.inputs.Slider(minimum=1, maximum=100, default=15, step=1, label="Output tokens:"),
gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.7, label='Temperature'),
gr.inputs.Slider(minimum=0, maximum=100, default=0, step=1, label='Top K'),
gr.inputs.Checkbox(default=False, label="Rearrange Prompt", optional=False)
],
outputs=["textbox"],
examples=examples,
title=title,
description=description,
article=article
).launch(enable_queue=True,cache_examples=True)