import streamlit as st #from transformers import CLIPModel, pipeline, CLIPImageProcessor from transformers import pipeline import torch from PIL import Image ################################# #### FUNCTIONS def load_clip(model_size='large'): if model_size == 'base': MODEL_name = 'openai/clip-vit-base-patch32' #elif model_size == 'large': # MODEL_name = 'openai/clip-vit-large-patch14' model = CLIPModel.from_pretrained(MODEL_name) processor = CLIPImageProcessor.from_pretrained(MODEL_name) return processor, model def inference_clip(options, image, processor, model): inputs = processor(text= options, images=image, return_tensors="pt", padding=True) with torch.no_grad(): outputs = model(**inputs) #logits_per_text = outputs.logits_per_text logits_per_image = outputs.logits_per_image # this is the image-text similarity score probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities max_prob_idx = torch.argmax(probs) max_prob_option = options[max_prob_idx] max_prob = probs[max_prob_idx].item() return max_prob_option ################################# #### LAYOUT col_l, col_r = st.columns(2) #CLIP_large = load_clip(model_size='large') model_name = "openai/clip-vit-large-patch14-336" classifier = pipeline("zero-shot-image-classification", model = model_name) #### Loading picture with col_l: picture_file = st.file_uploader("Picture :", type=["jpg", "jpeg", "png"]) if picture_file is not None: image = Image.open(picture_file) st.image(image, caption='Please upload an image of the damage') #use_column_width=True #image with col_l: default_options = 'There is a car, There is no car' options = st.text_input(label="Please enter the classes", value=default_options).split(',') #options = list(options) # button to launch compute if st.button("Compute"): #clip_processor, clip_model = load_clip(model_size='large') #result = inference_clip(options = options, image = image, processor=clip_processor, model=clip_model) scores = classifier(image, candidate_labels = options) with col_r: #st.write(result) st.dataframe(scores)