chips commited on
Commit
ede43a0
·
1 Parent(s): ca939d7

adding gemini

Browse files
Files changed (2) hide show
  1. app/services/service_gemini.py +375 -0
  2. requirements.txt +2 -1
app/services/service_gemini.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import io
4
+ import base64
5
+ from typing import Any, Dict, List, Type, Union, Optional
6
+
7
+ import google.generativeai as genai
8
+ from google.generativeai.types import GenerationConfig, HarmCategory, HarmBlockThreshold # For safety settings
9
+ import weave # Assuming weave is still used
10
+ from pydantic import BaseModel, ValidationError # For schema validation
11
+
12
+ # Assuming these utilities are in the same relative paths or accessible
13
+ from app.utils.converter import product_data_to_str
14
+ from app.utils.image_processing import (
15
+ get_data_format, # Assuming this returns 'jpeg', 'png' etc.
16
+ get_image_base64_and_type, # Assuming this fetches URL and returns (base64_str, type_str)
17
+ get_image_data, # Assuming this reads local path and returns base64_str
18
+ )
19
+ from app.utils.logger import exception_to_str, setup_logger
20
+
21
+ # Assuming these are correctly defined and accessible
22
+ from ..config import get_settings
23
+ from ..core import errors
24
+ from ..core.errors import BadRequestError, VendorError # Using your custom errors
25
+ from ..core.prompts import get_prompts # Assuming prompts are compatible or adapted
26
+ from .base import BaseAttributionService # Assuming this base class exists
27
+
28
+ # Environment and Weave setup ( 그대로 유지 )
29
+ ENV = os.getenv("ENV", "LOCAL")
30
+ if ENV == "LOCAL":
31
+ weave_project_name = "cfai/attribution-exp"
32
+ elif ENV == "DEV":
33
+ weave_project_name = "cfai/attribution-dev"
34
+ elif ENV == "UAT":
35
+ weave_project_name = "cfai/attribution-uat"
36
+ elif ENV == "PROD":
37
+ pass # No weave for PROD
38
+
39
+ if ENV != "PROD":
40
+ # weave.init(project_name=weave_project_name) # Assuming weave.init() is called elsewhere or if needed here
41
+ print(f"Weave project name (potentially initialized elsewhere): {weave_project_name}")
42
+
43
+ settings = get_settings()
44
+ prompts = get_prompts()
45
+ logger = setup_logger(__name__)
46
+
47
+ # Configure the Gemini client
48
+ try:
49
+ if settings.GEMINI_API_KEY:
50
+ genai.configure(api_key=settings.GEMINI_API_KEY)
51
+ else:
52
+ logger.error("GEMINI_API_KEY not found in settings.")
53
+ # Potentially raise an error or handle this case as per application requirements
54
+ except AttributeError:
55
+ logger.error("Settings object does not have GEMINI_API_KEY attribute.")
56
+ # Handle missing settings attribute
57
+
58
+ # Define default safety settings for Gemini
59
+ # Adjust these as per your application's requirements
60
+ DEFAULT_SAFETY_SETTINGS = {
61
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
62
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
63
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
64
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
65
+ }
66
+
67
+ class GeminiService(BaseAttributionService):
68
+ def __init__(self, model_name: str = "gemini-1.5-flash-latest"):
69
+ """
70
+ Initializes the GeminiService.
71
+ Args:
72
+ model_name (str): The name of the Gemini model to use.
73
+ """
74
+ try:
75
+ self.model = genai.GenerativeModel(
76
+ model_name,
77
+ safety_settings=DEFAULT_SAFETY_SETTINGS
78
+ # system_instruction can be set here if a global system message is always used
79
+ )
80
+ logger.info(f"GeminiService initialized with model: {model_name}")
81
+ except Exception as e:
82
+ logger.error(f"Failed to initialize Gemini GenerativeModel: {exception_to_str(e)}")
83
+ # Depending on requirements, you might want to raise an error here
84
+ # For now, we'll let it proceed, and calls will fail if model isn't initialized.
85
+ self.model = None
86
+
87
+
88
+ def _prepare_image_parts(
89
+ self,
90
+ img_urls: Optional[List[str]] = None,
91
+ img_paths: Optional[List[str]] = None,
92
+ pil_images: Optional[List[Any]] = None, # PIL.Image.Image objects
93
+ ) -> List[Dict[str, Any]]:
94
+ """
95
+ Prepares image data in the format expected by Gemini API.
96
+ Decodes base64 image data to bytes.
97
+ Converts PIL images to bytes.
98
+ """
99
+ image_parts = []
100
+
101
+ # Process image URLs
102
+ if img_urls:
103
+ for img_url in img_urls:
104
+ try:
105
+ base64_data, img_type = get_image_base64_and_type(img_url)
106
+ if base64_data and img_type:
107
+ # Gemini expects raw bytes, so decode base64
108
+ image_bytes = base64.b64decode(base64_data)
109
+ mime_type = f"image/{img_type.lower()}"
110
+ image_parts.append({"mime_type": mime_type, "data": image_bytes})
111
+ else:
112
+ logger.warning(f"Could not retrieve or identify type for image URL: {img_url}")
113
+ except Exception as e:
114
+ logger.error(f"Error processing image URL {img_url}: {exception_to_str(e)}")
115
+
116
+ # Process image paths
117
+ if img_paths:
118
+ for img_path in img_paths:
119
+ try:
120
+ base64_data = get_image_data(img_path) # Assuming this returns base64 string
121
+ img_type = get_data_format(img_path) # Assuming this returns 'png', 'jpeg'
122
+ if base64_data and img_type:
123
+ image_bytes = base64.b64decode(base64_data)
124
+ mime_type = f"image/{img_type.lower()}"
125
+ image_parts.append({"mime_type": mime_type, "data": image_bytes})
126
+ else:
127
+ logger.warning(f"Could not retrieve or identify type for image path: {img_path}")
128
+ except Exception as e:
129
+ logger.error(f"Error processing image path {img_path}: {exception_to_str(e)}")
130
+
131
+ # Process PIL images
132
+ if pil_images:
133
+ for i, pil_image in enumerate(pil_images):
134
+ try:
135
+ img_format = pil_image.format or 'PNG' # Default to PNG if format is not available
136
+ mime_type = f"image/{img_format.lower()}"
137
+ with io.BytesIO() as img_byte_arr:
138
+ pil_image.save(img_byte_arr, format=img_format)
139
+ image_bytes = img_byte_arr.getvalue()
140
+ image_parts.append({"mime_type": mime_type, "data": image_bytes})
141
+ except Exception as e:
142
+ logger.error(f"Error processing PIL image #{i}: {exception_to_str(e)}")
143
+
144
+ return image_parts
145
+
146
+ @weave.op() # Assuming weave.op can be used as a decorator directly
147
+ async def extract_attributes(
148
+ self,
149
+ attributes_model: Type[BaseModel],
150
+ ai_model: str, # This will be the Gemini model name, e.g., "gemini-1.5-flash-latest"
151
+ img_urls: Optional[List[str]] = None,
152
+ product_taxonomy: str = "",
153
+ product_data: Optional[Dict[str, Union[str, List[str]]]] = None,
154
+ pil_images: Optional[List[Any]] = None,
155
+ img_paths: Optional[List[str]] = None,
156
+ ) -> Dict[str, Any]:
157
+ if not self.model:
158
+ raise VendorError("Gemini model not initialized.")
159
+ if self.model.model_name != ai_model: # If a different model is requested for this specific call
160
+ logger.info(f"Switching to model {ai_model} for this extraction request.")
161
+ # Note: This creates a new model object for the call.
162
+ # If this happens frequently, consider how model instances are managed.
163
+ current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS)
164
+ else:
165
+ current_model = self.model
166
+
167
+ # Construct the prompt text
168
+ # Combining system and human prompts as Gemini typically takes a list of contents.
169
+ # System instructions can also be part of the model's initialization.
170
+ system_message = prompts.EXTRACT_INFO_SYSTEM_MESSAGE
171
+ human_message = prompts.EXTRACT_INFO_HUMAN_MESSAGE.format(
172
+ product_taxonomy=product_taxonomy,
173
+ product_data=product_data_to_str(product_data if product_data else {}),
174
+ )
175
+ full_prompt_text = f"{system_message}\n\n{human_message}"
176
+
177
+ # For logging or debugging the prompt
178
+ logger.info(f"Gemini Prompt Text: {full_prompt_text[:500]}...") # Log a snippet
179
+
180
+ content_parts = [full_prompt_text]
181
+
182
+ # Prepare image parts
183
+ try:
184
+ image_parts = self._prepare_image_parts(img_urls, img_paths, pil_images)
185
+ content_parts.extend(image_parts)
186
+ except Exception as e:
187
+ logger.error(f"Failed during image preparation: {exception_to_str(e)}")
188
+ raise BadRequestError(f"Image processing failed: {e}")
189
+
190
+ if not image_parts and (img_urls or img_paths or pil_images):
191
+ logger.warning("Image sources provided, but no image parts were successfully prepared.")
192
+
193
+ # Define generation config for JSON output
194
+ # Pydantic's model_json_schema() generates an OpenAPI compliant schema dictionary.
195
+ try:
196
+ schema_for_gemini = attributes_model.model_json_schema()
197
+ except Exception as e:
198
+ logger.error(f"Error generating JSON schema from Pydantic model: {exception_to_str(e)}")
199
+ raise VendorError(f"Could not generate schema for attributes_model: {e}")
200
+
201
+ generation_config = GenerationConfig(
202
+ response_mime_type="application/json",
203
+ response_schema=schema_for_gemini, # Gemini expects the schema here
204
+ temperature=0.0, # For deterministic output, similar to low top_p
205
+ max_output_tokens=2048, # Adjust as needed, was 1000 for OpenAI
206
+ # top_p, top_k can also be set if needed
207
+ )
208
+
209
+ logger.info(f"Extracting attributes via Gemini model: {current_model.model_name}...")
210
+ try:
211
+ response = await current_model.generate_content_async(
212
+ contents=content_parts,
213
+ generation_config=generation_config,
214
+ # request_options={"timeout": 120} # Example: set timeout in seconds
215
+ )
216
+ except Exception as e: # Catches google.api_core.exceptions and others
217
+ error_message = exception_to_str(e)
218
+ logger.error(f"Gemini API call failed: {error_message}")
219
+ # More specific error handling for Gemini can be added here
220
+ # e.g., if isinstance(e, google.api_core.exceptions.InvalidArgument):
221
+ # raise BadRequestError(f"Invalid argument to Gemini: {error_message}")
222
+ raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message))
223
+
224
+ # Process the response
225
+ try:
226
+ # Check for safety blocks or refusals
227
+ if not response.candidates:
228
+ # This can happen if all candidates were filtered due to safety or other reasons.
229
+ block_reason_detail = "Unknown reason (no candidates)"
230
+ if response.prompt_feedback and response.prompt_feedback.block_reason:
231
+ block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}"
232
+ if response.prompt_feedback.block_reason_message:
233
+ block_reason_detail += f" - {response.prompt_feedback.block_reason_message}"
234
+ logger.error(f"Gemini response was blocked or empty. {block_reason_detail}")
235
+ raise VendorError(f"Gemini response blocked or empty. {block_reason_detail}")
236
+
237
+
238
+ # Assuming the first candidate is the one we want
239
+ candidate = response.candidates[0]
240
+
241
+ if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS
242
+ finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN"
243
+ logger.warning(f"Gemini generation finished with reason: {finish_reason_str}")
244
+ # Potentially raise error if finish reason is SAFETY, RECITATION, etc.
245
+ if finish_reason_str == "SAFETY":
246
+ safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings])
247
+ raise VendorError(f"Gemini content generation stopped due to safety concerns. Ratings: [{safety_ratings_str}]")
248
+
249
+
250
+ if not candidate.content.parts or not candidate.content.parts[0].text:
251
+ logger.error("Gemini response content is empty or not in the expected text format.")
252
+ raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + " (empty response text)")
253
+
254
+ response_text = candidate.content.parts[0].text
255
+
256
+ # Parse and validate the JSON response using the Pydantic model
257
+ parsed_data = attributes_model.model_validate_json(response_text)
258
+ return parsed_data.model_dump() # Return as dict
259
+
260
+ except ValidationError as ve:
261
+ logger.error(f"Pydantic validation failed for Gemini response: {ve}")
262
+ logger.debug(f"Invalid JSON received from Gemini: {response_text[:500]}...") # Log snippet of invalid JSON
263
+ raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {ve}")
264
+ except json.JSONDecodeError as je:
265
+ logger.error(f"JSON decoding failed for Gemini response: {je}")
266
+ logger.debug(f"Non-JSON response received: {response_text[:500]}...")
267
+ raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" Details: {je}")
268
+ except VendorError: # Re-raise VendorErrors
269
+ raise
270
+ except Exception as e:
271
+ error_message = exception_to_str(e)
272
+ logger.error(f"Error processing Gemini response: {error_message}")
273
+ # Log the raw response text if available and an error occurred
274
+ raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A"
275
+ logger.debug(f"Problematic Gemini response snippet: {raw_response_snippet}")
276
+ raise VendorError(f"Failed to process Gemini response: {error_message}")
277
+
278
+ @weave.op()
279
+ async def follow_schema(
280
+ self,
281
+ schema: Dict[str, Any], # This should be an OpenAPI schema dictionary
282
+ data: Dict[str, Any],
283
+ ai_model: str = "gemini-1.5-flash-latest" # Model for this specific task
284
+ ) -> Dict[str, Any]:
285
+ if not self.model: # Check if the main model was initialized
286
+ logger.warning("Main Gemini model not initialized. Attempting to initialize a temporary one for follow_schema.")
287
+ try:
288
+ current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS)
289
+ except Exception as e:
290
+ raise VendorError(f"Failed to initialize Gemini model for follow_schema: {exception_to_str(e)}")
291
+ elif self.model.model_name != ai_model:
292
+ logger.info(f"Switching to model {ai_model} for this follow_schema request.")
293
+ current_model = genai.GenerativeModel(ai_model, safety_settings=DEFAULT_SAFETY_SETTINGS)
294
+ else:
295
+ current_model = self.model
296
+
297
+ logger.info(f"Following schema via Gemini model: {current_model.model_name}...")
298
+
299
+ # Prepare the prompt
300
+ # System message can be part of the model or prepended here.
301
+ system_message = prompts.FOLLOW_SCHEMA_SYSTEM_MESSAGE
302
+ # The human message needs to contain the data to be transformed.
303
+ # Ensure `json_info` placeholder is correctly used by your prompt string.
304
+ try:
305
+ data_as_json_string = json.dumps(data, indent=2)
306
+ except TypeError as te:
307
+ logger.error(f"Could not serialize 'data' to JSON for prompt: {te}")
308
+ raise BadRequestError(f"Input data for schema following is not JSON serializable: {te}")
309
+
310
+ human_message = prompts.FOLLOW_SCHEMA_HUMAN_MESSAGE.format(json_info=data_as_json_string)
311
+ full_prompt_text = f"{system_message}\n\n{human_message}"
312
+
313
+ content_parts = [full_prompt_text]
314
+
315
+ # Define generation config for JSON output using the provided schema
316
+ generation_config = GenerationConfig(
317
+ response_mime_type="application/json",
318
+ response_schema=schema, # The provided schema dictionary
319
+ temperature=0.0, # For deterministic output
320
+ max_output_tokens=2048, # Adjust as needed
321
+ )
322
+
323
+ try:
324
+ response = await current_model.generate_content_async(
325
+ contents=content_parts,
326
+ generation_config=generation_config,
327
+ )
328
+ except Exception as e:
329
+ error_message = exception_to_str(e)
330
+ logger.error(f"Gemini API call failed for follow_schema: {error_message}")
331
+ raise VendorError(errors.VENDOR_THROW_ERROR.format(error_message=error_message))
332
+
333
+ # Process response
334
+ try:
335
+ if not response.candidates:
336
+ block_reason_detail = "Unknown reason (no candidates)"
337
+ if response.prompt_feedback and response.prompt_feedback.block_reason:
338
+ block_reason_detail = f"Blocked due to: {response.prompt_feedback.block_reason.name}"
339
+ logger.error(f"Gemini response was blocked or empty in follow_schema. {block_reason_detail}")
340
+ # OpenAI version returned {"status": "refused"}, mimicking similar for block
341
+ return {"status": "refused", "reason": block_reason_detail}
342
+
343
+ candidate = response.candidates[0]
344
+
345
+ if candidate.finish_reason not in [1, 2]: # 1=STOP, 2=MAX_TOKENS
346
+ finish_reason_str = candidate.finish_reason.name if candidate.finish_reason else "UNKNOWN"
347
+ logger.warning(f"Gemini generation (follow_schema) finished with reason: {finish_reason_str}")
348
+ if finish_reason_str == "SAFETY":
349
+ safety_ratings_str = ", ".join([f"{sr.category.name}: {sr.probability.name}" for sr in candidate.safety_ratings])
350
+ return {"status": "refused", "reason": f"Safety block. Ratings: [{safety_ratings_str}]"}
351
+
352
+
353
+ if not candidate.content.parts or not candidate.content.parts[0].text:
354
+ logger.error("Gemini response content (follow_schema) is empty.")
355
+ # Mimic OpenAI's refusal structure or raise error
356
+ return {"status": "refused", "reason": "Empty content from Gemini"}
357
+
358
+
359
+ response_text = candidate.content.parts[0].text
360
+ parsed_data = json.loads(response_text) # The schema is enforced by Gemini
361
+ return parsed_data
362
+
363
+ except json.JSONDecodeError as je:
364
+ logger.error(f"JSON decoding failed for Gemini response (follow_schema): {je}")
365
+ logger.debug(f"Non-JSON response received: {response_text[:500]}...")
366
+ # The original code raised ValueError(errors.VENDOR_ERROR_INVALID_JSON)
367
+ # Let's use VendorError for consistency if that's preferred, or ValueError
368
+ raise VendorError(errors.VENDOR_ERROR_INVALID_JSON + f" (follow_schema) Details: {je}")
369
+ except Exception as e:
370
+ error_message = exception_to_str(e)
371
+ logger.error(f"Error processing Gemini response (follow_schema): {error_message}")
372
+ raw_response_snippet = response_text[:500] if 'response_text' in locals() else "N/A"
373
+ logger.debug(f"Problematic Gemini response snippet (follow_schema): {raw_response_snippet}")
374
+ raise VendorError(f"Failed to process Gemini response (follow_schema): {error_message}")
375
+
requirements.txt CHANGED
@@ -12,4 +12,5 @@ pytest==8.3.4
12
  boto3==1.35.87
13
  redis==5.2.1
14
  weave==0.51.39
15
- gradio==5.22.0
 
 
12
  boto3==1.35.87
13
  redis==5.2.1
14
  weave==0.51.39
15
+ gradio==5.22.0
16
+ google-generativeai