ImageAPI / app.py
nimnim's picture
Added more text to the database and added Unidentified element
3d9617f
raw
history blame
1.7 kB
import gradio as gr
import cv2
import torch
import data
from models import imagebind_model
from models.imagebind_model import ModalityType
def read_dict_from_file(filename):
text_list = []
dictionary = {}
with open(filename, 'r') as file:
for line in file:
line = line.strip()
if line:
key, value = line.split(':', 1)
dictionary[key.strip()] = value.strip()
text_list.append(key)
return dictionary, text_list
text_output_path = 'output.txt'
database, text_list = read_dict_from_file(text_output_path)
text_list.append("Unidentified")
database["Unidentified"] = "NA"
image_paths = ['images/captured_image.jpg']
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)
def run_model():
inputs = {
ModalityType.TEXT: data.load_and_transform_text(text_list, device),
ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
}
with torch.no_grad():
embeddings = model(inputs)
embeddings_matrix = torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1)
model_image_to_text = torch.argmax(embeddings_matrix, dim=1)
label_product = text_list[model_image_to_text[0]]
return label_product, database[label_product]
def predict(image):
# Save the image to the desired file path
cv2.imwrite(image_paths[0], cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
return run_model()
demo = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(),
outputs=["text", "text"]
)
demo.launch()