File size: 3,778 Bytes
7dedb9e
 
34f251f
 
 
 
 
 
 
7dedb9e
34f251f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dedb9e
34f251f
 
 
7dedb9e
 
 
34f251f
 
7dedb9e
34f251f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
from huggingface_hub import InferenceClient
import json
from pheye_builder import create_model_and_transforms
from huggingface_hub import hf_hub_download
import torch
from PIL import Image
import os
import requests


def get_config(hf_model_path):
    config_path = hf_hub_download(hf_model_path, "config.json")

    with open(config_path, "r") as f:
        config = json.load(f)

    return config


def get_model_path(hf_model_path):
    return hf_hub_download(hf_model_path, "checkpoint.pt")


HF_MODEL = "miguelcarv/Pheye-x2-672"
config = get_config(HF_MODEL)

print("Got config")

model, tokenizer = create_model_and_transforms(
            clip_vision_encoder_path=config["encoder"],
            lang_decoder_path=config["decoder"],
            tokenizer_path=config["tokenizer"],
            cross_attn_every_n_layers=config["cross_interval"],
            level=config["level"],
            reduce_factor=config["reduce"],
            from_layer=config["from_layer"],
            encoder_dtype=eval(config["encoder_dtype"]),
            decoder_dtype=eval(config["decoder_dtype"]),
            dtype=eval(config["other_params_dtype"])
        ) 

if config["first_level"]:
    model.vision_encoder.add_first_level_adapter()

print("Created model")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = get_model_path(HF_MODEL)
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model = model.to(DEVICE)

print("Loaded model")

SYSTEM_PROMPT = "You are an AI visual assistant and you are seeing a single image. You will receive an instruction regarding that image. Your goal is to follow the instruction as faithfully as you can."

whiteboard = Image.open(requests.get("https://c1.staticflickr.com/7/6168/6207108414_a8833f410e_o.jpg", stream=True).raw).convert('RGB')
taxi_image = Image.open(requests.get("https://llava.hliu.cc/file=/nobackup/haotian/tmp/gradio/ca10383cc943e99941ecffdc4d34c51afb2da472/extreme_ironing.jpg", stream=True).raw).convert('RGB')


def generate_answer(img, question, max_new_tokens, num_beams):
    
    image = [img]
    prompt = [f"{SYSTEM_PROMPT}\n\nInstruction: {question}\nOutput:"]
    inputs = tokenizer(prompt, padding='longest', return_tensors='pt')
    print("Generating a response with the following parameters:")
    print(f"""Question: {question}\nMax New Tokens: {max_new_tokens}\nNum Beams: {num_beams}""")

    model.eval()
    with torch.no_grad():
        outputs = model.generate(vision_x=image, 
                                lang_x=inputs.input_ids.to(DEVICE),
                                device=DEVICE,
                                max_new_tokens=max_new_tokens,
                                num_beams = num_beams,
                                eos_token_id = tokenizer.eos_token_id,
                                pad_token_id = tokenizer.pad_token_id,
                                attention_mask=inputs.attention_mask.to(DEVICE))
        answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].split("Output:")[-1].lstrip()
    
    return answer


# Create the Gradio interface
iface = gr.Interface(
    fn=generate_answer,
    inputs=[
        gr.Image(type="pil", label="Image"),
        gr.Textbox(label="Question"),
        gr.Slider(minimum=5, maximum=500, step=1, value=50, label="Max New Tokens"),
        gr.Slider(minimum=1, maximum=5, step=1, value=3, label="Num Beams")
    ],
    outputs=gr.Textbox(label="Answer"),
    title="<h1 style='text-align: center; display: block;'>Pheye-x2 672x672 pixels</h1>",
    examples=[[taxi_image, "What is unusual about this image?"], [whiteboard, "What is the main topic of the whiteboard?"]]
)




if __name__ == "__main__":
    # Launch the Gradio app
    iface.launch()