open-source-models-hg / zero_shot_classification.py
kirill
Fixed tabs
01ab6e1
from transformers import CLIPModel, CLIPProcessor
import time
import gradio as gr
def get_zero_shot_classification_tab():
openai_model_name = "openai/clip-vit-large-patch14"
openai_model = CLIPModel.from_pretrained(openai_model_name)
openai_processor = CLIPProcessor.from_pretrained(openai_model_name)
patrickjohncyh_model_name = "patrickjohncyh/fashion-clip"
patrickjohncyh_model = CLIPModel.from_pretrained(patrickjohncyh_model_name)
patrickjohncyh_processor = CLIPProcessor.from_pretrained(patrickjohncyh_model_name)
model_map = {
openai_model_name: (openai_model, openai_processor),
patrickjohncyh_model_name: (patrickjohncyh_model, patrickjohncyh_processor)
}
def gradio_process(model_name, image, text):
(model, processor) = model_map[model_name]
labels = text.split(", ")
print (labels)
start = time.time()
inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)
outputs = model(**inputs)
probs = outputs.logits_per_image.softmax(dim=1)[0]
end = time.time()
time_spent = end - start
probs = list(probs)
results = []
for i in range(len(labels)):
results.append(f"{labels[i]} - {probs[i].item():.4f}")
result = "\n".join(results)
return [result, time_spent]
with gr.TabItem("Zero-Shot Classification") as zero_shot_image_classification_tab:
gr.Markdown("# Zero-Shot Image Classification")
with gr.Row():
with gr.Column():
# Input components
input_image = gr.Image(label="Upload Image", type="pil")
input_text = gr.Textbox(label="Labels (comma separated)")
model_selector = gr.Dropdown([openai_model_name, patrickjohncyh_model_name],
label = "Select Model")
# Process button
process_btn = gr.Button("Classificate")
with gr.Column():
# Output components
elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
output_text = gr.Textbox(label="Classification")
# Connect the input components to the processing function
process_btn.click(
fn=gradio_process,
inputs=[
model_selector,
input_image,
input_text
],
outputs=[output_text, elapsed_result]
)
return zero_shot_image_classification_tab