|
import os, time, copy |
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" |
|
|
|
from PIL import Image |
|
|
|
import gradio as gr |
|
|
|
import numpy as np |
|
import torch |
|
from transformers import logging |
|
logging.set_verbosity_error() |
|
|
|
from fromage import models |
|
from fromage import utils |
|
|
|
BASE_WIDTH = 512 |
|
MODEL_DIR = './fromage_model/fromage_vis4' |
|
|
|
|
|
class ChatBotCheese: |
|
def __init__(self): |
|
from huggingface_hub import hf_hub_download |
|
model_ckpt_path = hf_hub_download("alvanlii/fromage", "pretrained_ckpt.pth.tar") |
|
self.model = models.load_fromage(MODEL_DIR, model_ckpt_path) |
|
self.curr_image = None |
|
|
|
def add_image(self, state, image_in): |
|
state = state + [(f"![](/file={image_in.name})", "Ok, now type your message")] |
|
self.curr_image = Image.open(image_in.name).convert('RGB') |
|
return state, state |
|
|
|
def save_im(self, image_pil): |
|
file_name = f"{int(time.time())}_{np.random.randint(100)}.png" |
|
image_pil.save(file_name) |
|
return file_name |
|
|
|
def chat(self, input_text, state, ret_scale_factor, num_ims, num_words, temp, chat_state): |
|
chat_state.append(f'Q: {input_text} \nA:') |
|
chat_history = " ".join(chat_state) |
|
model_input = [] |
|
print(chat_history) |
|
if self.curr_image is not None: |
|
model_input = [self.curr_image, chat_history] |
|
else: |
|
model_input = [chat_history] |
|
|
|
model_outputs = self.model.generate_for_images_and_texts(model_input, max_num_rets=num_ims, num_words=num_words, ret_scale_factor=ret_scale_factor, temperature=temp) |
|
chat_state.append(' '.join([s for s in model_outputs if type(s) == str]) + '\n') |
|
|
|
im_names = [] |
|
if len(model_outputs) > 1: |
|
im_names = [self.save_im(im) for im in model_outputs[1]] |
|
|
|
response = model_outputs[0] |
|
for im_name in im_names: |
|
response += f'<img src="/file={im_name}">' |
|
state.append((input_text, response.replace("[RET]", ""))) |
|
|
|
return state, state, chat_state |
|
|
|
def reset(self): |
|
self.curr_image = None |
|
return [], [], [] |
|
|
|
def main(self): |
|
with gr.Blocks(css="#chatbot {height:600px; overflow-y:auto;}") as demo: |
|
gr.Markdown( |
|
""" |
|
### FROMAGe: Grounding Language Models to Images for Multimodal Generation |
|
Jing Yu Koh, Ruslan Salakhutdinov, Daniel Fried <br/> |
|
[Paper](https://arxiv.org/abs/2301.13823) [Github](https://github.com/kohjingyu/fromage) [Official Demo](https://huggingface.co/spaces/jykoh/fromage) <br/> |
|
This is an unofficial Gradio demo for the paper FROMAGe <br/> |
|
- Instructions (in order): |
|
- [Optional] Upload an image (the button with a photo emoji) |
|
- [Optional] Change the parameters |
|
- Send a message by typing into the box and pressing Enter on your keyboard |
|
- Ask about the image! Tell it to find similar images, or ones with different styles. |
|
- Check out the examples at the bottom! |
|
##### Notes |
|
- Please be kind to it! |
|
- It retrieves images from a database, and does not edit images |
|
- If it returns nothing, try resetting and refreshing the page |
|
""" |
|
) |
|
|
|
chatbot = gr.Chatbot(elem_id="chatbot") |
|
gr_state = gr.State([]) |
|
gr_chat_state = gr.State([]) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.85): |
|
txt = gr.Textbox(show_label=False, placeholder="Upload an image first [Optional]. Then enter text and press enter,").style(container=False) |
|
with gr.Column(scale=0.15, min_width=0): |
|
btn = gr.UploadButton("🖼️", file_types=["image"]) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=0.20, min_width=0): |
|
reset_btn = gr.Button("Reset Messages") |
|
gr_ret_scale_factor = gr.Number(value=1.0, label="Increased prob of returning images", interactive=True) |
|
gr_num_ims = gr.Number(value=3, precision=1, label="Max # of Images returned", interactive=True) |
|
gr_num_words = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True) |
|
gr_temp = gr.Number(value=0.0, label="Temperature", interactive=True) |
|
|
|
with gr.Row(): |
|
gr.Image("example_1.png", label="Example 1") |
|
gr.Image("example_2.png", label="Example 2") |
|
gr.Image("example_3.png", label="Example 3") |
|
|
|
|
|
txt.submit(self.chat, [txt, gr_state, gr_ret_scale_factor, gr_num_ims, gr_num_words, gr_temp, gr_chat_state], [gr_state, chatbot, gr_chat_state]) |
|
txt.submit(lambda :"", None, txt) |
|
btn.upload(self.add_image, [gr_state, btn], [gr_state, chatbot]) |
|
reset_btn.click(self.reset, [], [gr_state, chatbot, gr_chat_state]) |
|
|
|
|
|
|
|
|
|
demo.launch(share=False, server_name="0.0.0.0") |
|
|
|
def main(): |
|
cheddar = ChatBotCheese() |
|
cheddar.main() |
|
|
|
if __name__ == "__main__": |
|
main() |