Commit
·
8ba64a4
1
Parent(s):
0a2ea2e
initial commit
Browse files- .gitignore +6 -0
- app.py +275 -0
- app/__init__.py +0 -0
- app/config.py +93 -0
- app/core/__init__.py +0 -0
- app/core/errors.py +12 -0
- app/core/prompts.py +41 -0
- app/core/security.py +0 -0
- app/request_handler/__init__.py +2 -0
- app/request_handler/extract_handler.py +111 -0
- app/request_handler/follow_handler.py +52 -0
- app/request_handler/validate.py +44 -0
- app/schemas/__init__.py +0 -0
- app/schemas/requests.py +35 -0
- app/schemas/responses.py +67 -0
- app/schemas/schema_tools.py +91 -0
- app/services/__init__.py +0 -0
- app/services/base.py +62 -0
- app/services/factory.py +20 -0
- app/services/service_anthropic.py +155 -0
- app/services/service_openai.py +172 -0
- app/utils/__init__.py +0 -0
- app/utils/converter.py +14 -0
- app/utils/image_processing.py +14 -0
- app/utils/logger.py +39 -0
- app/utils/rate_limiter.py +0 -0
- app/utils/token_counter.py +0 -0
- clean_for_gradio.sh +4 -0
- requirements.txt +15 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
env
|
2 |
+
.env
|
3 |
+
app.log
|
4 |
+
gradio_temp/
|
5 |
+
|
6 |
+
__pycache__/
|
app.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.environ["HUGGINGFACE_DEMO"] = "1" # set before import from app
|
4 |
+
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
load_dotenv()
|
7 |
+
################################################################################################
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import uuid
|
11 |
+
import shutil
|
12 |
+
|
13 |
+
from app.config import get_settings
|
14 |
+
from app.schemas.requests import Attribute
|
15 |
+
from app.request_handler import handle_extract
|
16 |
+
from app.services.factory import AIServiceFactory
|
17 |
+
|
18 |
+
|
19 |
+
settings = get_settings()
|
20 |
+
IMAGE_MAX_SIZE = 1536
|
21 |
+
|
22 |
+
|
23 |
+
async def forward_request(attributes, product_taxonomy, product_data, ai_model, pil_images):
|
24 |
+
# prepare temp folder
|
25 |
+
request_id = str(uuid.uuid4())
|
26 |
+
request_temp_folder = os.path.join('gradio_temp', request_id)
|
27 |
+
os.makedirs(request_temp_folder, exist_ok=True)
|
28 |
+
|
29 |
+
try:
|
30 |
+
# convert attributes to schema
|
31 |
+
attributes = "attributes_object = {" + attributes + "}"
|
32 |
+
try:
|
33 |
+
attributes = exec(attributes, globals())
|
34 |
+
except:
|
35 |
+
raise gr.Error("Invalid `Attribute Schema`. Please insert valid schema following the example.")
|
36 |
+
for key, value in attributes_object.items(): # type: ignore
|
37 |
+
attributes_object[key] = Attribute(**value) # type: ignore
|
38 |
+
|
39 |
+
if product_data == "":
|
40 |
+
product_data = "{}"
|
41 |
+
product_data_code = f"product_data_object = {product_data}"
|
42 |
+
|
43 |
+
try:
|
44 |
+
exec(product_data_code, globals())
|
45 |
+
except:
|
46 |
+
raise gr.Error('Invalid `Product Data`. Please insert valid dictionary or leave it empty.')
|
47 |
+
|
48 |
+
if pil_images is None:
|
49 |
+
raise gr.Error('Please upload image(s) of the product')
|
50 |
+
pil_images = [pil_image[0] for pil_image in pil_images]
|
51 |
+
img_paths = []
|
52 |
+
for i, pil_image in enumerate(pil_images):
|
53 |
+
if max(pil_image.size) > IMAGE_MAX_SIZE:
|
54 |
+
ratio = IMAGE_MAX_SIZE / max(pil_image.size)
|
55 |
+
pil_image = pil_image.resize((int(pil_image.width * ratio), int(pil_image.height * ratio)))
|
56 |
+
img_path = os.path.join(request_temp_folder, f'{i}.jpg')
|
57 |
+
if pil_image.mode in ('RGBA', 'LA') or (pil_image.mode == 'P' and 'transparency' in pil_image.info):
|
58 |
+
pil_image = pil_image.convert("RGBA")
|
59 |
+
if pil_image.getchannel("A").getextrema() == (255, 255): # if fully opaque, save as JPEG
|
60 |
+
pil_image = pil_image.convert("RGB")
|
61 |
+
image_format = 'JPEG'
|
62 |
+
else:
|
63 |
+
image_format = 'PNG'
|
64 |
+
else:
|
65 |
+
image_format = 'JPEG'
|
66 |
+
pil_image.save(img_path, image_format, quality=100, subsampling=0)
|
67 |
+
img_paths.append(img_path)
|
68 |
+
|
69 |
+
# mapping
|
70 |
+
if ai_model in settings.OPENAI_MODELS:
|
71 |
+
ai_vendor = 'openai'
|
72 |
+
elif ai_model in settings.ANTHROPIC_MODELS:
|
73 |
+
ai_vendor = 'anthropic'
|
74 |
+
service = AIServiceFactory.get_service(ai_vendor)
|
75 |
+
|
76 |
+
try:
|
77 |
+
json_attributes = await service.extract_attributes_with_validation(
|
78 |
+
attributes_object, # type: ignore
|
79 |
+
ai_model,
|
80 |
+
None,
|
81 |
+
product_taxonomy,
|
82 |
+
product_data_object, # type: ignore
|
83 |
+
img_paths=img_paths,
|
84 |
+
)
|
85 |
+
except:
|
86 |
+
raise gr.Error('Failed to extract attributes. Something went wrong.')
|
87 |
+
finally:
|
88 |
+
# remove temp folder anyway
|
89 |
+
shutil.rmtree(request_temp_folder)
|
90 |
+
|
91 |
+
gr.Info('Process completed!')
|
92 |
+
return json_attributes
|
93 |
+
|
94 |
+
|
95 |
+
def add_attribute_schema(attributes, attr_name, attr_desc, attr_type, allowed_values):
|
96 |
+
schema = f"""
|
97 |
+
"{attr_name}": {{
|
98 |
+
"description": "{attr_desc}",
|
99 |
+
"data_type": "{attr_type}",
|
100 |
+
"allowed_values": [
|
101 |
+
{', '.join([f'"{v.strip()}"' for v in allowed_values.split(',')])}
|
102 |
+
]
|
103 |
+
}},
|
104 |
+
"""
|
105 |
+
return attributes + schema, "", "", "", ""
|
106 |
+
|
107 |
+
|
108 |
+
sample_schema = """"category": {
|
109 |
+
"description": "Category of the garment",
|
110 |
+
"data_type": "list[string]",
|
111 |
+
"allowed_values": [
|
112 |
+
"upper garment", "lower garment", "footwear", "accessory", "headwear", "dresses"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
|
116 |
+
"color": {
|
117 |
+
"description": "Color of the garment",
|
118 |
+
"data_type": "list[string]",
|
119 |
+
"allowed_values": [
|
120 |
+
"black", "white", "red", "blue", "green", "yellow", "pink", "purple", "orange", "brown", "grey", "beige", "multi-color", "other"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
|
124 |
+
"pattern": {
|
125 |
+
"description": "Pattern of the garment",
|
126 |
+
"data_type": "list[string]",
|
127 |
+
"allowed_values": [
|
128 |
+
"plain", "striped", "checkered", "floral", "polka dot", "camouflage", "animal print", "abstract", "other"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
|
132 |
+
"material": {
|
133 |
+
"description": "Material of the garment",
|
134 |
+
"data_type": "string",
|
135 |
+
"allowed_values": []
|
136 |
+
}
|
137 |
+
"""
|
138 |
+
description = """
|
139 |
+
This is a simple demo for Attribution. Follow the steps below:
|
140 |
+
|
141 |
+
1. Upload image(s) of a product.
|
142 |
+
2. Enter the product taxonomy (e.g. 'upper garment', 'lower garment', 'bag'). If only one product is in the image, you can leave this field empty.
|
143 |
+
3. Select the AI model to use.
|
144 |
+
4. Enter known attributes (optional).
|
145 |
+
5. Enter the attribute schema or use the "Add Attributes" section to add attributes.
|
146 |
+
6. Click "Extract Attributes" to get the extracted attributes.
|
147 |
+
"""
|
148 |
+
|
149 |
+
product_data_placeholder = """Example:
|
150 |
+
{
|
151 |
+
"brand": "Leaf",
|
152 |
+
"size": "M",
|
153 |
+
"product_name": "Leaf T-shirt",
|
154 |
+
"color": "red"
|
155 |
+
}
|
156 |
+
"""
|
157 |
+
product_data_value = """
|
158 |
+
{
|
159 |
+
"data1": "",
|
160 |
+
"data2": ""
|
161 |
+
}
|
162 |
+
"""
|
163 |
+
|
164 |
+
with gr.Blocks(title="Internal Demo for Attribution") as demo:
|
165 |
+
with gr.Row():
|
166 |
+
with gr.Column(scale=12):
|
167 |
+
gr.Markdown(
|
168 |
+
"""<div style="text-align: center; font-size: 24px;"><strong>Internal Demo for Attribution</strong></div>"""
|
169 |
+
)
|
170 |
+
gr.Markdown(description)
|
171 |
+
|
172 |
+
with gr.Row():
|
173 |
+
with gr.Column(scale=12):
|
174 |
+
with gr.Row():
|
175 |
+
with gr.Column():
|
176 |
+
gallery = gr.Gallery(
|
177 |
+
label="Upload images of your product here", type="pil"
|
178 |
+
)
|
179 |
+
product_taxnomy = gr.Textbox(
|
180 |
+
label="Product Taxonomy",
|
181 |
+
placeholder="Enter product taxonomy here (e.g. 'upper garment', 'lower garment', 'bag')",
|
182 |
+
lines=1,
|
183 |
+
max_lines=1,
|
184 |
+
)
|
185 |
+
ai_model = gr.Dropdown(
|
186 |
+
label="AI Model",
|
187 |
+
choices=settings.SUPPORTED_MODELS,
|
188 |
+
interactive=True,
|
189 |
+
)
|
190 |
+
product_data = gr.TextArea(
|
191 |
+
label="Product Data (Optional)",
|
192 |
+
placeholder=product_data_placeholder,
|
193 |
+
value=product_data_value.strip(),
|
194 |
+
interactive=True,
|
195 |
+
lines=10,
|
196 |
+
max_lines=10,
|
197 |
+
)
|
198 |
+
|
199 |
+
# track_count = gr.State(1)
|
200 |
+
# @gr.render(inputs=track_count)
|
201 |
+
# def render_tracks(count):
|
202 |
+
# ka_names = []
|
203 |
+
# ka_values = []
|
204 |
+
# with gr.Column():
|
205 |
+
# for i in range(count):
|
206 |
+
# with gr.Column(variant="panel"):
|
207 |
+
# with gr.Row():
|
208 |
+
# ka_name = gr.Textbox(placeholder="key", key=f"key-{i}", show_label=False)
|
209 |
+
# ka_value = gr.Textbox(placeholder="data", key=f"data-{i}", show_label=False)
|
210 |
+
# ka_names.append(ka_name)
|
211 |
+
# ka_values.append(ka_value)
|
212 |
+
|
213 |
+
# add_track_btn = gr.Button("Add Product Data")
|
214 |
+
# remove_track_btn = gr.Button("Remove Product Data")
|
215 |
+
# add_track_btn.click(lambda count: count + 1, track_count, track_count)
|
216 |
+
# remove_track_btn.click(lambda count: count - 1, track_count, track_count)
|
217 |
+
|
218 |
+
with gr.Column():
|
219 |
+
attributes = gr.TextArea(
|
220 |
+
label="Attribute Schema",
|
221 |
+
value=sample_schema,
|
222 |
+
placeholder="Enter schema here or use Add Attributes below",
|
223 |
+
interactive=True,
|
224 |
+
lines=30,
|
225 |
+
max_lines=30,
|
226 |
+
)
|
227 |
+
|
228 |
+
with gr.Accordion("Add Attributes", open=False):
|
229 |
+
attr_name = gr.Textbox(
|
230 |
+
label="Attribute name", placeholder="Enter attribute name"
|
231 |
+
)
|
232 |
+
attr_desc = gr.Textbox(
|
233 |
+
label="Description", placeholder="Enter description"
|
234 |
+
)
|
235 |
+
attr_type = gr.Dropdown(
|
236 |
+
label="Type",
|
237 |
+
choices=[
|
238 |
+
"string",
|
239 |
+
"list[string]",
|
240 |
+
"int",
|
241 |
+
"list[int]",
|
242 |
+
"float",
|
243 |
+
"list[float]",
|
244 |
+
"bool",
|
245 |
+
"list[bool]",
|
246 |
+
],
|
247 |
+
interactive=True,
|
248 |
+
)
|
249 |
+
allowed_values = gr.Textbox(
|
250 |
+
label="Allowed values (separated by comma)",
|
251 |
+
placeholder="yellow, red, blue",
|
252 |
+
)
|
253 |
+
add_btn = gr.Button("Add Attribute")
|
254 |
+
|
255 |
+
with gr.Row():
|
256 |
+
submit_btn = gr.Button("Extract Attributes")
|
257 |
+
|
258 |
+
with gr.Column(scale=6):
|
259 |
+
output_json = gr.Json(
|
260 |
+
label="Extracted Attributes", value={}, show_indices=False
|
261 |
+
)
|
262 |
+
|
263 |
+
add_btn.click(
|
264 |
+
add_attribute_schema,
|
265 |
+
inputs=[attributes, attr_name, attr_desc, attr_type, allowed_values],
|
266 |
+
outputs=[attributes, attr_name, attr_desc, attr_type, allowed_values],
|
267 |
+
)
|
268 |
+
|
269 |
+
submit_btn.click(
|
270 |
+
forward_request,
|
271 |
+
inputs=[attributes, product_taxnomy, product_data, ai_model, gallery],
|
272 |
+
outputs=output_json,
|
273 |
+
)
|
274 |
+
|
275 |
+
demo.launch()
|
app/__init__.py
ADDED
File without changes
|
app/config.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import lru_cache
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
from pydantic_settings import BaseSettings
|
6 |
+
|
7 |
+
|
8 |
+
if os.getenv("HUGGINGFACE_DEMO"):
|
9 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
10 |
+
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
|
11 |
+
else:
|
12 |
+
from app.aws.secrets import get_secret
|
13 |
+
|
14 |
+
secrets = get_secret()
|
15 |
+
os.environ["WANDB_API_KEY"] = secrets["WANDB_API_KEY"]
|
16 |
+
OPENAI_API_KEY = secrets["OPENAI_API_KEY"]
|
17 |
+
ANTHROPIC_API_KEY = secrets["ANTHROPIC_API_KEY"]
|
18 |
+
os.environ["WANDB_BASE_URL"] = "https://api.wandb.ai"
|
19 |
+
|
20 |
+
|
21 |
+
class Settings(BaseSettings):
|
22 |
+
# Supported openai models
|
23 |
+
OPENAI_MODELS: list = [
|
24 |
+
"gpt-4o", # first model is the default of the vendor
|
25 |
+
"gpt-4o-2024-11-20",
|
26 |
+
"gpt-4o-mini",
|
27 |
+
]
|
28 |
+
|
29 |
+
# Supported anthropic models
|
30 |
+
ANTHROPIC_MODELS: list = [
|
31 |
+
"claude-3-5-sonnet-latest" # first model is the default of the vendor
|
32 |
+
]
|
33 |
+
|
34 |
+
# Supprted AI Services
|
35 |
+
SUPPORTED_MODELS: list = OPENAI_MODELS + ANTHROPIC_MODELS
|
36 |
+
|
37 |
+
# API Keys
|
38 |
+
OPENAI_API_KEY: str
|
39 |
+
ANTHROPIC_API_KEY: str
|
40 |
+
|
41 |
+
DEFAULT_MAX_ATTEMPTS: int = 1
|
42 |
+
|
43 |
+
# AI Service Configuration
|
44 |
+
DEFAULT_MODEL: str = OPENAI_MODELS[0]
|
45 |
+
MAX_TOKENS: int = 2000
|
46 |
+
TEMPERATURE: float = 0.0
|
47 |
+
|
48 |
+
# CORS Configuration
|
49 |
+
CORS_ALLOW_ORIGINS: bool = True
|
50 |
+
|
51 |
+
# API Configuration
|
52 |
+
API_V1_PREFIX: str = "/api/v1"
|
53 |
+
PROJECT_NAME: str = "Dreem Attribution"
|
54 |
+
DEBUG: bool = False
|
55 |
+
|
56 |
+
# Rate Limiting
|
57 |
+
RATE_LIMIT_CALLS: int = 100
|
58 |
+
RATE_LIMIT_PERIOD: int = 60
|
59 |
+
|
60 |
+
# Cache Configuration
|
61 |
+
REDIS_URL: Optional[str] = None
|
62 |
+
CACHE_TTL: int = 3600 # 1 hour
|
63 |
+
|
64 |
+
# Logging
|
65 |
+
LOG_LEVEL: str = "INFO"
|
66 |
+
LOG_FORMAT: str = "json"
|
67 |
+
|
68 |
+
# Timeout Configuration
|
69 |
+
OPENAI_TIMEOUT: float = 30.0
|
70 |
+
ANTHROPIC_TIMEOUT: float = 30.0
|
71 |
+
|
72 |
+
# API Keys
|
73 |
+
OPENAI_API_KEY: str = OPENAI_API_KEY
|
74 |
+
ANTHROPIC_API_KEY: str = ANTHROPIC_API_KEY
|
75 |
+
|
76 |
+
def validate_api_keys(self):
|
77 |
+
"""Validate that required API keys are present."""
|
78 |
+
if not self.OPENAI_API_KEY:
|
79 |
+
raise ValueError("OPENAI_API_KEY is required")
|
80 |
+
if not self.ANTHROPIC_API_KEY:
|
81 |
+
raise ValueError("ANTHROPIC_API_KEY is required")
|
82 |
+
|
83 |
+
|
84 |
+
# Create a cached instance of settings
|
85 |
+
@lru_cache
|
86 |
+
def get_settings() -> Settings:
|
87 |
+
"""
|
88 |
+
Create and cache a Settings instance.
|
89 |
+
Returns the same instance for subsequent calls.
|
90 |
+
"""
|
91 |
+
settings = Settings()
|
92 |
+
settings.validate_api_keys()
|
93 |
+
return settings
|
app/core/__init__.py
ADDED
File without changes
|
app/core/errors.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
VENDOR_ERROR_INVALID_JSON = "Vendor Error: Invalid JSON data"
|
2 |
+
VENDOR_THROW_ERROR = "Vendor Error: {error_message}"
|
3 |
+
|
4 |
+
|
5 |
+
class VendorError(Exception):
|
6 |
+
def __init__(self, message: str):
|
7 |
+
super().__init__(message)
|
8 |
+
|
9 |
+
|
10 |
+
class BadRequestError(Exception):
|
11 |
+
def __init__(self, message: str):
|
12 |
+
super().__init__(message)
|
app/core/prompts.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from pydantic_settings import BaseSettings
|
5 |
+
|
6 |
+
EXTRACT_INFO_SYSTEM = "You are an expert at structured data extraction. You will be given an image of a product and should output the its properties into the given structure."
|
7 |
+
|
8 |
+
EXTRACT_INFO_HUMAN = (
|
9 |
+
"""Output properties of the {product_taxonomy} product in the images. You should use the following attributes to help you if it exists:
|
10 |
+
|
11 |
+
{product_data}
|
12 |
+
|
13 |
+
If an attribute is both in the image and the attributes, use the one in the attribute."""
|
14 |
+
).replace(" ", "")
|
15 |
+
|
16 |
+
FOLLOW_SCHEMA_SYSTEM = "You are an expert at structured data extraction. You will be given an dictionary of attributes of a product and should output the its properties into the given structure."
|
17 |
+
|
18 |
+
FOLLOW_SCHEMA_HUMAN = """Convert following attributes to structured schema. Keep all the keys and number of values. Only replace the values themselves. :
|
19 |
+
|
20 |
+
{json_info}"""
|
21 |
+
|
22 |
+
|
23 |
+
class Prompts(BaseSettings):
|
24 |
+
EXTRACT_INFO_SYSTEM_MESSAGE: str = EXTRACT_INFO_SYSTEM
|
25 |
+
|
26 |
+
EXTRACT_INFO_HUMAN_MESSAGE: str = EXTRACT_INFO_HUMAN
|
27 |
+
|
28 |
+
FOLLOW_SCHEMA_SYSTEM_MESSAGE: str = FOLLOW_SCHEMA_SYSTEM
|
29 |
+
|
30 |
+
FOLLOW_SCHEMA_HUMAN_MESSAGE: str = FOLLOW_SCHEMA_HUMAN
|
31 |
+
|
32 |
+
|
33 |
+
# Create a cached instance of settings
|
34 |
+
@lru_cache
|
35 |
+
def get_prompts() -> Prompts:
|
36 |
+
"""
|
37 |
+
Create and cache a Prompts instance.
|
38 |
+
Returns the same instance for subsequent calls.
|
39 |
+
"""
|
40 |
+
prompts = Prompts()
|
41 |
+
return prompts
|
app/core/security.py
ADDED
File without changes
|
app/request_handler/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from app.request_handler.extract_handler import handle_extract
|
2 |
+
from app.request_handler.follow_handler import handle_follow
|
app/request_handler/extract_handler.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import BytesIO
|
2 |
+
|
3 |
+
import requests
|
4 |
+
from fastapi import HTTPException
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
from app.config import get_settings
|
8 |
+
from app.core.errors import BadRequestError, VendorError
|
9 |
+
from app.schemas.requests import ExtractionRequest
|
10 |
+
from app.schemas.responses import APIResponse
|
11 |
+
from app.services.factory import AIServiceFactory
|
12 |
+
from app.utils.logger import setup_logger
|
13 |
+
|
14 |
+
logger = setup_logger(__name__)
|
15 |
+
settings = get_settings()
|
16 |
+
|
17 |
+
|
18 |
+
async def handle_extract(request: ExtractionRequest):
|
19 |
+
request.max_attempts = max(request.max_attempts, 1)
|
20 |
+
request.max_attempts = min(request.max_attempts, 5)
|
21 |
+
|
22 |
+
for attempt in range(1, request.max_attempts + 1):
|
23 |
+
try:
|
24 |
+
logger.info(f"Attempt: {attempt}")
|
25 |
+
if request.ai_model in settings.OPENAI_MODELS:
|
26 |
+
ai_vendor = "openai"
|
27 |
+
elif request.ai_model in settings.ANTHROPIC_MODELS:
|
28 |
+
ai_vendor = "anthropic"
|
29 |
+
else:
|
30 |
+
raise ValueError(
|
31 |
+
f"Invalid AI model: {request.ai_model}, only support {settings.SUPPORTED_MODELS}"
|
32 |
+
)
|
33 |
+
service = AIServiceFactory.get_service(ai_vendor)
|
34 |
+
|
35 |
+
pil_images = []
|
36 |
+
for url in request.img_urls:
|
37 |
+
try:
|
38 |
+
response = requests.get(url)
|
39 |
+
response.raise_for_status()
|
40 |
+
image = Image.open(BytesIO(response.content))
|
41 |
+
pil_images.append(image)
|
42 |
+
except Exception as e:
|
43 |
+
print(e)
|
44 |
+
logger.error(f"Failed to download or process image from {url}: {e}")
|
45 |
+
raise HTTPException(
|
46 |
+
status_code=400,
|
47 |
+
detail=f"Failed to process image from {url}",
|
48 |
+
headers={"attempt": attempt},
|
49 |
+
)
|
50 |
+
|
51 |
+
json_attributes = await service.extract_attributes_with_validation(
|
52 |
+
request.attributes,
|
53 |
+
request.ai_model,
|
54 |
+
request.img_urls,
|
55 |
+
request.product_taxonomy,
|
56 |
+
request.product_data,
|
57 |
+
pil_images=pil_images,
|
58 |
+
)
|
59 |
+
break
|
60 |
+
except BadRequestError as e:
|
61 |
+
logger.error("Bad request error: ", e)
|
62 |
+
raise HTTPException(
|
63 |
+
status_code=400, detail=str(e), headers={"attempt": attempt}
|
64 |
+
)
|
65 |
+
except ValueError as e:
|
66 |
+
logger.error("Value error: ", e)
|
67 |
+
raise HTTPException(
|
68 |
+
status_code=400, detail=str(e), headers={"attempt": attempt}
|
69 |
+
)
|
70 |
+
except VendorError as e:
|
71 |
+
logger.error("Vendor error: ", e)
|
72 |
+
if attempt == request.max_attempts:
|
73 |
+
raise HTTPException(
|
74 |
+
status_code=500, detail=str(e), headers={"attempt": attempt}
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
if request.ai_model in settings.ANTHROPIC_MODELS:
|
78 |
+
request.ai_model = settings.OPENAI_MODELS[
|
79 |
+
0
|
80 |
+
] # switch to OpenAI, and try again if max_attempts not reached
|
81 |
+
logger.info(
|
82 |
+
f"Switching from anthropic to {request.ai_model} for attempt {attempt + 1}"
|
83 |
+
)
|
84 |
+
elif request.ai_model in settings.OPENAI_MODELS:
|
85 |
+
request.ai_model = settings.ANTHROPIC_MODELS[
|
86 |
+
0
|
87 |
+
] # switch to anthropic, and try again if max_attempts not reached
|
88 |
+
logger.info(
|
89 |
+
f"Switching from OpenAI to {request.ai_model} for attempt {attempt + 1}"
|
90 |
+
)
|
91 |
+
|
92 |
+
except HTTPException as e:
|
93 |
+
logger.error("HTTP exception: ", e)
|
94 |
+
raise e
|
95 |
+
except Exception as e:
|
96 |
+
logger.error("Exception: ", e)
|
97 |
+
if (
|
98 |
+
"overload" in str(e).lower()
|
99 |
+
and request.ai_model in settings.ANTHROPIC_MODELS
|
100 |
+
):
|
101 |
+
request.ai_model = settings.OPENAI_MODELS[
|
102 |
+
0
|
103 |
+
] # switch to OpenAI, and try again if max_attempts not reached
|
104 |
+
if attempt == request.max_attempts:
|
105 |
+
raise HTTPException(
|
106 |
+
status_code=500,
|
107 |
+
detail="Internal server error",
|
108 |
+
headers={"attempt": attempt},
|
109 |
+
)
|
110 |
+
|
111 |
+
return json_attributes, attempt
|
app/request_handler/follow_handler.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, HTTPException
|
2 |
+
|
3 |
+
from app.config import get_settings
|
4 |
+
from app.core.errors import VendorError
|
5 |
+
from app.schemas.requests import FollowSchemaRequest
|
6 |
+
from app.services.factory import AIServiceFactory
|
7 |
+
from app.utils.logger import setup_logger
|
8 |
+
|
9 |
+
logger = setup_logger(__name__)
|
10 |
+
settings = get_settings()
|
11 |
+
|
12 |
+
|
13 |
+
async def handle_follow(request: FollowSchemaRequest):
|
14 |
+
|
15 |
+
request.max_attempts = max(request.max_attempts, 1)
|
16 |
+
request.max_attempts = min(request.max_attempts, 5)
|
17 |
+
|
18 |
+
for attempt in range(1, request.max_attempts + 1):
|
19 |
+
try:
|
20 |
+
logger.info(f"Attempt: {attempt}")
|
21 |
+
if request.ai_model in settings.OPENAI_MODELS:
|
22 |
+
ai_vendor = "openai"
|
23 |
+
elif request.ai_model in settings.ANTHROPIC_MODELS:
|
24 |
+
ai_vendor = "anthropic"
|
25 |
+
else:
|
26 |
+
raise ValueError(
|
27 |
+
f"Invalid AI model: {request.ai_model}, only support {settings.SUPPORTED_MODELS}"
|
28 |
+
)
|
29 |
+
service = AIServiceFactory.get_service(ai_vendor)
|
30 |
+
json_attributes = await service.follow_schema_with_validation(
|
31 |
+
request.data_schema, request.data
|
32 |
+
)
|
33 |
+
break
|
34 |
+
except ValueError as e:
|
35 |
+
if attempt == request.max_attempts:
|
36 |
+
raise HTTPException(
|
37 |
+
status_code=400, detail=str(e), headers={"attempt": attempt}
|
38 |
+
)
|
39 |
+
except VendorError as e:
|
40 |
+
if attempt == request.max_attempts:
|
41 |
+
raise HTTPException(
|
42 |
+
status_code=500, detail=str(e), headers={"attempt": attempt}
|
43 |
+
)
|
44 |
+
except Exception as e:
|
45 |
+
if attempt == request.max_attempts:
|
46 |
+
raise HTTPException(
|
47 |
+
status_code=500,
|
48 |
+
detail="Internal server error",
|
49 |
+
headers={"attempt": attempt},
|
50 |
+
)
|
51 |
+
|
52 |
+
return json_attributes, attempt
|
app/request_handler/validate.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app.config import get_settings
|
2 |
+
from app.schemas.requests import ExtractionRequest, FollowSchemaRequest
|
3 |
+
from app.schemas.schema_tools import validate_json_schema
|
4 |
+
from app.utils.logger import setup_logger
|
5 |
+
|
6 |
+
logger = setup_logger(__name__)
|
7 |
+
settings = get_settings()
|
8 |
+
|
9 |
+
|
10 |
+
def validate_extract_request(request: ExtractionRequest):
|
11 |
+
"""Validate the request to extract attributes."""
|
12 |
+
request.max_attempts = max(request.max_attempts, 1)
|
13 |
+
request.max_attempts = min(request.max_attempts, 5)
|
14 |
+
|
15 |
+
# Limit the number of images to 10
|
16 |
+
if len(request.img_urls) > 10:
|
17 |
+
logger.warning(
|
18 |
+
f"Number of images exceeds 10: {len(request.img_urls)}. Limiting to 10."
|
19 |
+
)
|
20 |
+
request.img_urls = request.img_urls[:10]
|
21 |
+
|
22 |
+
for url in request.img_urls:
|
23 |
+
if not url.startswith("http"):
|
24 |
+
raise ValueError(f"Invalid URL: {url}")
|
25 |
+
|
26 |
+
# validate_json_schema(request.data_schema)
|
27 |
+
|
28 |
+
if request.ai_model.lower() not in settings.SUPPORTED_MODELS:
|
29 |
+
raise ValueError(
|
30 |
+
f"Invalid ai_model: {request.ai_model}, only support {settings.SUPPORTED_MODELS}"
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
def validate_follow_request(request: FollowSchemaRequest):
|
35 |
+
"""Validate the request to follow a schema."""
|
36 |
+
request.max_attempts = max(request.max_attempts, 1)
|
37 |
+
request.max_attempts = min(request.max_attempts, 5)
|
38 |
+
|
39 |
+
validate_json_schema(request.data_schema)
|
40 |
+
|
41 |
+
if request.ai_model.lower() not in settings.SUPPORTED_MODELS:
|
42 |
+
raise ValueError(
|
43 |
+
f"Invalid ai_model: {request.ai_model}, only support {settings.SUPPORTED_MODELS}"
|
44 |
+
)
|
app/schemas/__init__.py
ADDED
File without changes
|
app/schemas/requests.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Optional
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
from app.config import get_settings
|
6 |
+
|
7 |
+
settings = get_settings()
|
8 |
+
|
9 |
+
|
10 |
+
class Attribute(BaseModel):
|
11 |
+
description: str
|
12 |
+
data_type: str
|
13 |
+
allowed_values: Optional[List[str]] = []
|
14 |
+
|
15 |
+
|
16 |
+
class ExtractionRequest(BaseModel):
|
17 |
+
attributes: Dict[str, Attribute]
|
18 |
+
img_urls: Optional[List[str]] = None
|
19 |
+
product_taxonomy: str
|
20 |
+
request_meta: Optional[Dict[str, str]] = None
|
21 |
+
product_data: Optional[Dict[str, str]] = None
|
22 |
+
ai_model: str = settings.DEFAULT_MODEL # type: ignore
|
23 |
+
max_attempts: int = settings.DEFAULT_MAX_ATTEMPTS # type: ignore
|
24 |
+
|
25 |
+
|
26 |
+
class FollowSchemaRequest(BaseModel):
|
27 |
+
data_schema: Dict[str, Any]
|
28 |
+
data: Dict[str, Any]
|
29 |
+
request_meta: Optional[Dict[str, str]] = None
|
30 |
+
ai_model: str = settings.DEFAULT_MODEL
|
31 |
+
max_attempts: int = settings.DEFAULT_MAX_ATTEMPTS # type: ignore
|
32 |
+
|
33 |
+
|
34 |
+
class ResultRequest(BaseModel):
|
35 |
+
task_id: str
|
app/schemas/responses.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, Optional
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class SubmitResponse(BaseModel):
|
7 |
+
task_id: str
|
8 |
+
|
9 |
+
|
10 |
+
class ResultResponse(BaseModel):
|
11 |
+
request_meta: Optional[Dict[str, str]] = None
|
12 |
+
task_id: str
|
13 |
+
result: dict
|
14 |
+
status_code: int
|
15 |
+
detail: str
|
16 |
+
attempt: int
|
17 |
+
|
18 |
+
|
19 |
+
class HealthCheckResponse(BaseModel):
|
20 |
+
status: str
|
21 |
+
|
22 |
+
|
23 |
+
class APIResponse(BaseModel):
|
24 |
+
detail: str
|
25 |
+
data: Dict[str, Any]
|
26 |
+
attempts: int
|
27 |
+
|
28 |
+
|
29 |
+
class APIErrorResponse(BaseModel):
|
30 |
+
detail: str
|
31 |
+
|
32 |
+
|
33 |
+
HEALTH_CHECK_RESPONSES = {}
|
34 |
+
|
35 |
+
SUBMIT_EXTRACT_RESPONSES = {
|
36 |
+
400: {
|
37 |
+
"model": APIErrorResponse,
|
38 |
+
},
|
39 |
+
500: {"model": APIErrorResponse},
|
40 |
+
}
|
41 |
+
|
42 |
+
SUBMIT_FOLLOW_RESPONSES = {
|
43 |
+
400: {
|
44 |
+
"model": APIErrorResponse,
|
45 |
+
},
|
46 |
+
500: {"model": APIErrorResponse},
|
47 |
+
}
|
48 |
+
|
49 |
+
RESULT_RESPONSES = {
|
50 |
+
400: {
|
51 |
+
"model": APIErrorResponse,
|
52 |
+
},
|
53 |
+
404: {
|
54 |
+
"model": APIErrorResponse,
|
55 |
+
},
|
56 |
+
500: {"model": APIErrorResponse},
|
57 |
+
}
|
58 |
+
|
59 |
+
RESPONSES = {
|
60 |
+
400: {
|
61 |
+
"model": APIErrorResponse,
|
62 |
+
},
|
63 |
+
404: {
|
64 |
+
"model": APIErrorResponse,
|
65 |
+
},
|
66 |
+
500: {"model": APIErrorResponse},
|
67 |
+
}
|
app/schemas/schema_tools.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum # do not remove this import for exec
|
2 |
+
from typing import List # do not remove this import for exec
|
3 |
+
from typing import Any, Dict
|
4 |
+
|
5 |
+
import jsonschema
|
6 |
+
from jsf import JSF
|
7 |
+
from pydantic import BaseModel, Field # do not remove this import for exec
|
8 |
+
|
9 |
+
from app.core.errors import VendorError
|
10 |
+
from app.schemas.requests import Attribute
|
11 |
+
|
12 |
+
|
13 |
+
def validate_json_data(data: Dict[str, Any], schema: Dict[str, Any]):
|
14 |
+
"""
|
15 |
+
Standalone JSON schema validation utility
|
16 |
+
"""
|
17 |
+
try:
|
18 |
+
jsonschema.validate(instance=data, schema=schema)
|
19 |
+
except jsonschema.ValidationError as e:
|
20 |
+
raise VendorError(f"Vendor generated invalid data: {e}")
|
21 |
+
|
22 |
+
|
23 |
+
def validate_json_schema(schema: Dict[str, Any]):
|
24 |
+
"""
|
25 |
+
Standalone JSON schema validation utility
|
26 |
+
"""
|
27 |
+
if schema == {}:
|
28 |
+
raise ValueError(f"JSON Schema validation failed")
|
29 |
+
|
30 |
+
try:
|
31 |
+
faker = JSF(schema)
|
32 |
+
_ = faker.generate()
|
33 |
+
except:
|
34 |
+
raise ValueError(f"JSON Schema validation failed")
|
35 |
+
|
36 |
+
|
37 |
+
SUPPORTED_DATA_TYPE = [
|
38 |
+
"string",
|
39 |
+
"int",
|
40 |
+
"float",
|
41 |
+
"bool",
|
42 |
+
"list[string]",
|
43 |
+
"list[int]",
|
44 |
+
"list[float]",
|
45 |
+
"list[bool]",
|
46 |
+
]
|
47 |
+
|
48 |
+
|
49 |
+
def convert_attribute_to_model(attributes: Dict[str, Attribute]) -> Dict[str, Any]:
|
50 |
+
import_code = ""
|
51 |
+
enum_code_list = []
|
52 |
+
master_class_code = "class Product(BaseModel):\n"
|
53 |
+
for key, value in attributes.items():
|
54 |
+
description = value.description
|
55 |
+
data_type = value.data_type
|
56 |
+
allowed_values = value.allowed_values
|
57 |
+
is_list = False
|
58 |
+
|
59 |
+
if data_type not in SUPPORTED_DATA_TYPE:
|
60 |
+
raise ValueError(f"Data type {data_type} is not supported")
|
61 |
+
|
62 |
+
if "list" in data_type:
|
63 |
+
is_list = True
|
64 |
+
|
65 |
+
if "int" in data_type:
|
66 |
+
data_type = "int"
|
67 |
+
elif "float" in data_type:
|
68 |
+
data_type = "float"
|
69 |
+
elif "bool" in data_type:
|
70 |
+
data_type = "bool"
|
71 |
+
elif "string" in data_type:
|
72 |
+
data_type = "str"
|
73 |
+
|
74 |
+
if len(allowed_values) > 0:
|
75 |
+
enum_code = f"class {key.capitalize()}Enum(str, Enum):\n"
|
76 |
+
for allowed_value in allowed_values:
|
77 |
+
enum_code += f" {allowed_value.replace(' ', '_').replace('-', '_').upper()} = '{allowed_value}'\n"
|
78 |
+
enum_code_list.append(enum_code)
|
79 |
+
data_type = f"{key.capitalize()}Enum"
|
80 |
+
|
81 |
+
if is_list:
|
82 |
+
data_type = f"List[{data_type}]"
|
83 |
+
|
84 |
+
master_class_code += (
|
85 |
+
f" {key}: {data_type} = Field(..., description='{description}')\n"
|
86 |
+
)
|
87 |
+
|
88 |
+
entire_code = import_code + "\n".join(enum_code_list) + "\n" + master_class_code
|
89 |
+
exec(entire_code, globals())
|
90 |
+
|
91 |
+
return Product # type: ignore
|
app/services/__init__.py
ADDED
File without changes
|
app/services/base.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Any, Dict, List, Type
|
3 |
+
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
from app.schemas.schema_tools import (
|
7 |
+
convert_attribute_to_model,
|
8 |
+
validate_json_data,
|
9 |
+
validate_json_schema,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
class BaseAttributionService(ABC):
|
14 |
+
@abstractmethod
|
15 |
+
async def extract_attributes(
|
16 |
+
self,
|
17 |
+
attributes_model: Type[BaseModel],
|
18 |
+
ai_model: str,
|
19 |
+
img_urls: List[str],
|
20 |
+
product_taxonomy: str,
|
21 |
+
pil_images: List[Any] = None,
|
22 |
+
) -> Dict[str, Any]:
|
23 |
+
pass
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
async def follow_schema(
|
27 |
+
self, schema: Dict[str, Any], data: Dict[str, Any]
|
28 |
+
) -> Dict[str, Any]:
|
29 |
+
pass
|
30 |
+
|
31 |
+
async def extract_attributes_with_validation(
|
32 |
+
self,
|
33 |
+
attributes: Dict[str, Any],
|
34 |
+
ai_model: str,
|
35 |
+
img_urls: List[str],
|
36 |
+
product_taxonomy: str,
|
37 |
+
product_data: Dict[str, str],
|
38 |
+
pil_images: List[Any] = None,
|
39 |
+
img_paths: List[str] = None,
|
40 |
+
) -> Dict[str, Any]:
|
41 |
+
# validate_json_schema(schema)
|
42 |
+
attributes_model = convert_attribute_to_model(attributes)
|
43 |
+
schema = attributes_model.model_json_schema()
|
44 |
+
data = await self.extract_attributes(
|
45 |
+
attributes_model,
|
46 |
+
ai_model,
|
47 |
+
img_urls,
|
48 |
+
product_taxonomy,
|
49 |
+
product_data,
|
50 |
+
# pil_images=pil_images, # temporarily removed for save cost
|
51 |
+
img_paths=img_paths,
|
52 |
+
)
|
53 |
+
validate_json_data(data, schema)
|
54 |
+
return data
|
55 |
+
|
56 |
+
async def follow_schema_with_validation(
|
57 |
+
self, schema: Dict[str, Any], data: Dict[str, Any]
|
58 |
+
) -> Dict[str, Any]:
|
59 |
+
validate_json_schema(schema)
|
60 |
+
data = await self.follow_schema(schema, data)
|
61 |
+
validate_json_data(data, schema)
|
62 |
+
return data
|
app/services/factory.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Type
|
2 |
+
|
3 |
+
from ..config import get_settings
|
4 |
+
from .base import BaseAttributionService
|
5 |
+
from .service_anthropic import AnthropicService
|
6 |
+
from .service_openai import OpenAIService
|
7 |
+
|
8 |
+
settings = get_settings()
|
9 |
+
|
10 |
+
|
11 |
+
class AIServiceFactory:
|
12 |
+
_services = {"openai": OpenAIService, "anthropic": AnthropicService}
|
13 |
+
|
14 |
+
@classmethod
|
15 |
+
def get_service(cls, ai_vendor: str = None) -> BaseAttributionService:
|
16 |
+
ai_vendor = ai_vendor or settings.DEFAULT_VENDOR
|
17 |
+
service_class = cls._services.get(ai_vendor.lower())
|
18 |
+
if not service_class:
|
19 |
+
raise ValueError(f"Unsupported ai_vendor: {ai_vendor}")
|
20 |
+
return service_class()
|
app/services/service_anthropic.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from typing import Any, Dict, List, Type
|
4 |
+
|
5 |
+
import anthropic
|
6 |
+
import weave
|
7 |
+
from anthropic import APIStatusError, AsyncAnthropic
|
8 |
+
from pydantic import BaseModel
|
9 |
+
|
10 |
+
from app.config import get_settings
|
11 |
+
from app.core import errors
|
12 |
+
from app.core.errors import BadRequestError, VendorError
|
13 |
+
from app.core.prompts import get_prompts
|
14 |
+
from app.services.base import BaseAttributionService
|
15 |
+
from app.utils.converter import product_data_to_str
|
16 |
+
from app.utils.image_processing import get_data_format, get_image_data
|
17 |
+
from app.utils.logger import setup_logger
|
18 |
+
|
19 |
+
deployment = os.getenv("DEPLOYMENT", "LOCAL")
|
20 |
+
if deployment == "LOCAL": # local or demo
|
21 |
+
weave_project_name = "cfai/attribution-exp"
|
22 |
+
elif deployment == "DEV":
|
23 |
+
weave_project_name = "cfai/attribution-dev"
|
24 |
+
elif deployment == "PROD":
|
25 |
+
weave_project_name = "cfai/attribution-prod"
|
26 |
+
|
27 |
+
weave.init(project_name=weave_project_name)
|
28 |
+
settings = get_settings()
|
29 |
+
prompts = get_prompts()
|
30 |
+
logger = setup_logger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
class AnthropicService(BaseAttributionService):
|
34 |
+
def __init__(self):
|
35 |
+
self.client = AsyncAnthropic(api_key=settings.ANTHROPIC_API_KEY)
|
36 |
+
|
37 |
+
@weave.op
|
38 |
+
async def extract_attributes(
|
39 |
+
self,
|
40 |
+
attributes_model: Type[BaseModel],
|
41 |
+
ai_model: str,
|
42 |
+
img_urls: List[str],
|
43 |
+
product_taxonomy: str,
|
44 |
+
product_data: Dict[str, str],
|
45 |
+
pil_images: List[Any] = None, # do not remove, this is for weave
|
46 |
+
img_paths: List[str] = None,
|
47 |
+
) -> Dict[str, Any]:
|
48 |
+
logger.info("Extracting info via Anthropic...")
|
49 |
+
tools = [
|
50 |
+
{
|
51 |
+
"name": "extract_garment_info",
|
52 |
+
"description": "Extracts key information from the image.",
|
53 |
+
"input_schema": attributes_model.model_json_schema(),
|
54 |
+
"cache_control": {"type": "ephemeral"},
|
55 |
+
}
|
56 |
+
]
|
57 |
+
|
58 |
+
if img_urls is not None:
|
59 |
+
image_messages = [
|
60 |
+
{
|
61 |
+
"type": "image",
|
62 |
+
"source": {"type": "url", "url": img_url},
|
63 |
+
}
|
64 |
+
for img_url in img_urls
|
65 |
+
]
|
66 |
+
elif img_paths is not None:
|
67 |
+
image_messages = [
|
68 |
+
{
|
69 |
+
"type": "image",
|
70 |
+
"source": {
|
71 |
+
"type": "base64",
|
72 |
+
"media_type": f"image/{get_data_format(img_path)}",
|
73 |
+
"data": get_image_data(img_path),
|
74 |
+
},
|
75 |
+
}
|
76 |
+
for img_path in img_paths
|
77 |
+
]
|
78 |
+
else:
|
79 |
+
# this is not expected, raise some errors here later.
|
80 |
+
pass
|
81 |
+
|
82 |
+
system_message = [{"type": "text", "text": prompts.EXTRACT_INFO_SYSTEM_MESSAGE}]
|
83 |
+
|
84 |
+
text_messages = [
|
85 |
+
{
|
86 |
+
"type": "text",
|
87 |
+
"text": prompts.EXTRACT_INFO_HUMAN_MESSAGE.format(
|
88 |
+
product_taxonomy=product_taxonomy,
|
89 |
+
product_data=product_data_to_str(product_data),
|
90 |
+
),
|
91 |
+
}
|
92 |
+
]
|
93 |
+
|
94 |
+
messages = [{"role": "user", "content": text_messages + image_messages}]
|
95 |
+
|
96 |
+
# try:
|
97 |
+
try:
|
98 |
+
response = await self.client.messages.create(
|
99 |
+
model=ai_model,
|
100 |
+
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
|
101 |
+
max_tokens=2048,
|
102 |
+
system=system_message,
|
103 |
+
tools=tools,
|
104 |
+
messages=messages,
|
105 |
+
)
|
106 |
+
except anthropic.BadRequestError as e:
|
107 |
+
raise BadRequestError(e.message)
|
108 |
+
except Exception as e:
|
109 |
+
raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=str(e)))
|
110 |
+
|
111 |
+
for content in response.content:
|
112 |
+
if content.type == "tool_use":
|
113 |
+
return content.input
|
114 |
+
|
115 |
+
@weave.op
|
116 |
+
async def follow_schema(self, schema, data):
|
117 |
+
logger.info("Following structure via Anthropic...")
|
118 |
+
tools = [
|
119 |
+
{
|
120 |
+
"name": "extract_garment_info",
|
121 |
+
"description": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE,
|
122 |
+
"input_schema": schema,
|
123 |
+
"cache_control": {"type": "ephemeral"},
|
124 |
+
}
|
125 |
+
]
|
126 |
+
|
127 |
+
text_messages = [
|
128 |
+
{
|
129 |
+
"type": "text",
|
130 |
+
"text": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data),
|
131 |
+
}
|
132 |
+
]
|
133 |
+
|
134 |
+
system_message = [
|
135 |
+
{"type": "text", "text": prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE}
|
136 |
+
]
|
137 |
+
|
138 |
+
messages = [{"role": "user", "content": text_messages}]
|
139 |
+
try:
|
140 |
+
response = await self.client.messages.create(
|
141 |
+
model=settings.ANTHROPIC_DEFAULT_MODEL,
|
142 |
+
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
|
143 |
+
max_tokens=2048,
|
144 |
+
system=system_message,
|
145 |
+
tools=tools,
|
146 |
+
messages=messages,
|
147 |
+
)
|
148 |
+
except Exception as e:
|
149 |
+
raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=str(e)))
|
150 |
+
|
151 |
+
for content in response.content:
|
152 |
+
if content.type == "tool_use":
|
153 |
+
return content.input["json_info"]
|
154 |
+
|
155 |
+
return {"status": "ERROR: no tool_use found"}
|
app/services/service_openai.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from typing import Any, Dict, List, Type
|
4 |
+
|
5 |
+
import openai
|
6 |
+
import weave
|
7 |
+
from openai import AsyncOpenAI
|
8 |
+
from pydantic import BaseModel
|
9 |
+
|
10 |
+
from app.utils.converter import product_data_to_str
|
11 |
+
from app.utils.image_processing import get_data_format, get_image_data
|
12 |
+
from app.utils.logger import setup_logger
|
13 |
+
|
14 |
+
from ..config import get_settings
|
15 |
+
from ..core import errors
|
16 |
+
from ..core.errors import BadRequestError, VendorError
|
17 |
+
from ..core.prompts import get_prompts
|
18 |
+
from .base import BaseAttributionService
|
19 |
+
|
20 |
+
deployment = os.getenv("DEPLOYMENT", "LOCAL")
|
21 |
+
if deployment == "LOCAL": # local or demo
|
22 |
+
weave_project_name = "cfai/attribution-exp"
|
23 |
+
elif deployment == "DEV":
|
24 |
+
weave_project_name = "cfai/attribution-dev"
|
25 |
+
elif deployment == "PROD":
|
26 |
+
weave_project_name = "cfai/attribution-prod"
|
27 |
+
|
28 |
+
weave.init(project_name=weave_project_name)
|
29 |
+
settings = get_settings()
|
30 |
+
prompts = get_prompts()
|
31 |
+
logger = setup_logger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
def get_response_format(json_schema: dict[str, any]) -> dict[str, any]:
|
35 |
+
# OpenAI requires each $def have to have additionalProperties set to False
|
36 |
+
json_schema["additionalProperties"] = False
|
37 |
+
|
38 |
+
# check if the schema has a $defs key
|
39 |
+
if "$defs" in json_schema:
|
40 |
+
for keys in json_schema["$defs"].keys():
|
41 |
+
json_schema["$defs"][keys]["additionalProperties"] = False
|
42 |
+
response_format = {
|
43 |
+
"type": "json_schema",
|
44 |
+
"json_schema": {"strict": True, "name": "GarmentSchema", "schema": json_schema},
|
45 |
+
}
|
46 |
+
|
47 |
+
return response_format
|
48 |
+
|
49 |
+
|
50 |
+
class OpenAIService(BaseAttributionService):
|
51 |
+
def __init__(self):
|
52 |
+
self.client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
53 |
+
|
54 |
+
@weave.op
|
55 |
+
async def extract_attributes(
|
56 |
+
self,
|
57 |
+
attributes_model: Type[BaseModel],
|
58 |
+
ai_model: str,
|
59 |
+
img_urls: List[str],
|
60 |
+
product_taxonomy: str,
|
61 |
+
product_data: Dict[str, str],
|
62 |
+
pil_images: List[Any] = None, # do not remove, this is for weave
|
63 |
+
img_paths: List[str] = None,
|
64 |
+
) -> Dict[str, Any]:
|
65 |
+
logger.info("Extracting info via OpenAI...")
|
66 |
+
text_content = [
|
67 |
+
{
|
68 |
+
"type": "text",
|
69 |
+
"text": prompts.EXTRACT_INFO_HUMAN_MESSAGE.format(
|
70 |
+
product_taxonomy=product_taxonomy,
|
71 |
+
product_data=product_data_to_str(product_data),
|
72 |
+
),
|
73 |
+
},
|
74 |
+
]
|
75 |
+
if img_urls is not None:
|
76 |
+
image_content = [
|
77 |
+
{
|
78 |
+
"type": "image_url",
|
79 |
+
"image_url": {
|
80 |
+
"url": img_url,
|
81 |
+
},
|
82 |
+
}
|
83 |
+
for img_url in img_urls
|
84 |
+
]
|
85 |
+
elif img_paths is not None:
|
86 |
+
image_content = [
|
87 |
+
{
|
88 |
+
"type": "image_url",
|
89 |
+
"image_url": {
|
90 |
+
"url": f"data:image/{get_data_format(img_path)};base64,{get_image_data(img_path)}",
|
91 |
+
},
|
92 |
+
}
|
93 |
+
for img_path in img_paths
|
94 |
+
]
|
95 |
+
|
96 |
+
try:
|
97 |
+
response = await self.client.beta.chat.completions.parse(
|
98 |
+
model=ai_model,
|
99 |
+
messages=[
|
100 |
+
{
|
101 |
+
"role": "system",
|
102 |
+
"content": prompts.EXTRACT_INFO_SYSTEM_MESSAGE,
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"role": "user",
|
106 |
+
"content": text_content + image_content,
|
107 |
+
},
|
108 |
+
],
|
109 |
+
max_tokens=1000,
|
110 |
+
response_format=attributes_model,
|
111 |
+
logprobs=False,
|
112 |
+
# top_logprobs=2,
|
113 |
+
temperature=0.0,
|
114 |
+
)
|
115 |
+
except openai.BadRequestError as e:
|
116 |
+
raise BadRequestError(str(e))
|
117 |
+
except Exception as e:
|
118 |
+
raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=str(e)))
|
119 |
+
|
120 |
+
try:
|
121 |
+
content = response.choices[0].message.content
|
122 |
+
parsed_data = json.loads(content)
|
123 |
+
except:
|
124 |
+
raise VendorError(errors.VENDOR_ERROR_INVALID_JSON)
|
125 |
+
|
126 |
+
return parsed_data
|
127 |
+
|
128 |
+
@weave.op
|
129 |
+
async def follow_schema(
|
130 |
+
self, schema: Dict[str, Any], data: Dict[str, Any]
|
131 |
+
) -> Dict[str, Any]:
|
132 |
+
logger.info("Following structure via OpenAI...")
|
133 |
+
text_content = [
|
134 |
+
{
|
135 |
+
"type": "text",
|
136 |
+
"text": prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data),
|
137 |
+
},
|
138 |
+
]
|
139 |
+
|
140 |
+
try:
|
141 |
+
response = await self.client.beta.chat.completions.parse(
|
142 |
+
model="gpt-4o-2024-11-20",
|
143 |
+
messages=[
|
144 |
+
{
|
145 |
+
"role": "system",
|
146 |
+
"content": prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE,
|
147 |
+
},
|
148 |
+
{
|
149 |
+
"role": "user",
|
150 |
+
"content": text_content,
|
151 |
+
},
|
152 |
+
],
|
153 |
+
max_tokens=1000,
|
154 |
+
response_format=get_response_format(schema),
|
155 |
+
logprobs=False,
|
156 |
+
# top_logprobs=2,
|
157 |
+
temperature=0.0,
|
158 |
+
)
|
159 |
+
except Exception as e:
|
160 |
+
raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=str(e)))
|
161 |
+
|
162 |
+
if response.choices[0].message.refusal:
|
163 |
+
logger.info("OpenAI refused to respond to the request")
|
164 |
+
return {"status": "refused"}
|
165 |
+
|
166 |
+
try:
|
167 |
+
content = response.choices[0].message.content
|
168 |
+
parsed_data = json.loads(content)
|
169 |
+
except:
|
170 |
+
raise ValueError(errors.VENDOR_ERROR_INVALID_JSON)
|
171 |
+
|
172 |
+
return parsed_data
|
app/utils/__init__.py
ADDED
File without changes
|
app/utils/converter.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def product_data_to_str(product_data: dict[str, any]) -> str:
|
2 |
+
"""
|
3 |
+
Convert product data to a string.
|
4 |
+
|
5 |
+
Args:
|
6 |
+
- product_data: a dictionary of product data
|
7 |
+
|
8 |
+
Returns:
|
9 |
+
- a string representation of the product data
|
10 |
+
"""
|
11 |
+
if product_data is None:
|
12 |
+
return ""
|
13 |
+
|
14 |
+
return "\n".join([f"{k}: {v}" for k, v in product_data.items()])
|
app/utils/image_processing.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
|
3 |
+
|
4 |
+
def get_image_data(image_path):
|
5 |
+
with open(image_path, "rb") as f:
|
6 |
+
image_data = base64.b64encode(f.read()).decode("utf-8")
|
7 |
+
return image_data
|
8 |
+
|
9 |
+
|
10 |
+
def get_data_format(image_path):
|
11 |
+
image_format = image_path.split(".")[-1]
|
12 |
+
if image_format == "jpg":
|
13 |
+
image_format = "jpeg"
|
14 |
+
return image_format
|
app/utils/logger.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from logging.handlers import RotatingFileHandler
|
4 |
+
|
5 |
+
|
6 |
+
# Configure logger
|
7 |
+
def setup_logger(name: str) -> logging.Logger:
|
8 |
+
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
9 |
+
log_file = os.getenv("LOG_FILE", "app.log")
|
10 |
+
max_bytes = int(os.getenv("LOG_MAX_BYTES", 10 * 1024 * 1024)) # 10 MB
|
11 |
+
backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5))
|
12 |
+
|
13 |
+
logger = logging.getLogger(name)
|
14 |
+
logger.setLevel(log_level)
|
15 |
+
|
16 |
+
# Console handler
|
17 |
+
console_handler = logging.StreamHandler()
|
18 |
+
console_handler.setLevel(log_level)
|
19 |
+
console_formatter = logging.Formatter(
|
20 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
21 |
+
)
|
22 |
+
console_handler.setFormatter(console_formatter)
|
23 |
+
|
24 |
+
# Rotating file handler
|
25 |
+
file_handler = RotatingFileHandler(
|
26 |
+
log_file, maxBytes=max_bytes, backupCount=backup_count
|
27 |
+
)
|
28 |
+
file_handler.setLevel(log_level)
|
29 |
+
file_formatter = logging.Formatter(
|
30 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
31 |
+
)
|
32 |
+
file_handler.setFormatter(file_formatter)
|
33 |
+
|
34 |
+
# Add handlers
|
35 |
+
logger.addHandler(console_handler)
|
36 |
+
logger.addHandler(file_handler)
|
37 |
+
logger.propagate = False
|
38 |
+
|
39 |
+
return logger
|
app/utils/rate_limiter.py
ADDED
File without changes
|
app/utils/token_counter.py
ADDED
File without changes
|
clean_for_gradio.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
rm -rf app/api
|
2 |
+
rm -rf app/aws
|
3 |
+
rm app/main.py
|
4 |
+
rm worker.py
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi==0.115.6
|
2 |
+
fastapi-cli==0.0.7
|
3 |
+
pydantic==2.7.4
|
4 |
+
pydantic_settings==2.7.0
|
5 |
+
openai
|
6 |
+
anthropic==0.42.0
|
7 |
+
Pillow==11.0.0
|
8 |
+
requests==2.32.3
|
9 |
+
jsonschema==4.23.0
|
10 |
+
jsf==0.11.2
|
11 |
+
pytest==8.3.4
|
12 |
+
boto3==1.35.87
|
13 |
+
redis==5.2.1
|
14 |
+
weave==0.51.39
|
15 |
+
gradio==5.22.0
|