import gradio as gr
from transformers import AutoProcessor, CLIPModel

clip_path = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(clip_path).eval()
processor = AutoProcessor.from_pretrained(clip_path)


async def predict(init_image, labels_level1):
    if init_image is None:
        return "", ""

    split_labels = labels_level1.split(",")
    ret_str = ""
    
    inputs = processor(
        text=split_labels, images=init_image, return_tensors="pt", padding=True
    )

    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image  # this is the image-text similarity score


    for i in range(len(split_labels)):
        ret_str += split_labels[i] + ": " + str(logits_per_image[0][i]) + "\n"

    return ret_str, ret_str


css = """
#container{
    margin: 0 auto;
    max-width: 80rem;
}
#intro{
    max-width: 100%;
    text-align: center;
    margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
    init_image_state = gr.State()
    with gr.Column(elem_id="container"):
        gr.Markdown(
            """# Clip Demo
            """,
            elem_id="intro",
        )
        with gr.Row():
            txt_input = gr.Textbox(
                value="cartoon,painting,screenshot",
                interactive=True, label="设定大类别类别", scale=5)
            txt = gr.Textbox(value="", label="Output:", scale=5)
            generate_bt = gr.Button("点击开始分类", scale=1)
        with gr.Row():
            with gr.Column():
                image_input = gr.Image(
                    sources=["upload", "clipboard"],
                    label="User Image",
                    type="pil",
                )
        with gr.Row():
            prob_label = gr.Textbox(value="", label="一级分类")

        inputs = [image_input, txt_input]
        generate_bt.click(fn=predict, inputs=inputs, outputs=[txt, prob_label], show_progress=True)
        image_input.change(
            fn=predict,
            inputs=inputs,
            outputs=[txt, prob_label],
            show_progress=True,
            queue=False,
        )

demo.queue()
demo.launch(server_name='0.0.0.0', server_port=8081, share=False)