Pheye / app.py
miguelcarv's picture
first commit
34f251f
raw
history blame
No virus
3.78 kB
import gradio as gr
from huggingface_hub import InferenceClient
import json
from pheye_builder import create_model_and_transforms
from huggingface_hub import hf_hub_download
import torch
from PIL import Image
import os
import requests
def get_config(hf_model_path):
config_path = hf_hub_download(hf_model_path, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
return config
def get_model_path(hf_model_path):
return hf_hub_download(hf_model_path, "checkpoint.pt")
HF_MODEL = "miguelcarv/Pheye-x2-672"
config = get_config(HF_MODEL)
print("Got config")
model, tokenizer = create_model_and_transforms(
clip_vision_encoder_path=config["encoder"],
lang_decoder_path=config["decoder"],
tokenizer_path=config["tokenizer"],
cross_attn_every_n_layers=config["cross_interval"],
level=config["level"],
reduce_factor=config["reduce"],
from_layer=config["from_layer"],
encoder_dtype=eval(config["encoder_dtype"]),
decoder_dtype=eval(config["decoder_dtype"]),
dtype=eval(config["other_params_dtype"])
)
if config["first_level"]:
model.vision_encoder.add_first_level_adapter()
print("Created model")
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = get_model_path(HF_MODEL)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model = model.to(DEVICE)
print("Loaded model")
SYSTEM_PROMPT = "You are an AI visual assistant and you are seeing a single image. You will receive an instruction regarding that image. Your goal is to follow the instruction as faithfully as you can."
whiteboard = Image.open(requests.get("https://c1.staticflickr.com/7/6168/6207108414_a8833f410e_o.jpg", stream=True).raw).convert('RGB')
taxi_image = Image.open(requests.get("https://llava.hliu.cc/file=/nobackup/haotian/tmp/gradio/ca10383cc943e99941ecffdc4d34c51afb2da472/extreme_ironing.jpg", stream=True).raw).convert('RGB')
def generate_answer(img, question, max_new_tokens, num_beams):
image = [img]
prompt = [f"{SYSTEM_PROMPT}\n\nInstruction: {question}\nOutput:"]
inputs = tokenizer(prompt, padding='longest', return_tensors='pt')
print("Generating a response with the following parameters:")
print(f"""Question: {question}\nMax New Tokens: {max_new_tokens}\nNum Beams: {num_beams}""")
model.eval()
with torch.no_grad():
outputs = model.generate(vision_x=image,
lang_x=inputs.input_ids.to(DEVICE),
device=DEVICE,
max_new_tokens=max_new_tokens,
num_beams = num_beams,
eos_token_id = tokenizer.eos_token_id,
pad_token_id = tokenizer.pad_token_id,
attention_mask=inputs.attention_mask.to(DEVICE))
answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].split("Output:")[-1].lstrip()
return answer
# Create the Gradio interface
iface = gr.Interface(
fn=generate_answer,
inputs=[
gr.Image(type="pil", label="Image"),
gr.Textbox(label="Question"),
gr.Slider(minimum=5, maximum=500, step=1, value=50, label="Max New Tokens"),
gr.Slider(minimum=1, maximum=5, step=1, value=3, label="Num Beams")
],
outputs=gr.Textbox(label="Answer"),
title="<h1 style='text-align: center; display: block;'>Pheye-x2 672x672 pixels</h1>",
examples=[[taxi_image, "What is unusual about this image?"], [whiteboard, "What is the main topic of the whiteboard?"]]
)
if __name__ == "__main__":
# Launch the Gradio app
iface.launch()