Spaces:
Sleeping
Sleeping
| """Gemini BATCH API helper.""" | |
| from __future__ import annotations | |
| """Helpers for running Gemini BATCH API jobs.""" | |
| from dataclasses import dataclass | |
| import json | |
| import mimetypes | |
| import os | |
| import tempfile | |
| import time | |
| from typing import Any, Dict, Iterable, List, Optional | |
| import requests | |
| from calls_analyser.domain.exceptions import AIModelError | |
| try: # pragma: no cover - optional dependency wiring | |
| from google import genai | |
| from google.genai import types | |
| except Exception: # pragma: no cover - optional dependency wiring | |
| genai = None # type: ignore | |
| types = None # type: ignore | |
| class BatchTask: | |
| """Represents a single audio file queued for BATCH processing.""" | |
| key: str | |
| path: str | |
| mime_type: str | |
| file_uri: Optional[str] = None | |
| class GeminiBatchRunner: | |
| """Create and poll Gemini BATCH jobs for multiple audio files.""" | |
| def __init__(self, api_key: str, model: str) -> None: | |
| if genai is None: | |
| raise AIModelError("google-genai library is not available") | |
| self._api_key = api_key | |
| self._model = model | |
| self._client = genai.Client(api_key=api_key) | |
| def run_batch( | |
| self, | |
| tasks: Iterable[BatchTask], | |
| prompt_text: str, | |
| *, | |
| chunk_size: int = 20, | |
| ) -> Dict[str, str]: | |
| """Run batch jobs and return mapping ``key -> text``.""" | |
| pending = list(tasks) | |
| if not pending: | |
| return {} | |
| results: Dict[str, str] = {} | |
| normalized_chunk_size = max(1, int(chunk_size)) | |
| for chunk_idx in range(0, len(pending), normalized_chunk_size): | |
| chunk = pending[chunk_idx : chunk_idx + normalized_chunk_size] | |
| self._process_chunk(chunk, chunk_idx // normalized_chunk_size, prompt_text, results) | |
| return results | |
| def _process_chunk( | |
| self, chunk: List[BatchTask], chunk_index: int, prompt_text: str, results: Dict[str, str] | |
| ) -> None: | |
| if not chunk: | |
| return | |
| uploaded_file_names: List[str] = [] | |
| uploaded_jsonl_name: Optional[str] = None | |
| try: | |
| for task in chunk: | |
| uploaded = self._client.files.upload(file=task.path) | |
| task.file_uri = uploaded.uri | |
| task.mime_type = uploaded.mime_type or task.mime_type | |
| uploaded_file_names.append(uploaded.name) | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| jsonl_path = os.path.join(tmpdir, f"batch_input_{chunk_index:03}.jsonl") | |
| self._prepare_chunk_jsonl(chunk, jsonl_path, prompt_text, chunk_index) | |
| uploaded_jsonl = self._upload_jsonl(jsonl_path, f"batch-input-{chunk_index:03}") | |
| uploaded_jsonl_name = uploaded_jsonl.name | |
| batch_name = self._create_batch_job_rest( | |
| model_id=self._model, | |
| input_file_name=uploaded_jsonl_name, | |
| display_name=f"audio-batch-{chunk_index:03}", | |
| ) | |
| dest_file_name = self._poll_batch_job(batch_name) | |
| file_content = self._client.files.download(file=dest_file_name) | |
| self._process_results_jsonl_bytes(file_content, results) | |
| finally: | |
| for name in uploaded_file_names: | |
| try: | |
| self._client.files.delete(name=name) | |
| except Exception: # pragma: no cover - cleanup best effort | |
| pass | |
| if uploaded_jsonl_name: | |
| try: | |
| self._client.files.delete(name=uploaded_jsonl_name) | |
| except Exception: # pragma: no cover - cleanup best effort | |
| pass | |
| # ------------------------- JSONL helpers ------------------------- | |
| def _prepare_chunk_jsonl( | |
| self, tasks_chunk: List[BatchTask], jsonl_path: str, prompt_text: str, chunk_index: int | |
| ) -> None: | |
| os.makedirs(os.path.dirname(jsonl_path), exist_ok=True) | |
| with open(jsonl_path, "w", encoding="utf-8") as f: | |
| for i, task in enumerate(tasks_chunk): | |
| unique_key = task.key or f"chunk{chunk_index:03}_batch_{i:03}" | |
| parts = self._build_parts_for_task(task, prompt_text) | |
| request_entry = { | |
| "key": unique_key, | |
| "request": { | |
| "contents": [ | |
| { | |
| "role": "user", | |
| "parts": parts, | |
| } | |
| ] | |
| }, | |
| } | |
| f.write(json.dumps(request_entry, ensure_ascii=False) + "\n") | |
| def _build_parts_for_task(task: BatchTask, prompt_text: str) -> List[Dict[str, Any]]: | |
| clean_prompt = (prompt_text or "").strip() | |
| parts: List[Dict[str, Any]] = [] | |
| if clean_prompt: | |
| parts.append({"text": clean_prompt}) | |
| parts.append( | |
| { | |
| "file_data": { | |
| "mime_type": task.mime_type, | |
| "file_uri": task.file_uri, | |
| } | |
| } | |
| ) | |
| return parts | |
| def _upload_jsonl(self, jsonl_path: str, display_name: str): | |
| try: | |
| return self._client.files.upload( | |
| file=jsonl_path, | |
| config=types.UploadFileConfig(display_name=display_name, mime_type="jsonl"), | |
| ) | |
| except Exception: | |
| return self._client.files.upload( | |
| file=jsonl_path, | |
| config=types.UploadFileConfig(display_name=display_name), | |
| ) | |
| # ------------------------- REST helpers -------------------------- | |
| def _rest_model_name(model_id: str) -> str: | |
| return model_id.replace("models/", "") | |
| def _create_batch_job_rest(self, model_id: str, input_file_name: str, display_name: str) -> str: | |
| url = ( | |
| "https://generativelanguage.googleapis.com/v1beta/models/" | |
| f"{self._rest_model_name(model_id)}:batchGenerateContent" | |
| ) | |
| headers = { | |
| "x-goog-api-key": self._api_key, | |
| "Content-Type": "application/json", | |
| } | |
| payload = { | |
| "batch": { | |
| "display_name": display_name, | |
| "input_config": {"file_name": input_file_name}, | |
| } | |
| } | |
| resp = requests.post(url, headers=headers, json=payload, timeout=60) | |
| if not resp.ok: | |
| raise AIModelError(f"REST create failed: {resp.status_code} {resp.text}") | |
| data = resp.json() | |
| name = data.get("name") | |
| if not name and isinstance(data.get("batch"), dict): | |
| name = data["batch"].get("name") | |
| if not name: | |
| raise AIModelError(f"REST create succeeded but no batch name found. Response: {data}") | |
| return name | |
| def _get_batch_job_rest(self, name: str) -> Dict[str, Any]: | |
| url = f"https://generativelanguage.googleapis.com/v1beta/{name}" | |
| headers = {"x-goog-api-key": self._api_key} | |
| resp = requests.get(url, headers=headers, timeout=60) | |
| if not resp.ok: | |
| raise AIModelError(f"REST get failed: {resp.status_code} {resp.text}") | |
| return resp.json() | |
| def _extract_state(rest_obj: Dict[str, Any]) -> Optional[str]: | |
| return rest_obj.get("state") or (rest_obj.get("metadata") or {}).get("state") or (rest_obj.get("batch") or {}).get("state") | |
| def _extract_result_file_name(rest_obj: Dict[str, Any]) -> Optional[str]: | |
| resp = rest_obj.get("response") or {} | |
| dest = resp.get("dest") or {} | |
| return ( | |
| dest.get("file_name") | |
| or dest.get("fileName") | |
| or resp.get("file_name") | |
| or resp.get("fileName") | |
| or resp.get("responsesFile") | |
| or resp.get("responses_file") | |
| ) | |
| def _poll_batch_job(self, batch_name: str) -> str: | |
| completed_states = { | |
| "BATCH_STATE_SUCCEEDED", | |
| "BATCH_STATE_FAILED", | |
| "BATCH_STATE_CANCELLED", | |
| "BATCH_STATE_EXPIRED", | |
| "BATCH_STATE_PAUSED", | |
| } | |
| while True: | |
| rest_job = self._get_batch_job_rest(batch_name) | |
| state = self._extract_state(rest_job) | |
| if state in completed_states: | |
| break | |
| time.sleep(30) | |
| if state != "BATCH_STATE_SUCCEEDED": | |
| err = rest_job.get("error") or (rest_job.get("response") or {}).get("error") | |
| raise AIModelError(f"Batch job failed with state {state}: {err}") | |
| result_file_name = self._extract_result_file_name(rest_job) | |
| if not result_file_name: | |
| raise AIModelError("Could not locate result file name in REST response") | |
| return result_file_name | |
| # ------------------------- Results processing -------------------- | |
| def _process_results_jsonl_bytes(content_bytes: bytes, results: Dict[str, str]) -> None: | |
| content_str = content_bytes.decode("utf-8", errors="replace") | |
| for line in content_str.splitlines(): | |
| if not line.strip(): | |
| continue | |
| try: | |
| result = json.loads(line) | |
| except Exception: | |
| continue | |
| key = result.get("key") | |
| if not key: | |
| continue | |
| response_wrapper = result.get("response", {}) | |
| if "error" in response_wrapper: | |
| results[key] = f"Error: {response_wrapper['error']}" | |
| continue | |
| candidates = response_wrapper.get("candidates", []) | |
| text: Optional[str] = None | |
| if candidates and "content" in candidates[0]: | |
| parts = candidates[0]["content"].get("parts", []) | |
| for part in parts: | |
| if isinstance(part, dict) and part.get("text"): | |
| text = part["text"] | |
| break | |
| if text is None: | |
| continue | |
| results[key] = text | |
| def guess_mime_type(path: str) -> str: | |
| mime_type, _ = mimetypes.guess_type(path) | |
| return mime_type or "application/octet-stream" | |