import asyncio import csv import json import time import typing from typing import Optional import requests from fastavro import parse_schema, reader, writer from . import EmbedResponse, EmbedResponse_EmbeddingsFloats, EmbedResponse_EmbeddingsByType, ApiMeta, \ EmbedByTypeResponseEmbeddings, ApiMetaBilledUnits, EmbedJob, CreateEmbedJobResponse, Dataset from .datasets import DatasetsCreateResponse, DatasetsGetResponse def get_terminal_states(): return get_success_states() | get_failed_states() def get_success_states(): return {"complete", "validated"} def get_failed_states(): return {"unknown", "failed", "skipped", "cancelled", "failed"} def get_id( awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]): return getattr(awaitable, "job_id", None) or getattr(awaitable, "id", None) or getattr( getattr(awaitable, "dataset", None), "id", None) def get_validation_status(awaitable: typing.Union[EmbedJob, DatasetsGetResponse]): return getattr(awaitable, "status", None) or getattr(getattr(awaitable, "dataset", None), "validation_status", None) def get_job(cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse, EmbedJob, DatasetsGetResponse]) -> \ typing.Union[ EmbedJob, DatasetsGetResponse]: if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse": return cohere.embed_jobs.get(id=get_id(awaitable)) elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse": return cohere.datasets.get(id=get_id(awaitable)) else: raise ValueError(f"Unexpected awaitable type {awaitable}") async def async_get_job(cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse]) -> \ typing.Union[ EmbedJob, DatasetsGetResponse]: if awaitable.__class__.__name__ == "EmbedJob" or awaitable.__class__.__name__ == "CreateEmbedJobResponse": return await cohere.embed_jobs.get(id=get_id(awaitable)) elif awaitable.__class__.__name__ == "DatasetsGetResponse" or awaitable.__class__.__name__ == "DatasetsCreateResponse": return await cohere.datasets.get(id=get_id(awaitable)) else: raise ValueError(f"Unexpected awaitable type {awaitable}") def get_failure_reason(job: typing.Union[EmbedJob, DatasetsGetResponse]) -> Optional[str]: if isinstance(job, EmbedJob): return f"Embed job {job.job_id} failed with status {job.status}" elif isinstance(job, DatasetsGetResponse): return f"Dataset creation failed with status {job.dataset.validation_status} and error : {job.dataset.validation_error}" return None @typing.overload def wait( cohere: typing.Any, awaitable: CreateEmbedJobResponse, timeout: Optional[float] = None, interval: float = 10, ) -> EmbedJob: ... @typing.overload def wait( cohere: typing.Any, awaitable: DatasetsCreateResponse, timeout: Optional[float] = None, interval: float = 10, ) -> DatasetsGetResponse: ... def wait( cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse], timeout: Optional[float] = None, interval: float = 2, ) -> typing.Union[EmbedJob, DatasetsGetResponse]: start_time = time.time() terminal_states = get_terminal_states() failed_states = get_failed_states() job = get_job(cohere, awaitable) while get_validation_status(job) not in terminal_states: if timeout is not None and time.time() - start_time > timeout: raise TimeoutError(f"wait timed out after {timeout} seconds") time.sleep(interval) print("...") job = get_job(cohere, awaitable) if get_validation_status(job) in failed_states: raise Exception(get_failure_reason(job)) return job @typing.overload async def async_wait( cohere: typing.Any, awaitable: CreateEmbedJobResponse, timeout: Optional[float] = None, interval: float = 10, ) -> EmbedJob: ... @typing.overload async def async_wait( cohere: typing.Any, awaitable: DatasetsCreateResponse, timeout: Optional[float] = None, interval: float = 10, ) -> DatasetsGetResponse: ... async def async_wait( cohere: typing.Any, awaitable: typing.Union[CreateEmbedJobResponse, DatasetsCreateResponse], timeout: Optional[float] = None, interval: float = 10, ) -> typing.Union[EmbedJob, DatasetsGetResponse]: start_time = time.time() terminal_states = get_terminal_states() failed_states = get_failed_states() job = await async_get_job(cohere, awaitable) while get_validation_status(job) not in terminal_states: if timeout is not None and time.time() - start_time > timeout: raise TimeoutError(f"wait timed out after {timeout} seconds") await asyncio.sleep(interval) print("...") job = await async_get_job(cohere, awaitable) if get_validation_status(job) in failed_states: raise Exception(get_failure_reason(job)) return job def sum_fields_if_not_none(obj: typing.Any, field: str) -> Optional[int]: non_none = [getattr(obj, field) for obj in obj if getattr(obj, field) is not None] return sum(non_none) if non_none else None def merge_meta_field(metas: typing.List[ApiMeta]) -> ApiMeta: api_version = metas[0].api_version billed_units = [meta.billed_units for meta in metas] input_tokens = sum_fields_if_not_none(billed_units, "input_tokens") output_tokens = sum_fields_if_not_none(billed_units, "output_tokens") search_units = sum_fields_if_not_none(billed_units, "search_units") classifications = sum_fields_if_not_none(billed_units, "classifications") warnings = {warning for meta in metas if meta.warnings for warning in meta.warnings} return ApiMeta( api_version=api_version, billed_units=ApiMetaBilledUnits( input_tokens=input_tokens, output_tokens=output_tokens, search_units=search_units, classifications=classifications ), warnings=list(warnings) ) def merge_embed_responses(responses: typing.List[EmbedResponse]) -> EmbedResponse: meta = merge_meta_field([response.meta for response in responses if response.meta]) response_id = ", ".join(response.id for response in responses) texts = [ text for response in responses for text in response.texts ] if responses[0].response_type == "embeddings_floats": embeddings_floats = typing.cast(typing.List[EmbedResponse_EmbeddingsFloats], responses) embeddings = [ embedding for embeddings_floats in embeddings_floats for embedding in embeddings_floats.embeddings ] return EmbedResponse_EmbeddingsFloats( response_type="embeddings_floats", id=response_id, texts=texts, embeddings=embeddings, meta=meta ) else: embeddings_type = typing.cast(typing.List[EmbedResponse_EmbeddingsByType], responses) embeddings_by_type = [ response.embeddings for response in embeddings_type ] # only get set keys from the pydantic model (i.e. exclude fields that are set to 'None') fields = embeddings_type[0].embeddings.dict(exclude_unset=True).keys() merged_dicts = { field: [ embedding for embedding_by_type in embeddings_by_type for embedding in getattr(embedding_by_type, field) ] for field in fields } embeddings_by_type_merged = EmbedByTypeResponseEmbeddings.parse_obj(merged_dicts) return EmbedResponse_EmbeddingsByType( response_type="embeddings_by_type", id=response_id, embeddings=embeddings_by_type_merged, texts=texts, meta=meta ) supported_formats = ["jsonl", "csv", "avro"] def save_avro(dataset: Dataset, filepath: str): if not dataset.schema_: raise ValueError("Dataset does not have a schema") schema = parse_schema(json.loads(dataset.schema_)) with open(filepath, "wb") as outfile: writer(outfile, schema, dataset_generator(dataset)) def save_jsonl(dataset: Dataset, filepath: str): with open(filepath, "w") as outfile: for data in dataset_generator(dataset): json.dump(data, outfile) outfile.write("\n") def save_csv(dataset: Dataset, filepath: str): with open(filepath, "w") as outfile: for i, data in enumerate(dataset_generator(dataset)): if i == 0: writer = csv.DictWriter(outfile, fieldnames=list(data.keys())) writer.writeheader() writer.writerow(data) def dataset_generator(dataset: Dataset): if not dataset.dataset_parts: raise ValueError("Dataset does not have dataset_parts") for part in dataset.dataset_parts: if not part.url: raise ValueError("Dataset part does not have a url") resp = requests.get(part.url, stream=True) for record in reader(resp.raw): yield record class SdkUtils: @staticmethod def save_dataset(dataset: Dataset, filepath: str, format: typing.Literal["jsonl", "csv", "avro"] = "jsonl"): if format == "jsonl": return save_jsonl(dataset, filepath) if format == "csv": return save_csv(dataset, filepath) if format == "avro": return save_avro(dataset, filepath) raise Exception(f"unsupported format must be one of : {supported_formats}") class SyncSdkUtils(SdkUtils): pass class AsyncSdkUtils(SdkUtils): pass