magma / app.py
osanseviero's picture
osanseviero HF staff
Update app.py
e4f5644
raw history blame
No virus
1.79 kB
import os
os.system("pip install deepspeed")
os.system("pip freeze")
import gradio as gr
import re
from magma import Magma
from magma.image_input import ImageInput
from huggingface_hub import hf_hub_download
checkpoint_path = hf_hub_download(repo_id="osanseviero/magma", filename="model.pt")
model = Magma.from_checkpoint(
config_path = "configs/MAGMA_v1.yml",
checkpoint_path = checkpoint_path,
device = 'cuda:0'
)
def generate(context, length, temperature, top_k):
context = context.strip()
url_regex = r'https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)'
lines = context.split('\n')
inputs = []
for line in lines:
if re.match(url_regex, line):
try:
inputs.append(ImageInput(line))
except Exception as e:
return str(e)
else:
inputs.append(line)
## returns a tensor of shape: (1, 149, 4096)
embeddings = model.preprocess_inputs(inputs)
## returns a list of length embeddings.shape[0] (batch size)
output = model.generate(
embeddings = embeddings,
max_steps = length,
temperature = (0.01 if temperature == 0 else temperature),
top_k = top_k
)
return context + output[0]
iface = gr.Interface(
fn=generate,
inputs=[
gr.inputs.Textbox(
label="Prompt (image URLs need to be on their own lines):",
default="https://www.art-prints-on-demand.com/kunst/thomas_cole/woods_hi.jpg\nDescribe the painting:",
lines=7),
gr.inputs.Slider(minimum=1, maximum=100, default=15, step=1, label="Output tokens:"),
gr.inputs.Slider(minimum=0.0, maximum=1.0, default=0.7, label='Temperature'),
gr.inputs.Slider(minimum=0, maximum=100, default=0, step=1, label='Top K')
],
outputs=["textbox"]
).launch(share=True)