Commit
·
1420df1
1
Parent(s):
e1e20e4
Added linear reuse and price layers
Browse files- app.py +48 -2
- files/price_linear.pth +3 -0
- files/reuse_linear.pth +3 -0
app.py
CHANGED
@@ -67,7 +67,53 @@ def predict_brand(img):
|
|
67 |
|
68 |
|
69 |
def estimate_price_and_usage(img):
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
72 |
|
73 |
def retrieve(query):
|
@@ -141,7 +187,7 @@ with gr.Blocks(
|
|
141 |
predicted_brand = gr.Textbox(label="Brand", show_label=False)
|
142 |
|
143 |
with gr.Column(variant="compact"):
|
144 |
-
btn_estimate = gr.Button("Estimate Price
|
145 |
text_box = gr.Textbox(label="Estimates:", show_label=False)
|
146 |
with gr.Tab("Image Retrieval"):
|
147 |
with gr.Row(variant="compact"):
|
|
|
67 |
|
68 |
|
69 |
def estimate_price_and_usage(img):
|
70 |
+
query_features = model.encode_image(preprocess(img).unsqueeze(0).to(device))
|
71 |
+
|
72 |
+
# Estimate usage
|
73 |
+
num_classes = 2
|
74 |
+
probe = torch.nn.Linear(
|
75 |
+
query_features.shape[-1],
|
76 |
+
num_classes,
|
77 |
+
dtype=torch.float16,
|
78 |
+
bias=False
|
79 |
+
)
|
80 |
+
# Load weights for the linear layer as a tensor
|
81 |
+
linear_data = torch.load("files/reuse_linear.pth")
|
82 |
+
probe.weight.data = linear_data["weight"]
|
83 |
+
|
84 |
+
# Do inference
|
85 |
+
probe.eval()
|
86 |
+
probe = probe.to(device)
|
87 |
+
output = probe(query_features)
|
88 |
+
print(output)
|
89 |
+
output = torch.softmax(output, dim=-1)
|
90 |
+
output = output.cpu().detach().numpy().astype("float32")
|
91 |
+
reuse = output.argmax(axis=-1)[0]
|
92 |
+
reuse_classes = ["Reuse", "Export"]
|
93 |
+
|
94 |
+
# Estimate price
|
95 |
+
num_classes = 4
|
96 |
+
probe = torch.nn.Linear(
|
97 |
+
query_features.shape[-1],
|
98 |
+
num_classes,
|
99 |
+
dtype=torch.float16,
|
100 |
+
bias=False
|
101 |
+
)
|
102 |
+
# Print output shape for the linear layer
|
103 |
+
# Load weights for the linear layer as a tensor
|
104 |
+
linear_data = torch.load("files/price_linear.pth")
|
105 |
+
probe.weight.data = linear_data["weight"]
|
106 |
+
|
107 |
+
# Do inference
|
108 |
+
probe.eval()
|
109 |
+
probe = probe.to(device)
|
110 |
+
output = probe(query_features)
|
111 |
+
output = torch.softmax(output, dim=-1)
|
112 |
+
output = output.cpu().detach().numpy().astype("float32")
|
113 |
+
price = output.argmax(axis=-1)[0]
|
114 |
+
price_classes = ["<50", "50-100", "100-150", ">150"]
|
115 |
+
|
116 |
+
return f"Estimated price: {price_classes[price]} SEK - Usage: {reuse_classes[reuse]}"
|
117 |
|
118 |
|
119 |
def retrieve(query):
|
|
|
187 |
predicted_brand = gr.Textbox(label="Brand", show_label=False)
|
188 |
|
189 |
with gr.Column(variant="compact"):
|
190 |
+
btn_estimate = gr.Button("Estimate Price and Reuse").style(size="sm")
|
191 |
text_box = gr.Textbox(label="Estimates:", show_label=False)
|
192 |
with gr.Tab("Image Retrieval"):
|
193 |
with gr.Row(variant="compact"):
|
files/price_linear.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a00e3c8fa9f78af43ab1208e14b818bea5028443e4c5260c6743e05f14b378f
|
3 |
+
size 5115
|
files/reuse_linear.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8998d9093128cab4269649e9dc70541940f0d8a9a92c6e02fe774a6153f5e29b
|
3 |
+
size 3067
|