import tempfile
from share_btn import community_icon_html, loading_icon_html, share_js, save_js
import huggingface_hub
import gradio as gr
from gill import utils
from gill import models
import matplotlib.pyplot as plt
from PIL import Image
import torch
import numpy as np
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
css = """
#chatbot { min-height: 300px; }
#save-btn {
background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0));
}
#save-btn:hover {
background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0));
}
#share-btn {
background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0));
}
#share-btn:hover {
background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0));
}
#gallery { z-index: 999999; }
#gallery img:hover {transform: scale(2.3); z-index: 999999; position: relative; padding-right: 30%; padding-bottom: 30%;}
#gallery button img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; padding-bottom: 0;}
@media (hover: none) {
#gallery img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; 0;}
}
"""
examples = [
'examples/sparrow.png',
'examples/beaver.png',
'examples/couch.png',
'examples/guac.png',
'examples/scraped_knee.png'
]
# Download model from HF Hub.
ckpt_path = huggingface_hub.hf_hub_download(
repo_id='jykoh/gill', filename='pretrained_ckpt.pth.tar')
decision_model_path = huggingface_hub.hf_hub_download(
repo_id='jykoh/gill', filename='decision_model.pth.tar')
args_path = huggingface_hub.hf_hub_download(
repo_id='jykoh/gill', filename='model_args.json')
model = models.load_gill('./', args_path, ckpt_path, decision_model_path)
def upload_image(state, image_input):
conversation = state[0]
chat_history = state[1]
input_image = Image.open(image_input.name).resize(
(224, 224)).convert('RGB')
input_image.save(image_input.name) # Overwrite with smaller image.
conversation += [(f'', "")]
return [conversation, chat_history + [input_image, ""]], conversation
def reset():
return [[], []], []
def reset_last(state):
conversation = state[0][:-1]
chat_history = state[1][:-2]
return [conversation, chat_history], conversation
def save_image_to_local(image: Image.Image):
# TODO(jykoh): Update so the url path is used, to prevent repeat saving.
filename = next(tempfile._get_candidate_names()) + '.png'
image.save(filename)
return filename
def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperature):
# Ignore empty inputs.
if len(input_text) == 0:
return state, state[0], gr.update(visible=True)
input_prompt = 'Q: ' + input_text + '\nA:'
conversation = state[0]
chat_history = state[1]
print('Generating for', chat_history, flush=True)
# If an image was uploaded, prepend it to the model.
model_inputs = chat_history
model_inputs.append(input_prompt)
top_p = 1.0
if temperature != 0.0:
top_p = 0.95
print('Running model.generate_for_images_and_texts with',
model_inputs, flush=True)
model_outputs = model.generate_for_images_and_texts(model_inputs,
num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
temperature=temperature, max_num_rets=1,
num_inference_steps=1)
print('model_outputs', model_outputs, ret_scale_factor, flush=True)
im_names = []
response = ''
text_outputs = []
for output_i, p in enumerate(model_outputs):
if type(p) == str:
if output_i > 0:
response += '
'
# Remove the image tokens for output.
text_outputs.append(p.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', ''))
response += p
if len(model_outputs) > 1:
response += '
'
elif type(p) == dict:
# Decide whether to generate or retrieve.
if p['decision'] is not None and p['decision'][0] == 'gen':
image = p['gen'][0][0].resize((512, 512))
filename = save_image_to_local(image)
response += f'
(Generated)
' else: image = p['ret'][0][0].resize((512, 512)) filename = save_image_to_local(image) response += f'(Retrieved)
' chat_history = model_inputs + \ [' '.join([s for s in model_outputs if type(s) == str]) + '\n'] # Remove [RET] from outputs. conversation.append((input_text, response.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', ''))) # Set input image to None. print('state', state, flush=True) print('updated state', [conversation, chat_history], flush=True) return [conversation, chat_history], conversation, gr.update(visible=True), gr.update(visible=True) with gr.Blocks(css=css) as demo: gr.HTML("""This is the official Gradio demo for the FROMAGe model, a model that can process arbitrarily interleaved image and text inputs, and produce image and text outputs.
Paper: Grounding Language Models to Images for Multimodal Generation