WebashalarForML commited on
Commit
4b4aedc
·
verified ·
1 Parent(s): 93416d9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import base64
4
+ import numpy as np
5
+ from flask import Flask, request, jsonify
6
+ from langchain_experimental.open_clip.open_clip import OpenCLIPEmbeddings
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ from io import BytesIO
9
+ from PIL import Image
10
+ # from matplotlib.offsetbox import OffsetImage, AnnotationBbox
11
+ from io import BytesIO
12
+ from pathlib import Path
13
+
14
+ # ============================== #
15
+ # INITIALIZE APP #
16
+ # ============================== #
17
+ app = Flask(__name__)
18
+ clip_embd = OpenCLIPEmbeddings()
19
+
20
+ BASE_DIR = Path("/app")
21
+ BLOCKS_DIR = BASE_DIR / "blocks"
22
+ # STATIC_DIR = BASE_DIR / "static"
23
+ # GEN_PROJECT_DIR = BASE_DIR / "generated_projects"
24
+ BACKDROP_DIR = BLOCKS_DIR / "Backdrops"
25
+ SPRITE_DIR = BLOCKS_DIR / "sprites"
26
+ CODE_BLOCKS_DIR = BLOCKS_DIR / "code_blocks"
27
+ # === new: outputs rooted under BASE_DIR ===
28
+ OUTPUT_DIR = BASE_DIR / "outputs"
29
+
30
+ # ============================== #
31
+ # LOAD PRE-COMPUTED EMBEDS #
32
+ # ============================== #
33
+ with open(f"{BLOCKS_DIR}/embeddings.json", "r") as f:
34
+ embedding_json = json.load(f)
35
+
36
+ image_paths = [item["file-path"] for item in embedding_json]
37
+ image_embeds = np.array([item["embeddings"] for item in embedding_json])
38
+
39
+
40
+ # ============================== #
41
+ # HELPER: Decode Base64 Image #
42
+ # ============================== #
43
+ def decode_base64_image(b64_string):
44
+ img_data = base64.b64decode(b64_string)
45
+ img = Image.open(BytesIO(img_data)).convert("RGB")
46
+ return img
47
+
48
+
49
+ # ============================== #
50
+ # API ROUTE #
51
+ # ============================== #
52
+ @app.route("/match", methods=["POST"])
53
+ def match_image():
54
+ """
55
+ Input: JSON { "images": ["<base64_img1>", "<base64_img2>", ...] }
56
+ Output: Best match path + score for each input image
57
+ """
58
+ data = request.get_json()
59
+ if "images" not in data:
60
+ return jsonify({"error": "No images provided"}), 400
61
+
62
+ results = []
63
+ for b64_img in data["images"]:
64
+ try:
65
+ # Decode and embed input image
66
+ img = decode_base64_image(b64_img)
67
+ query_embed = np.array(clip_embd.embed_image([img])) # embed_image can take PIL images
68
+
69
+ # Cosine similarity with stored embeddings
70
+ sims = cosine_similarity(query_embed, image_embeds)[0]
71
+ best_idx = np.argmax(sims)
72
+
73
+ results.append({
74
+ "input": b64_img[:50] + "...", # partial preview
75
+ "best_match": {
76
+ "name": os.path.basename(image_paths[best_idx]),
77
+ "path": image_paths[best_idx],
78
+ "similarity": float(sims[best_idx])
79
+ }
80
+ })
81
+ except Exception as e:
82
+ results.append({"error": str(e)})
83
+
84
+ return jsonify(results)
85
+
86
+
87
+ # ============================== #
88
+ # MAIN ENTRY #
89
+ # ============================== #
90
+ if __name__ == "__main__":
91
+ app.run(debug=True, port=7860)