thanhnt-cf commited on
Commit
8ba64a4
·
1 Parent(s): 0a2ea2e

initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ env
2
+ .env
3
+ app.log
4
+ gradio_temp/
5
+
6
+ __pycache__/
app.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["HUGGINGFACE_DEMO"] = "1" # set before import from app
4
+
5
+ from dotenv import load_dotenv
6
+ load_dotenv()
7
+ ################################################################################################
8
+
9
+ import gradio as gr
10
+ import uuid
11
+ import shutil
12
+
13
+ from app.config import get_settings
14
+ from app.schemas.requests import Attribute
15
+ from app.request_handler import handle_extract
16
+ from app.services.factory import AIServiceFactory
17
+
18
+
19
+ settings = get_settings()
20
+ IMAGE_MAX_SIZE = 1536
21
+
22
+
23
+ async def forward_request(attributes, product_taxonomy, product_data, ai_model, pil_images):
24
+ # prepare temp folder
25
+ request_id = str(uuid.uuid4())
26
+ request_temp_folder = os.path.join('gradio_temp', request_id)
27
+ os.makedirs(request_temp_folder, exist_ok=True)
28
+
29
+ try:
30
+ # convert attributes to schema
31
+ attributes = "attributes_object = {" + attributes + "}"
32
+ try:
33
+ attributes = exec(attributes, globals())
34
+ except:
35
+ raise gr.Error("Invalid `Attribute Schema`. Please insert valid schema following the example.")
36
+ for key, value in attributes_object.items(): # type: ignore
37
+ attributes_object[key] = Attribute(**value) # type: ignore
38
+
39
+ if product_data == "":
40
+ product_data = "{}"
41
+ product_data_code = f"product_data_object = {product_data}"
42
+
43
+ try:
44
+ exec(product_data_code, globals())
45
+ except:
46
+ raise gr.Error('Invalid `Product Data`. Please insert valid dictionary or leave it empty.')
47
+
48
+ if pil_images is None:
49
+ raise gr.Error('Please upload image(s) of the product')
50
+ pil_images = [pil_image[0] for pil_image in pil_images]
51
+ img_paths = []
52
+ for i, pil_image in enumerate(pil_images):
53
+ if max(pil_image.size) > IMAGE_MAX_SIZE:
54
+ ratio = IMAGE_MAX_SIZE / max(pil_image.size)
55
+ pil_image = pil_image.resize((int(pil_image.width * ratio), int(pil_image.height * ratio)))
56
+ img_path = os.path.join(request_temp_folder, f'{i}.jpg')
57
+ if pil_image.mode in ('RGBA', 'LA') or (pil_image.mode == 'P' and 'transparency' in pil_image.info):
58
+ pil_image = pil_image.convert("RGBA")
59
+ if pil_image.getchannel("A").getextrema() == (255, 255): # if fully opaque, save as JPEG
60
+ pil_image = pil_image.convert("RGB")
61
+ image_format = 'JPEG'
62
+ else:
63
+ image_format = 'PNG'
64
+ else:
65
+ image_format = 'JPEG'
66
+ pil_image.save(img_path, image_format, quality=100, subsampling=0)
67
+ img_paths.append(img_path)
68
+
69
+ # mapping
70
+ if ai_model in settings.OPENAI_MODELS:
71
+ ai_vendor = 'openai'
72
+ elif ai_model in settings.ANTHROPIC_MODELS:
73
+ ai_vendor = 'anthropic'
74
+ service = AIServiceFactory.get_service(ai_vendor)
75
+
76
+ try:
77
+ json_attributes = await service.extract_attributes_with_validation(
78
+ attributes_object, # type: ignore
79
+ ai_model,
80
+ None,
81
+ product_taxonomy,
82
+ product_data_object, # type: ignore
83
+ img_paths=img_paths,
84
+ )
85
+ except:
86
+ raise gr.Error('Failed to extract attributes. Something went wrong.')
87
+ finally:
88
+ # remove temp folder anyway
89
+ shutil.rmtree(request_temp_folder)
90
+
91
+ gr.Info('Process completed!')
92
+ return json_attributes
93
+
94
+
95
+ def add_attribute_schema(attributes, attr_name, attr_desc, attr_type, allowed_values):
96
+ schema = f"""
97
+ "{attr_name}": {{
98
+ "description": "{attr_desc}",
99
+ "data_type": "{attr_type}",
100
+ "allowed_values": [
101
+ {', '.join([f'"{v.strip()}"' for v in allowed_values.split(',')])}
102
+ ]
103
+ }},
104
+ """
105
+ return attributes + schema, "", "", "", ""
106
+
107
+
108
+ sample_schema = """"category": {
109
+ "description": "Category of the garment",
110
+ "data_type": "list[string]",
111
+ "allowed_values": [
112
+ "upper garment", "lower garment", "footwear", "accessory", "headwear", "dresses"
113
+ ]
114
+ },
115
+
116
+ "color": {
117
+ "description": "Color of the garment",
118
+ "data_type": "list[string]",
119
+ "allowed_values": [
120
+ "black", "white", "red", "blue", "green", "yellow", "pink", "purple", "orange", "brown", "grey", "beige", "multi-color", "other"
121
+ ]
122
+ },
123
+
124
+ "pattern": {
125
+ "description": "Pattern of the garment",
126
+ "data_type": "list[string]",
127
+ "allowed_values": [
128
+ "plain", "striped", "checkered", "floral", "polka dot", "camouflage", "animal print", "abstract", "other"
129
+ ]
130
+ },
131
+
132
+ "material": {
133
+ "description": "Material of the garment",
134
+ "data_type": "string",
135
+ "allowed_values": []
136
+ }
137
+ """
138
+ description = """
139
+ This is a simple demo for Attribution. Follow the steps below:
140
+
141
+ 1. Upload image(s) of a product.
142
+ 2. Enter the product taxonomy (e.g. 'upper garment', 'lower garment', 'bag'). If only one product is in the image, you can leave this field empty.
143
+ 3. Select the AI model to use.
144
+ 4. Enter known attributes (optional).
145
+ 5. Enter the attribute schema or use the "Add Attributes" section to add attributes.
146
+ 6. Click "Extract Attributes" to get the extracted attributes.
147
+ """
148
+
149
+ product_data_placeholder = """Example:
150
+ {
151
+ "brand": "Leaf",
152
+ "size": "M",
153
+ "product_name": "Leaf T-shirt",
154
+ "color": "red"
155
+ }
156
+ """
157
+ product_data_value = """
158
+ {
159
+ "data1": "",
160
+ "data2": ""
161
+ }
162
+ """
163
+
164
+ with gr.Blocks(title="Internal Demo for Attribution") as demo:
165
+ with gr.Row():
166
+ with gr.Column(scale=12):
167
+ gr.Markdown(
168
+ """<div style="text-align: center; font-size: 24px;"><strong>Internal Demo for Attribution</strong></div>"""
169
+ )
170
+ gr.Markdown(description)
171
+
172
+ with gr.Row():
173
+ with gr.Column(scale=12):
174
+ with gr.Row():
175
+ with gr.Column():
176
+ gallery = gr.Gallery(
177
+ label="Upload images of your product here", type="pil"
178
+ )
179
+ product_taxnomy = gr.Textbox(
180
+ label="Product Taxonomy",
181
+ placeholder="Enter product taxonomy here (e.g. 'upper garment', 'lower garment', 'bag')",
182
+ lines=1,
183
+ max_lines=1,
184
+ )
185
+ ai_model = gr.Dropdown(
186
+ label="AI Model",
187
+ choices=settings.SUPPORTED_MODELS,
188
+ interactive=True,
189
+ )
190
+ product_data = gr.TextArea(
191
+ label="Product Data (Optional)",
192
+ placeholder=product_data_placeholder,
193
+ value=product_data_value.strip(),
194
+ interactive=True,
195
+ lines=10,
196
+ max_lines=10,
197
+ )
198
+
199
+ # track_count = gr.State(1)
200
+ # @gr.render(inputs=track_count)
201
+ # def render_tracks(count):
202
+ # ka_names = []
203
+ # ka_values = []
204
+ # with gr.Column():
205
+ # for i in range(count):
206
+ # with gr.Column(variant="panel"):
207
+ # with gr.Row():
208
+ # ka_name = gr.Textbox(placeholder="key", key=f"key-{i}", show_label=False)
209
+ # ka_value = gr.Textbox(placeholder="data", key=f"data-{i}", show_label=False)
210
+ # ka_names.append(ka_name)
211
+ # ka_values.append(ka_value)
212
+
213
+ # add_track_btn = gr.Button("Add Product Data")
214
+ # remove_track_btn = gr.Button("Remove Product Data")
215
+ # add_track_btn.click(lambda count: count + 1, track_count, track_count)
216
+ # remove_track_btn.click(lambda count: count - 1, track_count, track_count)
217
+
218
+ with gr.Column():
219
+ attributes = gr.TextArea(
220
+ label="Attribute Schema",
221
+ value=sample_schema,
222
+ placeholder="Enter schema here or use Add Attributes below",
223
+ interactive=True,
224
+ lines=30,
225
+ max_lines=30,
226
+ )
227
+
228
+ with gr.Accordion("Add Attributes", open=False):
229
+ attr_name = gr.Textbox(
230
+ label="Attribute name", placeholder="Enter attribute name"
231
+ )
232
+ attr_desc = gr.Textbox(
233
+ label="Description", placeholder="Enter description"
234
+ )
235
+ attr_type = gr.Dropdown(
236
+ label="Type",
237
+ choices=[
238
+ "string",
239
+ "list[string]",
240
+ "int",
241
+ "list[int]",
242
+ "float",
243
+ "list[float]",
244
+ "bool",
245
+ "list[bool]",
246
+ ],
247
+ interactive=True,
248
+ )
249
+ allowed_values = gr.Textbox(
250
+ label="Allowed values (separated by comma)",
251
+ placeholder="yellow, red, blue",
252
+ )
253
+ add_btn = gr.Button("Add Attribute")
254
+
255
+ with gr.Row():
256
+ submit_btn = gr.Button("Extract Attributes")
257
+
258
+ with gr.Column(scale=6):
259
+ output_json = gr.Json(
260
+ label="Extracted Attributes", value={}, show_indices=False
261
+ )
262
+
263
+ add_btn.click(
264
+ add_attribute_schema,
265
+ inputs=[attributes, attr_name, attr_desc, attr_type, allowed_values],
266
+ outputs=[attributes, attr_name, attr_desc, attr_type, allowed_values],
267
+ )
268
+
269
+ submit_btn.click(
270
+ forward_request,
271
+ inputs=[attributes, product_taxnomy, product_data, ai_model, gallery],
272
+ outputs=output_json,
273
+ )
274
+
275
+ demo.launch()
app/__init__.py ADDED
File without changes
app/config.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import lru_cache
3
+ from typing import Optional
4
+
5
+ from pydantic_settings import BaseSettings
6
+
7
+
8
+ if os.getenv("HUGGINGFACE_DEMO"):
9
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
10
+ ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
11
+ else:
12
+ from app.aws.secrets import get_secret
13
+
14
+ secrets = get_secret()
15
+ os.environ["WANDB_API_KEY"] = secrets["WANDB_API_KEY"]
16
+ OPENAI_API_KEY = secrets["OPENAI_API_KEY"]
17
+ ANTHROPIC_API_KEY = secrets["ANTHROPIC_API_KEY"]
18
+ os.environ["WANDB_BASE_URL"] = "https://api.wandb.ai"
19
+
20
+
21
+ class Settings(BaseSettings):
22
+ # Supported openai models
23
+ OPENAI_MODELS: list = [
24
+ "gpt-4o", # first model is the default of the vendor
25
+ "gpt-4o-2024-11-20",
26
+ "gpt-4o-mini",
27
+ ]
28
+
29
+ # Supported anthropic models
30
+ ANTHROPIC_MODELS: list = [
31
+ "claude-3-5-sonnet-latest" # first model is the default of the vendor
32
+ ]
33
+
34
+ # Supprted AI Services
35
+ SUPPORTED_MODELS: list = OPENAI_MODELS + ANTHROPIC_MODELS
36
+
37
+ # API Keys
38
+ OPENAI_API_KEY: str
39
+ ANTHROPIC_API_KEY: str
40
+
41
+ DEFAULT_MAX_ATTEMPTS: int = 1
42
+
43
+ # AI Service Configuration
44
+ DEFAULT_MODEL: str = OPENAI_MODELS[0]
45
+ MAX_TOKENS: int = 2000
46
+ TEMPERATURE: float = 0.0
47
+
48
+ # CORS Configuration
49
+ CORS_ALLOW_ORIGINS: bool = True
50
+
51
+ # API Configuration
52
+ API_V1_PREFIX: str = "/api/v1"
53
+ PROJECT_NAME: str = "Dreem Attribution"
54
+ DEBUG: bool = False
55
+
56
+ # Rate Limiting
57
+ RATE_LIMIT_CALLS: int = 100
58
+ RATE_LIMIT_PERIOD: int = 60
59
+
60
+ # Cache Configuration
61
+ REDIS_URL: Optional[str] = None
62
+ CACHE_TTL: int = 3600 # 1 hour
63
+
64
+ # Logging
65
+ LOG_LEVEL: str = "INFO"
66
+ LOG_FORMAT: str = "json"
67
+
68
+ # Timeout Configuration
69
+ OPENAI_TIMEOUT: float = 30.0
70
+ ANTHROPIC_TIMEOUT: float = 30.0
71
+
72
+ # API Keys
73
+ OPENAI_API_KEY: str = OPENAI_API_KEY
74
+ ANTHROPIC_API_KEY: str = ANTHROPIC_API_KEY
75
+
76
+ def validate_api_keys(self):
77
+ """Validate that required API keys are present."""
78
+ if not self.OPENAI_API_KEY:
79
+ raise ValueError("OPENAI_API_KEY is required")
80
+ if not self.ANTHROPIC_API_KEY:
81
+ raise ValueError("ANTHROPIC_API_KEY is required")
82
+
83
+
84
+ # Create a cached instance of settings
85
+ @lru_cache
86
+ def get_settings() -> Settings:
87
+ """
88
+ Create and cache a Settings instance.
89
+ Returns the same instance for subsequent calls.
90
+ """
91
+ settings = Settings()
92
+ settings.validate_api_keys()
93
+ return settings
app/core/__init__.py ADDED
File without changes
app/core/errors.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ VENDOR_ERROR_INVALID_JSON = "Vendor Error: Invalid JSON data"
2
+ VENDOR_THROW_ERROR = "Vendor Error: {error_message}"
3
+
4
+
5
+ class VendorError(Exception):
6
+ def __init__(self, message: str):
7
+ super().__init__(message)
8
+
9
+
10
+ class BadRequestError(Exception):
11
+ def __init__(self, message: str):
12
+ super().__init__(message)
app/core/prompts.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ from typing import Optional
3
+
4
+ from pydantic_settings import BaseSettings
5
+
6
+ EXTRACT_INFO_SYSTEM = "You are an expert at structured data extraction. You will be given an image of a product and should output the its properties into the given structure."
7
+
8
+ EXTRACT_INFO_HUMAN = (
9
+ """Output properties of the {product_taxonomy} product in the images. You should use the following attributes to help you if it exists:
10
+
11
+ {product_data}
12
+
13
+ If an attribute is both in the image and the attributes, use the one in the attribute."""
14
+ ).replace(" ", "")
15
+
16
+ FOLLOW_SCHEMA_SYSTEM = "You are an expert at structured data extraction. You will be given an dictionary of attributes of a product and should output the its properties into the given structure."
17
+
18
+ FOLLOW_SCHEMA_HUMAN = """Convert following attributes to structured schema. Keep all the keys and number of values. Only replace the values themselves. :
19
+
20
+ {json_info}"""
21
+
22
+
23
+ class Prompts(BaseSettings):
24
+ EXTRACT_INFO_SYSTEM_MESSAGE: str = EXTRACT_INFO_SYSTEM
25
+
26
+ EXTRACT_INFO_HUMAN_MESSAGE: str = EXTRACT_INFO_HUMAN
27
+
28
+ FOLLOW_SCHEMA_SYSTEM_MESSAGE: str = FOLLOW_SCHEMA_SYSTEM
29
+
30
+ FOLLOW_SCHEMA_HUMAN_MESSAGE: str = FOLLOW_SCHEMA_HUMAN
31
+
32
+
33
+ # Create a cached instance of settings
34
+ @lru_cache
35
+ def get_prompts() -> Prompts:
36
+ """
37
+ Create and cache a Prompts instance.
38
+ Returns the same instance for subsequent calls.
39
+ """
40
+ prompts = Prompts()
41
+ return prompts
app/core/security.py ADDED
File without changes
app/request_handler/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from app.request_handler.extract_handler import handle_extract
2
+ from app.request_handler.follow_handler import handle_follow
app/request_handler/extract_handler.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+
3
+ import requests
4
+ from fastapi import HTTPException
5
+ from PIL import Image
6
+
7
+ from app.config import get_settings
8
+ from app.core.errors import BadRequestError, VendorError
9
+ from app.schemas.requests import ExtractionRequest
10
+ from app.schemas.responses import APIResponse
11
+ from app.services.factory import AIServiceFactory
12
+ from app.utils.logger import setup_logger
13
+
14
+ logger = setup_logger(__name__)
15
+ settings = get_settings()
16
+
17
+
18
+ async def handle_extract(request: ExtractionRequest):
19
+ request.max_attempts = max(request.max_attempts, 1)
20
+ request.max_attempts = min(request.max_attempts, 5)
21
+
22
+ for attempt in range(1, request.max_attempts + 1):
23
+ try:
24
+ logger.info(f"Attempt: {attempt}")
25
+ if request.ai_model in settings.OPENAI_MODELS:
26
+ ai_vendor = "openai"
27
+ elif request.ai_model in settings.ANTHROPIC_MODELS:
28
+ ai_vendor = "anthropic"
29
+ else:
30
+ raise ValueError(
31
+ f"Invalid AI model: {request.ai_model}, only support {settings.SUPPORTED_MODELS}"
32
+ )
33
+ service = AIServiceFactory.get_service(ai_vendor)
34
+
35
+ pil_images = []
36
+ for url in request.img_urls:
37
+ try:
38
+ response = requests.get(url)
39
+ response.raise_for_status()
40
+ image = Image.open(BytesIO(response.content))
41
+ pil_images.append(image)
42
+ except Exception as e:
43
+ print(e)
44
+ logger.error(f"Failed to download or process image from {url}: {e}")
45
+ raise HTTPException(
46
+ status_code=400,
47
+ detail=f"Failed to process image from {url}",
48
+ headers={"attempt": attempt},
49
+ )
50
+
51
+ json_attributes = await service.extract_attributes_with_validation(
52
+ request.attributes,
53
+ request.ai_model,
54
+ request.img_urls,
55
+ request.product_taxonomy,
56
+ request.product_data,
57
+ pil_images=pil_images,
58
+ )
59
+ break
60
+ except BadRequestError as e:
61
+ logger.error("Bad request error: ", e)
62
+ raise HTTPException(
63
+ status_code=400, detail=str(e), headers={"attempt": attempt}
64
+ )
65
+ except ValueError as e:
66
+ logger.error("Value error: ", e)
67
+ raise HTTPException(
68
+ status_code=400, detail=str(e), headers={"attempt": attempt}
69
+ )
70
+ except VendorError as e:
71
+ logger.error("Vendor error: ", e)
72
+ if attempt == request.max_attempts:
73
+ raise HTTPException(
74
+ status_code=500, detail=str(e), headers={"attempt": attempt}
75
+ )
76
+ else:
77
+ if request.ai_model in settings.ANTHROPIC_MODELS:
78
+ request.ai_model = settings.OPENAI_MODELS[
79
+ 0
80
+ ] # switch to OpenAI, and try again if max_attempts not reached
81
+ logger.info(
82
+ f"Switching from anthropic to {request.ai_model} for attempt {attempt + 1}"
83
+ )
84
+ elif request.ai_model in settings.OPENAI_MODELS:
85
+ request.ai_model = settings.ANTHROPIC_MODELS[
86
+ 0
87
+ ] # switch to anthropic, and try again if max_attempts not reached
88
+ logger.info(
89
+ f"Switching from OpenAI to {request.ai_model} for attempt {attempt + 1}"
90
+ )
91
+
92
+ except HTTPException as e:
93
+ logger.error("HTTP exception: ", e)
94
+ raise e
95
+ except Exception as e:
96
+ logger.error("Exception: ", e)
97
+ if (
98
+ "overload" in str(e).lower()
99
+ and request.ai_model in settings.ANTHROPIC_MODELS
100
+ ):
101
+ request.ai_model = settings.OPENAI_MODELS[
102
+ 0
103
+ ] # switch to OpenAI, and try again if max_attempts not reached
104
+ if attempt == request.max_attempts:
105
+ raise HTTPException(
106
+ status_code=500,
107
+ detail="Internal server error",
108
+ headers={"attempt": attempt},
109
+ )
110
+
111
+ return json_attributes, attempt
app/request_handler/follow_handler.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+
3
+ from app.config import get_settings
4
+ from app.core.errors import VendorError
5
+ from app.schemas.requests import FollowSchemaRequest
6
+ from app.services.factory import AIServiceFactory
7
+ from app.utils.logger import setup_logger
8
+
9
+ logger = setup_logger(__name__)
10
+ settings = get_settings()
11
+
12
+
13
+ async def handle_follow(request: FollowSchemaRequest):
14
+
15
+ request.max_attempts = max(request.max_attempts, 1)
16
+ request.max_attempts = min(request.max_attempts, 5)
17
+
18
+ for attempt in range(1, request.max_attempts + 1):
19
+ try:
20
+ logger.info(f"Attempt: {attempt}")
21
+ if request.ai_model in settings.OPENAI_MODELS:
22
+ ai_vendor = "openai"
23
+ elif request.ai_model in settings.ANTHROPIC_MODELS:
24
+ ai_vendor = "anthropic"
25
+ else:
26
+ raise ValueError(
27
+ f"Invalid AI model: {request.ai_model}, only support {settings.SUPPORTED_MODELS}"
28
+ )
29
+ service = AIServiceFactory.get_service(ai_vendor)
30
+ json_attributes = await service.follow_schema_with_validation(
31
+ request.data_schema, request.data
32
+ )
33
+ break
34
+ except ValueError as e:
35
+ if attempt == request.max_attempts:
36
+ raise HTTPException(
37
+ status_code=400, detail=str(e), headers={"attempt": attempt}
38
+ )
39
+ except VendorError as e:
40
+ if attempt == request.max_attempts:
41
+ raise HTTPException(
42
+ status_code=500, detail=str(e), headers={"attempt": attempt}
43
+ )
44
+ except Exception as e:
45
+ if attempt == request.max_attempts:
46
+ raise HTTPException(
47
+ status_code=500,
48
+ detail="Internal server error",
49
+ headers={"attempt": attempt},
50
+ )
51
+
52
+ return json_attributes, attempt
app/request_handler/validate.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.config import get_settings
2
+ from app.schemas.requests import ExtractionRequest, FollowSchemaRequest
3
+ from app.schemas.schema_tools import validate_json_schema
4
+ from app.utils.logger import setup_logger
5
+
6
+ logger = setup_logger(__name__)
7
+ settings = get_settings()
8
+
9
+
10
+ def validate_extract_request(request: ExtractionRequest):
11
+ """Validate the request to extract attributes."""
12
+ request.max_attempts = max(request.max_attempts, 1)
13
+ request.max_attempts = min(request.max_attempts, 5)
14
+
15
+ # Limit the number of images to 10
16
+ if len(request.img_urls) > 10:
17
+ logger.warning(
18
+ f"Number of images exceeds 10: {len(request.img_urls)}. Limiting to 10."
19
+ )
20
+ request.img_urls = request.img_urls[:10]
21
+
22
+ for url in request.img_urls:
23
+ if not url.startswith("http"):
24
+ raise ValueError(f"Invalid URL: {url}")
25
+
26
+ # validate_json_schema(request.data_schema)
27
+
28
+ if request.ai_model.lower() not in settings.SUPPORTED_MODELS:
29
+ raise ValueError(
30
+ f"Invalid ai_model: {request.ai_model}, only support {settings.SUPPORTED_MODELS}"
31
+ )
32
+
33
+
34
+ def validate_follow_request(request: FollowSchemaRequest):
35
+ """Validate the request to follow a schema."""
36
+ request.max_attempts = max(request.max_attempts, 1)
37
+ request.max_attempts = min(request.max_attempts, 5)
38
+
39
+ validate_json_schema(request.data_schema)
40
+
41
+ if request.ai_model.lower() not in settings.SUPPORTED_MODELS:
42
+ raise ValueError(
43
+ f"Invalid ai_model: {request.ai_model}, only support {settings.SUPPORTED_MODELS}"
44
+ )
app/schemas/__init__.py ADDED
File without changes
app/schemas/requests.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from app.config import get_settings
6
+
7
+ settings = get_settings()
8
+
9
+
10
+ class Attribute(BaseModel):
11
+ description: str
12
+ data_type: str
13
+ allowed_values: Optional[List[str]] = []
14
+
15
+
16
+ class ExtractionRequest(BaseModel):
17
+ attributes: Dict[str, Attribute]
18
+ img_urls: Optional[List[str]] = None
19
+ product_taxonomy: str
20
+ request_meta: Optional[Dict[str, str]] = None
21
+ product_data: Optional[Dict[str, str]] = None
22
+ ai_model: str = settings.DEFAULT_MODEL # type: ignore
23
+ max_attempts: int = settings.DEFAULT_MAX_ATTEMPTS # type: ignore
24
+
25
+
26
+ class FollowSchemaRequest(BaseModel):
27
+ data_schema: Dict[str, Any]
28
+ data: Dict[str, Any]
29
+ request_meta: Optional[Dict[str, str]] = None
30
+ ai_model: str = settings.DEFAULT_MODEL
31
+ max_attempts: int = settings.DEFAULT_MAX_ATTEMPTS # type: ignore
32
+
33
+
34
+ class ResultRequest(BaseModel):
35
+ task_id: str
app/schemas/responses.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class SubmitResponse(BaseModel):
7
+ task_id: str
8
+
9
+
10
+ class ResultResponse(BaseModel):
11
+ request_meta: Optional[Dict[str, str]] = None
12
+ task_id: str
13
+ result: dict
14
+ status_code: int
15
+ detail: str
16
+ attempt: int
17
+
18
+
19
+ class HealthCheckResponse(BaseModel):
20
+ status: str
21
+
22
+
23
+ class APIResponse(BaseModel):
24
+ detail: str
25
+ data: Dict[str, Any]
26
+ attempts: int
27
+
28
+
29
+ class APIErrorResponse(BaseModel):
30
+ detail: str
31
+
32
+
33
+ HEALTH_CHECK_RESPONSES = {}
34
+
35
+ SUBMIT_EXTRACT_RESPONSES = {
36
+ 400: {
37
+ "model": APIErrorResponse,
38
+ },
39
+ 500: {"model": APIErrorResponse},
40
+ }
41
+
42
+ SUBMIT_FOLLOW_RESPONSES = {
43
+ 400: {
44
+ "model": APIErrorResponse,
45
+ },
46
+ 500: {"model": APIErrorResponse},
47
+ }
48
+
49
+ RESULT_RESPONSES = {
50
+ 400: {
51
+ "model": APIErrorResponse,
52
+ },
53
+ 404: {
54
+ "model": APIErrorResponse,
55
+ },
56
+ 500: {"model": APIErrorResponse},
57
+ }
58
+
59
+ RESPONSES = {
60
+ 400: {
61
+ "model": APIErrorResponse,
62
+ },
63
+ 404: {
64
+ "model": APIErrorResponse,
65
+ },
66
+ 500: {"model": APIErrorResponse},
67
+ }
app/schemas/schema_tools.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum # do not remove this import for exec
2
+ from typing import List # do not remove this import for exec
3
+ from typing import Any, Dict
4
+
5
+ import jsonschema
6
+ from jsf import JSF
7
+ from pydantic import BaseModel, Field # do not remove this import for exec
8
+
9
+ from app.core.errors import VendorError
10
+ from app.schemas.requests import Attribute
11
+
12
+
13
+ def validate_json_data(data: Dict[str, Any], schema: Dict[str, Any]):
14
+ """
15
+ Standalone JSON schema validation utility
16
+ """
17
+ try:
18
+ jsonschema.validate(instance=data, schema=schema)
19
+ except jsonschema.ValidationError as e:
20
+ raise VendorError(f"Vendor generated invalid data: {e}")
21
+
22
+
23
+ def validate_json_schema(schema: Dict[str, Any]):
24
+ """
25
+ Standalone JSON schema validation utility
26
+ """
27
+ if schema == {}:
28
+ raise ValueError(f"JSON Schema validation failed")
29
+
30
+ try:
31
+ faker = JSF(schema)
32
+ _ = faker.generate()
33
+ except:
34
+ raise ValueError(f"JSON Schema validation failed")
35
+
36
+
37
+ SUPPORTED_DATA_TYPE = [
38
+ "string",
39
+ "int",
40
+ "float",
41
+ "bool",
42
+ "list[string]",
43
+ "list[int]",
44
+ "list[float]",
45
+ "list[bool]",
46
+ ]
47
+
48
+
49
+ def convert_attribute_to_model(attributes: Dict[str, Attribute]) -> Dict[str, Any]:
50
+ import_code = ""
51
+ enum_code_list = []
52
+ master_class_code = "class Product(BaseModel):\n"
53
+ for key, value in attributes.items():
54
+ description = value.description
55
+ data_type = value.data_type
56
+ allowed_values = value.allowed_values
57
+ is_list = False
58
+
59
+ if data_type not in SUPPORTED_DATA_TYPE:
60
+ raise ValueError(f"Data type {data_type} is not supported")
61
+
62
+ if "list" in data_type:
63
+ is_list = True
64
+
65
+ if "int" in data_type:
66
+ data_type = "int"
67
+ elif "float" in data_type:
68
+ data_type = "float"
69
+ elif "bool" in data_type:
70
+ data_type = "bool"
71
+ elif "string" in data_type:
72
+ data_type = "str"
73
+
74
+ if len(allowed_values) > 0:
75
+ enum_code = f"class {key.capitalize()}Enum(str, Enum):\n"
76
+ for allowed_value in allowed_values:
77
+ enum_code += f" {allowed_value.replace(' ', '_').replace('-', '_').upper()} = '{allowed_value}'\n"
78
+ enum_code_list.append(enum_code)
79
+ data_type = f"{key.capitalize()}Enum"
80
+
81
+ if is_list:
82
+ data_type = f"List[{data_type}]"
83
+
84
+ master_class_code += (
85
+ f" {key}: {data_type} = Field(..., description='{description}')\n"
86
+ )
87
+
88
+ entire_code = import_code + "\n".join(enum_code_list) + "\n" + master_class_code
89
+ exec(entire_code, globals())
90
+
91
+ return Product # type: ignore
app/services/__init__.py ADDED
File without changes
app/services/base.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Dict, List, Type
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from app.schemas.schema_tools import (
7
+ convert_attribute_to_model,
8
+ validate_json_data,
9
+ validate_json_schema,
10
+ )
11
+
12
+
13
+ class BaseAttributionService(ABC):
14
+ @abstractmethod
15
+ async def extract_attributes(
16
+ self,
17
+ attributes_model: Type[BaseModel],
18
+ ai_model: str,
19
+ img_urls: List[str],
20
+ product_taxonomy: str,
21
+ pil_images: List[Any] = None,
22
+ ) -> Dict[str, Any]:
23
+ pass
24
+
25
+ @abstractmethod
26
+ async def follow_schema(
27
+ self, schema: Dict[str, Any], data: Dict[str, Any]
28
+ ) -> Dict[str, Any]:
29
+ pass
30
+
31
+ async def extract_attributes_with_validation(
32
+ self,
33
+ attributes: Dict[str, Any],
34
+ ai_model: str,
35
+ img_urls: List[str],
36
+ product_taxonomy: str,
37
+ product_data: Dict[str, str],
38
+ pil_images: List[Any] = None,
39
+ img_paths: List[str] = None,
40
+ ) -> Dict[str, Any]:
41
+ # validate_json_schema(schema)
42
+ attributes_model = convert_attribute_to_model(attributes)
43
+ schema = attributes_model.model_json_schema()
44
+ data = await self.extract_attributes(
45
+ attributes_model,
46
+ ai_model,
47
+ img_urls,
48
+ product_taxonomy,
49
+ product_data,
50
+ # pil_images=pil_images, # temporarily removed for save cost
51
+ img_paths=img_paths,
52
+ )
53
+ validate_json_data(data, schema)
54
+ return data
55
+
56
+ async def follow_schema_with_validation(
57
+ self, schema: Dict[str, Any], data: Dict[str, Any]
58
+ ) -> Dict[str, Any]:
59
+ validate_json_schema(schema)
60
+ data = await self.follow_schema(schema, data)
61
+ validate_json_data(data, schema)
62
+ return data
app/services/factory.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type
2
+
3
+ from ..config import get_settings
4
+ from .base import BaseAttributionService
5
+ from .service_anthropic import AnthropicService
6
+ from .service_openai import OpenAIService
7
+
8
+ settings = get_settings()
9
+
10
+
11
+ class AIServiceFactory:
12
+ _services = {"openai": OpenAIService, "anthropic": AnthropicService}
13
+
14
+ @classmethod
15
+ def get_service(cls, ai_vendor: str = None) -> BaseAttributionService:
16
+ ai_vendor = ai_vendor or settings.DEFAULT_VENDOR
17
+ service_class = cls._services.get(ai_vendor.lower())
18
+ if not service_class:
19
+ raise ValueError(f"Unsupported ai_vendor: {ai_vendor}")
20
+ return service_class()
app/services/service_anthropic.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Any, Dict, List, Type
4
+
5
+ import anthropic
6
+ import weave
7
+ from anthropic import APIStatusError, AsyncAnthropic
8
+ from pydantic import BaseModel
9
+
10
+ from app.config import get_settings
11
+ from app.core import errors
12
+ from app.core.errors import BadRequestError, VendorError
13
+ from app.core.prompts import get_prompts
14
+ from app.services.base import BaseAttributionService
15
+ from app.utils.converter import product_data_to_str
16
+ from app.utils.image_processing import get_data_format, get_image_data
17
+ from app.utils.logger import setup_logger
18
+
19
+ deployment = os.getenv("DEPLOYMENT", "LOCAL")
20
+ if deployment == "LOCAL": # local or demo
21
+ weave_project_name = "cfai/attribution-exp"
22
+ elif deployment == "DEV":
23
+ weave_project_name = "cfai/attribution-dev"
24
+ elif deployment == "PROD":
25
+ weave_project_name = "cfai/attribution-prod"
26
+
27
+ weave.init(project_name=weave_project_name)
28
+ settings = get_settings()
29
+ prompts = get_prompts()
30
+ logger = setup_logger(__name__)
31
+
32
+
33
+ class AnthropicService(BaseAttributionService):
34
+ def __init__(self):
35
+ self.client = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY)
36
+
37
+ @weave.op
38
+ async def extract_attributes(
39
+ self,
40
+ attributes_model: Type[BaseModel],
41
+ ai_model: str,
42
+ img_urls: List[str],
43
+ product_taxonomy: str,
44
+ product_data: Dict[str, str],
45
+ pil_images: List[Any] = None, # do not remove, this is for weave
46
+ img_paths: List[str] = None,
47
+ ) -> Dict[str, Any]:
48
+ logger.info("Extracting info via Anthropic...")
49
+ tools = [
50
+ {
51
+ "name": "extract_garment_info",
52
+ "description": "Extracts key information from the image.",
53
+ "input_schema": attributes_model.model_json_schema(),
54
+ "cache_control": {"type": "ephemeral"},
55
+ }
56
+ ]
57
+
58
+ if img_urls is not None:
59
+ image_messages = [
60
+ {
61
+ "type": "image",
62
+ "source": {"type": "url", "url": img_url},
63
+ }
64
+ for img_url in img_urls
65
+ ]
66
+ elif img_paths is not None:
67
+ image_messages = [
68
+ {
69
+ "type": "image",
70
+ "source": {
71
+ "type": "base64",
72
+ "media_type": f"image/{get_data_format(img_path)}",
73
+ "data": get_image_data(img_path),
74
+ },
75
+ }
76
+ for img_path in img_paths
77
+ ]
78
+ else:
79
+ # this is not expected, raise some errors here later.
80
+ pass
81
+
82
+ system_message = [{"type": "text", "text": prompts.EXTRACT_INFO_SYSTEM_MESSAGE}]
83
+
84
+ text_messages = [
85
+ {
86
+ "type": "text",
87
+ "text": prompts.EXTRACT_INFO_HUMAN_MESSAGE.format(
88
+ product_taxonomy=product_taxonomy,
89
+ product_data=product_data_to_str(product_data),
90
+ ),
91
+ }
92
+ ]
93
+
94
+ messages = [{"role": "user", "content": text_messages + image_messages}]
95
+
96
+ # try:
97
+ try:
98
+ response = await self.client.messages.create(
99
+ model=ai_model,
100
+ extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
101
+ max_tokens=2048,
102
+ system=system_message,
103
+ tools=tools,
104
+ messages=messages,
105
+ )
106
+ except anthropic.BadRequestError as e:
107
+ raise BadRequestError(e.message)
108
+ except Exception as e:
109
+ raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=str(e)))
110
+
111
+ for content in response.content:
112
+ if content.type == "tool_use":
113
+ return content.input
114
+
115
+ @weave.op
116
+ async def follow_schema(self, schema, data):
117
+ logger.info("Following structure via Anthropic...")
118
+ tools = [
119
+ {
120
+ "name": "extract_garment_info",
121
+ "description": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE,
122
+ "input_schema": schema,
123
+ "cache_control": {"type": "ephemeral"},
124
+ }
125
+ ]
126
+
127
+ text_messages = [
128
+ {
129
+ "type": "text",
130
+ "text": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data),
131
+ }
132
+ ]
133
+
134
+ system_message = [
135
+ {"type": "text", "text": prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE}
136
+ ]
137
+
138
+ messages = [{"role": "user", "content": text_messages}]
139
+ try:
140
+ response = await self.client.messages.create(
141
+ model=settings.ANTHROPIC_DEFAULT_MODEL,
142
+ extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
143
+ max_tokens=2048,
144
+ system=system_message,
145
+ tools=tools,
146
+ messages=messages,
147
+ )
148
+ except Exception as e:
149
+ raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=str(e)))
150
+
151
+ for content in response.content:
152
+ if content.type == "tool_use":
153
+ return content.input["json_info"]
154
+
155
+ return {"status": "ERROR: no tool_use found"}
app/services/service_openai.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Any, Dict, List, Type
4
+
5
+ import openai
6
+ import weave
7
+ from openai import AsyncOpenAI
8
+ from pydantic import BaseModel
9
+
10
+ from app.utils.converter import product_data_to_str
11
+ from app.utils.image_processing import get_data_format, get_image_data
12
+ from app.utils.logger import setup_logger
13
+
14
+ from ..config import get_settings
15
+ from ..core import errors
16
+ from ..core.errors import BadRequestError, VendorError
17
+ from ..core.prompts import get_prompts
18
+ from .base import BaseAttributionService
19
+
20
+ deployment = os.getenv("DEPLOYMENT", "LOCAL")
21
+ if deployment == "LOCAL": # local or demo
22
+ weave_project_name = "cfai/attribution-exp"
23
+ elif deployment == "DEV":
24
+ weave_project_name = "cfai/attribution-dev"
25
+ elif deployment == "PROD":
26
+ weave_project_name = "cfai/attribution-prod"
27
+
28
+ weave.init(project_name=weave_project_name)
29
+ settings = get_settings()
30
+ prompts = get_prompts()
31
+ logger = setup_logger(__name__)
32
+
33
+
34
+ def get_response_format(json_schema: dict[str, any]) -> dict[str, any]:
35
+ # OpenAI requires each $def have to have additionalProperties set to False
36
+ json_schema["additionalProperties"] = False
37
+
38
+ # check if the schema has a $defs key
39
+ if "$defs" in json_schema:
40
+ for keys in json_schema["$defs"].keys():
41
+ json_schema["$defs"][keys]["additionalProperties"] = False
42
+ response_format = {
43
+ "type": "json_schema",
44
+ "json_schema": {"strict": True, "name": "GarmentSchema", "schema": json_schema},
45
+ }
46
+
47
+ return response_format
48
+
49
+
50
+ class OpenAIService(BaseAttributionService):
51
+ def __init__(self):
52
+ self.client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
53
+
54
+ @weave.op
55
+ async def extract_attributes(
56
+ self,
57
+ attributes_model: Type[BaseModel],
58
+ ai_model: str,
59
+ img_urls: List[str],
60
+ product_taxonomy: str,
61
+ product_data: Dict[str, str],
62
+ pil_images: List[Any] = None, # do not remove, this is for weave
63
+ img_paths: List[str] = None,
64
+ ) -> Dict[str, Any]:
65
+ logger.info("Extracting info via OpenAI...")
66
+ text_content = [
67
+ {
68
+ "type": "text",
69
+ "text": prompts.EXTRACT_INFO_HUMAN_MESSAGE.format(
70
+ product_taxonomy=product_taxonomy,
71
+ product_data=product_data_to_str(product_data),
72
+ ),
73
+ },
74
+ ]
75
+ if img_urls is not None:
76
+ image_content = [
77
+ {
78
+ "type": "image_url",
79
+ "image_url": {
80
+ "url": img_url,
81
+ },
82
+ }
83
+ for img_url in img_urls
84
+ ]
85
+ elif img_paths is not None:
86
+ image_content = [
87
+ {
88
+ "type": "image_url",
89
+ "image_url": {
90
+ "url": f"data:image/{get_data_format(img_path)};base64,{get_image_data(img_path)}",
91
+ },
92
+ }
93
+ for img_path in img_paths
94
+ ]
95
+
96
+ try:
97
+ response = await self.client.beta.chat.completions.parse(
98
+ model=ai_model,
99
+ messages=[
100
+ {
101
+ "role": "system",
102
+ "content": prompts.EXTRACT_INFO_SYSTEM_MESSAGE,
103
+ },
104
+ {
105
+ "role": "user",
106
+ "content": text_content + image_content,
107
+ },
108
+ ],
109
+ max_tokens=1000,
110
+ response_format=attributes_model,
111
+ logprobs=False,
112
+ # top_logprobs=2,
113
+ temperature=0.0,
114
+ )
115
+ except openai.BadRequestError as e:
116
+ raise BadRequestError(str(e))
117
+ except Exception as e:
118
+ raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=str(e)))
119
+
120
+ try:
121
+ content = response.choices[0].message.content
122
+ parsed_data = json.loads(content)
123
+ except:
124
+ raise VendorError(errors.VENDOR_ERROR_INVALID_JSON)
125
+
126
+ return parsed_data
127
+
128
+ @weave.op
129
+ async def follow_schema(
130
+ self, schema: Dict[str, Any], data: Dict[str, Any]
131
+ ) -> Dict[str, Any]:
132
+ logger.info("Following structure via OpenAI...")
133
+ text_content = [
134
+ {
135
+ "type": "text",
136
+ "text": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data),
137
+ },
138
+ ]
139
+
140
+ try:
141
+ response = await self.client.beta.chat.completions.parse(
142
+ model="gpt-4o-2024-11-20",
143
+ messages=[
144
+ {
145
+ "role": "system",
146
+ "content": prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE,
147
+ },
148
+ {
149
+ "role": "user",
150
+ "content": text_content,
151
+ },
152
+ ],
153
+ max_tokens=1000,
154
+ response_format=get_response_format(schema),
155
+ logprobs=False,
156
+ # top_logprobs=2,
157
+ temperature=0.0,
158
+ )
159
+ except Exception as e:
160
+ raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=str(e)))
161
+
162
+ if response.choices[0].message.refusal:
163
+ logger.info("OpenAI refused to respond to the request")
164
+ return {"status": "refused"}
165
+
166
+ try:
167
+ content = response.choices[0].message.content
168
+ parsed_data = json.loads(content)
169
+ except:
170
+ raise ValueError(errors.VENDOR_ERROR_INVALID_JSON)
171
+
172
+ return parsed_data
app/utils/__init__.py ADDED
File without changes
app/utils/converter.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def product_data_to_str(product_data: dict[str, any]) -> str:
2
+ """
3
+ Convert product data to a string.
4
+
5
+ Args:
6
+ - product_data: a dictionary of product data
7
+
8
+ Returns:
9
+ - a string representation of the product data
10
+ """
11
+ if product_data is None:
12
+ return ""
13
+
14
+ return "\n".join([f"{k}: {v}" for k, v in product_data.items()])
app/utils/image_processing.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+
3
+
4
+ def get_image_data(image_path):
5
+ with open(image_path, "rb") as f:
6
+ image_data = base64.b64encode(f.read()).decode("utf-8")
7
+ return image_data
8
+
9
+
10
+ def get_data_format(image_path):
11
+ image_format = image_path.split(".")[-1]
12
+ if image_format == "jpg":
13
+ image_format = "jpeg"
14
+ return image_format
app/utils/logger.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from logging.handlers import RotatingFileHandler
4
+
5
+
6
+ # Configure logger
7
+ def setup_logger(name: str) -> logging.Logger:
8
+ log_level = os.getenv("LOG_LEVEL", "INFO").upper()
9
+ log_file = os.getenv("LOG_FILE", "app.log")
10
+ max_bytes = int(os.getenv("LOG_MAX_BYTES", 10 * 1024 * 1024)) # 10 MB
11
+ backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5))
12
+
13
+ logger = logging.getLogger(name)
14
+ logger.setLevel(log_level)
15
+
16
+ # Console handler
17
+ console_handler = logging.StreamHandler()
18
+ console_handler.setLevel(log_level)
19
+ console_formatter = logging.Formatter(
20
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
21
+ )
22
+ console_handler.setFormatter(console_formatter)
23
+
24
+ # Rotating file handler
25
+ file_handler = RotatingFileHandler(
26
+ log_file, maxBytes=max_bytes, backupCount=backup_count
27
+ )
28
+ file_handler.setLevel(log_level)
29
+ file_formatter = logging.Formatter(
30
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
31
+ )
32
+ file_handler.setFormatter(file_formatter)
33
+
34
+ # Add handlers
35
+ logger.addHandler(console_handler)
36
+ logger.addHandler(file_handler)
37
+ logger.propagate = False
38
+
39
+ return logger
app/utils/rate_limiter.py ADDED
File without changes
app/utils/token_counter.py ADDED
File without changes
clean_for_gradio.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ rm -rf app/api
2
+ rm -rf app/aws
3
+ rm app/main.py
4
+ rm worker.py
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.6
2
+ fastapi-cli==0.0.7
3
+ pydantic==2.7.4
4
+ pydantic_settings==2.7.0
5
+ openai
6
+ anthropic==0.42.0
7
+ Pillow==11.0.0
8
+ requests==2.32.3
9
+ jsonschema==4.23.0
10
+ jsf==0.11.2
11
+ pytest==8.3.4
12
+ boto3==1.35.87
13
+ redis==5.2.1
14
+ weave==0.51.39
15
+ gradio==5.22.0