achterbrain's picture
fixed error in DETR counting function
982eb46
raw
history blame
7.97 kB
import random
import os
import torch
import pandas as pd
from transformers import CLIPProcessor, CLIPModel, DetrFeatureExtractor, DetrForObjectDetection
from PIL import Image
CLIPmodel_import = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
CLIPprocessor_import = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
DetrFeatureExtractor_import = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")
DetrModel_import = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
# Import list of coco example objects
script_path = os.path.dirname(__file__)
coco_objects = open(script_path+"/coco-labels-paper.txt", "r")
coco_objects = coco_objects.read()
coco_objects = coco_objects.split("\n")
# Example image
#test_image = Image.open('pages/Functions/test_image.png')
#test_image = Image.open('pages/Functions/test_imageIV.png')
###### Helper functions
def Coco_object_set(included_object, set_length=6):
'''
Creates set of object based on coco objects and the currently correct object.
'''
curr_object_set = set([included_object])
while len(curr_object_set)<set_length:
temp_object = random.choice(coco_objects)
curr_object_set.add(temp_object)
return list(curr_object_set)
def Object_set_creator(included_object, list_of_all_objects = coco_objects, excluded_objects_list = [], set_length=6):
'''
Creates set of object based on list_of_all_objects.
The included object will always be in the list.
Optional list of objects to be excluded from the set.
'''
curr_object_set = set([included_object])
# Check that the included object is not contained in the excluded objects
if included_object in excluded_objects_list:
raise ValueError('The included_object can not be part of the excluded_objects list.')
while len(curr_object_set)<set_length:
temp_object = random.choice(list_of_all_objects)
if temp_object not in excluded_objects_list:
curr_object_set.add(temp_object)
return list(curr_object_set)
###### Single object recognition
def CLIP_single_object_classifier(img, object_class, task_specific_label=None):
'''
Test presence of object in image by using the "red herring strategy" and CLIP algorithm.
Note that the task_specific_label is not used for this classifier.
'''
# Define model and parameters
word_list = Coco_object_set(object_class)
inputs = CLIPprocessor_import(text=word_list, images=img, return_tensors="pt", padding=True)
# Run inference
outputs = CLIPmodel_import(**inputs)
# Get image-text similarity score
logits_per_image = outputs.logits_per_image
# Get probabilities
probs = logits_per_image.softmax(dim=1)
# Return true if the highest prob value is recognised
if word_list[probs.argmax().item()]==object_class:
return True
else:
return False
def CLIP_object_recognition(img, object_class, tested_classes):
'''
More general CLIP object recogntintion implementation
'''
if object_class not in tested_classes:
raise ValueError('The object_class has to be part of the tested_classes list.')
# Define model and parameters
inputs = CLIPprocessor_import(text=tested_classes, images=img, return_tensors="pt", padding=True)
# Run inference
outputs = CLIPmodel_import(**inputs)
# Get image-text similarity score
logits_per_image = outputs.logits_per_image
# Get probabilities
probs = logits_per_image.softmax(dim=1)
# Return true if the highest prob value is recognised
if tested_classes[probs.argmax().item()]==object_class:
return True
else:
return False
###### Multi object recognition
#list_of_objects = ['cat','apple','cow']
def CLIP_multi_object_recognition(img, list_of_objects):
'''
Algorithm based on CLIP to test presence of multiple objects.
Currently has a debugging print call in.
'''
# Loop over list of objects, test for presence of each inidividually, making sure that non of the other objects is part of test set
for i_object in list_of_objects:
# Create list with objects not in test set (all objects which arent i_object)
untested_objects = [x for x in list_of_objects if x!= i_object]
# Create set going into clip object recogniser and test this set using standard recognition function
CLIP_test_classes = Object_set_creator(included_object=i_object, excluded_objects_list=untested_objects)
i_object_present = CLIP_object_recognition(img, i_object, CLIP_test_classes)
print(i_object+str(i_object_present))
# Stop loop and return false if one of the objects is not recognised by CLIP
if i_object_present == False:
return False
# Return true if all objects were recognised
return True
def CLIP_multi_object_recognition_DSwrapper(img, representations, task_specific_label=None):
'''
Dashboard wrapper of CLIP_multi_object_recognition
Note that the task_specific_label is not used for this classifier.
'''
list_of_objects = representations.split(', ')
return CLIP_multi_object_recognition(img,list_of_objects)
###### Negation
def CLIP_object_negation(img, present_object, absent_object):
'''
Algorithm based on CLIP to test negation prompts
'''
# Create sets of objects for present and absent object
tested_classes_present = Object_set_creator(
included_object=present_object, excluded_objects_list=[absent_object])
tested_classes_absent = Object_set_creator(
included_object=absent_object, excluded_objects_list=[present_object],set_length=10)
# Use CLIP object recognition to test for objects.
presence_test = CLIP_object_recognition(img, present_object, tested_classes_present)
absence_test = CLIP_object_recognition(img, absent_object, tested_classes_absent)
if presence_test==True and absence_test==False:
return True
else:
return False
###### Counting / arithmetic
'''
test_image = Image.open('pages/Functions/test_imageIII.jpeg')
object_classes = ['cat','remote']
object_counts = [2,2]
'''
def DETR_multi_object_counting(img, object_classes, object_counts, confidence_treshold=0.5):
# Apply Detr to image
inputs = DetrFeatureExtractor_import(images=img, return_tensors="pt")
outputs = DetrModel_import(**inputs)
# Convert outputs (bounding boxes and class logits) to COCO API
target_sizes = torch.tensor([img.size[::-1]])
results = DetrFeatureExtractor_import.post_process_object_detection(
outputs, threshold=confidence_treshold, target_sizes=target_sizes)[0]
# Create dict with value_counts
count_dict = pd.Series(results['labels'].numpy())
count_dict = count_dict.map(DetrModel_import.config.id2label)
count_dict = count_dict.value_counts().to_dict()
# Create dict for correct response
label_dict = dict(zip(object_classes, object_counts))
# Return False is the count for a given label does not match
for i_item in label_dict.items():
# Check whether current label item exists in count dict, else return false
if i_item[0] not in count_dict:
return False
# Now that we checked the label item is in count dict, check that the count matches
if int(count_dict[i_item[0]])==int(i_item[1]): # Adding type control for comparison due to str read in
print(str(i_item)+'_true')
else:
print(str(i_item)+'_false')
print("oberserved: "+str(count_dict[i_item[0]]))
return False
# If all match, return true
return True
def DETR_multi_object_counting_DSwrapper(img, representations, Task_specific_label):
'''
Dashboard wrapper of DETR_multi_object_counting
'''
list_of_objects = representations.split(', ')
object_counts = Task_specific_label.split(', ')
return DETR_multi_object_counting(img,list_of_objects, object_counts, confidence_treshold=0.5)