|
from transformers import CLIPModel, CLIPProcessor |
|
import time |
|
import gradio as gr |
|
|
|
|
|
def get_zero_shot_classification_tab(): |
|
openai_model_name = "openai/clip-vit-large-patch14" |
|
openai_model = CLIPModel.from_pretrained(openai_model_name) |
|
openai_processor = CLIPProcessor.from_pretrained(openai_model_name) |
|
|
|
patrickjohncyh_model_name = "patrickjohncyh/fashion-clip" |
|
patrickjohncyh_model = CLIPModel.from_pretrained(patrickjohncyh_model_name) |
|
patrickjohncyh_processor = CLIPProcessor.from_pretrained(patrickjohncyh_model_name) |
|
|
|
model_map = { |
|
openai_model_name: (openai_model, openai_processor), |
|
patrickjohncyh_model_name: (patrickjohncyh_model, patrickjohncyh_processor) |
|
} |
|
|
|
def gradio_process(model_name, image, text): |
|
(model, processor) = model_map[model_name] |
|
labels = text.split(", ") |
|
print (labels) |
|
start = time.time() |
|
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True) |
|
outputs = model(**inputs) |
|
probs = outputs.logits_per_image.softmax(dim=1)[0] |
|
end = time.time() |
|
time_spent = end - start |
|
probs = list(probs) |
|
results = [] |
|
for i in range(len(labels)): |
|
results.append(f"{labels[i]} - {probs[i].item():.4f}") |
|
result = "\n".join(results) |
|
|
|
return [result, time_spent] |
|
|
|
with gr.TabItem("Zero-Shot Classification") as zero_shot_image_classification_tab: |
|
gr.Markdown("# Zero-Shot Image Classification") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
input_image = gr.Image(label="Upload Image", type="pil") |
|
input_text = gr.Textbox(label="Labels (comma separated)") |
|
model_selector = gr.Dropdown([openai_model_name, patrickjohncyh_model_name], |
|
label = "Select Model") |
|
|
|
|
|
process_btn = gr.Button("Classificate") |
|
|
|
with gr.Column(): |
|
|
|
elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1) |
|
output_text = gr.Textbox(label="Classification") |
|
|
|
|
|
process_btn.click( |
|
fn=gradio_process, |
|
inputs=[ |
|
model_selector, |
|
input_image, |
|
input_text |
|
], |
|
outputs=[output_text, elapsed_result] |
|
) |
|
|
|
return zero_shot_image_classification_tab |
|
|