foodDetectionDemo / clip_component.py
cheng
update text interval
710ee58
import cv2
import torch
import os
from PIL import Image
import clip
similarity_threshold = 22.00
def get_token_from_clip(image):
text_inputs = ["apple", "banana", "cereal", "milk", "lemon", "orange", "salad", "juice", "chicken", "bread"]
text_tokens = clip.tokenize(text_inputs)
device = "cpu"
model, preprocess = clip.load("ViT-B/32")
print("device: ", device)
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
image_pil = Image.fromarray(image.astype('uint8'))
image_input = preprocess(image_pil).unsqueeze(0).to(device) # Add batch dimension
with torch.no_grad():
image_feature = model.encode_image(image_input)
image_feature /= image_feature.norm(dim=-1, keepdim=True)
with torch.no_grad():
similarity = text_features.cpu().numpy() @ image_feature.cpu().numpy().T
results = []
for i in range(similarity.shape[0]):
similarity_num = (100.0 * similarity[i][0])
text_input = text_inputs[i]
results.append({"text_input": text_input, "similarity": similarity_num})
# print(similarity_num)
results.sort(key=lambda x: x["similarity"], reverse=True)
# Print the caption for each text input along with their similarity scores
detect_food = ""
for result in results:
print(f"Text input: {result['text_input']}, Similarity: {result['similarity']:.2f}")
if result['similarity'] >= similarity_threshold:
detect_food += " " + result['text_input'] + ", "
detect_food_list = detect_food[1:]
return detect_food_list