File size: 1,686 Bytes
b66ad81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61b336c
 
b66ad81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7b62a7
 
30bb3fd
b66ad81
 
 
 
 
 
0221e35
b66ad81
 
30bb3fd
b66ad81
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
import cv2
import torch
import data
from models import imagebind_model
from models.imagebind_model import ModalityType

def read_dict_from_file(filename):
    text_list = []
    dictionary = {}
    with open(filename, 'r') as file:
        for line in file:
            line = line.strip()
            if line:
                key, value = line.split(':', 1)
                dictionary[key.strip()] = value.strip()
                text_list.append(key)
    return dictionary, text_list

text_output_path = 'output.txt'
database, text_list = read_dict_from_file(text_output_path)
# text_list.append("N/A")
# database["N/A"] = "NA"
image_paths = ['images/captured_image.jpg']

device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)

def run_model():
    inputs = {
        ModalityType.TEXT: data.load_and_transform_text(text_list, device),
        ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
    }

    with torch.no_grad():
        embeddings = model(inputs)

    embeddings_matrix = torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1)

    model_image_to_text = torch.argmax(embeddings_matrix, dim=1)

    label_product = text_list[model_image_to_text[0]]

    return label_product, database[label_product] 

def predict(image):
    # Save the image to the desired file path
    cv2.imwrite(image_paths[0], cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
    return run_model()

demo = gr.Interface(
    fn=predict,
    inputs=gr.inputs.Image(),
    outputs=["text", "text"]
)

demo.launch()