Spaces:
Running
Running
File size: 7,040 Bytes
8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 aebdae7 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c aebdae7 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c aebdae7 8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 f77c97c 8ff3c52 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import argparse
import base64
import os
import pickle
import time
from typing import Dict, List
import cv2
import numpy as np
import requests
ENDPOINT = "http://127.0.0.1:8000"
if "REMOTE_URL_RAILWAY" in os.environ:
ENDPOINT = os.environ["REMOTE_URL_RAILWAY"]
print(f"API ENDPOINT: {ENDPOINT}")
API_VERSION = f"{ENDPOINT}/version"
API_URL_MATCH = f"{ENDPOINT}/v1/match"
API_URL_EXTRACT = f"{ENDPOINT}/v1/extract"
def read_image(path: str) -> str:
"""
Read an image from a file, encode it as a JPEG and then as a base64 string.
Args:
path (str): The path to the image to read.
Returns:
str: The base64 encoded image.
"""
# Read the image from the file
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
# Encode the image as a png, NO COMPRESSION!!!
retval, buffer = cv2.imencode(".png", img)
# Encode the JPEG as a base64 string
b64img = base64.b64encode(buffer).decode("utf-8")
return b64img
def do_api_requests(url=API_URL_EXTRACT, **kwargs):
"""
Helper function to send an API request to the image matching service.
Args:
url (str): The URL of the API endpoint to use. Defaults to the
feature extraction endpoint.
**kwargs: Additional keyword arguments to pass to the API.
Returns:
List[Dict[str, np.ndarray]]: A list of dictionaries containing the
extracted features. The keys are "keypoints", "descriptors", and
"scores", and the values are ndarrays of shape (N, 2), (N, ?),
and (N,), respectively.
"""
# Set up the request body
reqbody = {
# List of image data base64 encoded
"data": [],
# List of maximum number of keypoints to extract from each image
"max_keypoints": [100, 100],
# List of timestamps for each image (not used?)
"timestamps": ["0", "1"],
# Whether to convert the images to grayscale
"grayscale": 0,
# List of image height and width
"image_hw": [[640, 480], [320, 240]],
# Type of feature to extract
"feature_type": 0,
# List of rotation angles for each image
"rotates": [0.0, 0.0],
# List of scale factors for each image
"scales": [1.0, 1.0],
# List of reference points for each image (not used)
"reference_points": [[640, 480], [320, 240]],
# Whether to binarize the descriptors
"binarize": True,
}
# Update the request body with the additional keyword arguments
reqbody.update(kwargs)
try:
# Send the request
r = requests.post(url, json=reqbody)
if r.status_code == 200:
# Return the response
return r.json()
else:
# Print an error message if the response code is not 200
print(f"Error: Response code {r.status_code} - {r.text}")
except Exception as e:
# Print an error message if an exception occurs
print(f"An error occurred: {e}")
def send_request_match(path0: str, path1: str) -> Dict[str, np.ndarray]:
"""
Send a request to the API to generate a match between two images.
Args:
path0 (str): The path to the first image.
path1 (str): The path to the second image.
Returns:
Dict[str, np.ndarray]: A dictionary containing the generated matches.
The keys are "keypoints0", "keypoints1", "matches0", and "matches1",
and the values are ndarrays of shape (N, 2), (N, 2), (N, 2), and
(N, 2), respectively.
"""
files = {"image0": open(path0, "rb"), "image1": open(path1, "rb")}
try:
# TODO: replace files with post json
response = requests.post(API_URL_MATCH, files=files)
pred = {}
if response.status_code == 200:
pred = response.json()
for key in list(pred.keys()):
pred[key] = np.array(pred[key])
else:
print(
f"Error: Response code {response.status_code} - {response.text}"
)
finally:
files["image0"].close()
files["image1"].close()
return pred
def send_request_extract(
input_images: str, viz: bool = False
) -> List[Dict[str, np.ndarray]]:
"""
Send a request to the API to extract features from an image.
Args:
input_images (str): The path to the image.
Returns:
List[Dict[str, np.ndarray]]: A list of dictionaries containing the
extracted features. The keys are "keypoints", "descriptors", and
"scores", and the values are ndarrays of shape (N, 2), (N, 128),
and (N,), respectively.
"""
image_data = read_image(input_images)
inputs = {
"data": [image_data],
}
response = do_api_requests(
url=API_URL_EXTRACT,
**inputs,
)
# breakpoint()
# print("Keypoints detected: {}".format(len(response[0]["keypoints"])))
# draw matching, debug only
if viz:
from hloc.utils.viz import plot_keypoints
from ui.viz import fig2im, plot_images
kpts = np.array(response[0]["keypoints_orig"])
if "image_orig" in response[0].keys():
img_orig = np.array(["image_orig"])
output_keypoints = plot_images([img_orig], titles="titles", dpi=300)
plot_keypoints([kpts])
output_keypoints = fig2im(output_keypoints)
cv2.imwrite(
"demo_match.jpg",
output_keypoints[:, :, ::-1].copy(), # RGB -> BGR
)
return response
def get_api_version():
try:
response = requests.get(API_VERSION).json()
print("API VERSION: {}".format(response["version"]))
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Send text to stable audio server and receive generated audio."
)
parser.add_argument(
"--image0",
required=False,
help="Path for the file's melody",
default="datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot45.jpg",
)
parser.add_argument(
"--image1",
required=False,
help="Path for the file's melody",
default="datasets/sacre_coeur/mapping_rot/02928139_3448003521_rot90.jpg",
)
args = parser.parse_args()
# get api version
get_api_version()
# request match
# for i in range(10):
# t1 = time.time()
# preds = send_request_match(args.image0, args.image1)
# t2 = time.time()
# print(
# "Time cost1: {} seconds, matched: {}".format(
# (t2 - t1), len(preds["mmkeypoints0_orig"])
# )
# )
# request extract
for i in range(1000):
t1 = time.time()
preds = send_request_extract(args.image0)
t2 = time.time()
print(f"Time cost2: {(t2 - t1)} seconds")
# dump preds
with open("preds.pkl", "wb") as f:
pickle.dump(preds, f)
|