"""A Parquet file writer that wraps the pyarrow writer.""" from typing import IO, Optional import pyarrow as pa import pyarrow.parquet as pq from .schema import Item, Schema, schema_to_arrow_schema class ParquetWriter: """A writer to parquet.""" def __init__(self, schema: Schema, codec: str = 'snappy', row_group_buffer_size: int = 128 * 1024 * 1024, record_batch_size: int = 10_000): self._schema = schema_to_arrow_schema(schema) self._codec = codec self._row_group_buffer_size = row_group_buffer_size self._buffer: list[list[Optional[Item]]] = [[] for _ in range(len(self._schema.names))] self._buffer_size = record_batch_size self._record_batches: list[pa.RecordBatch] = [] self._record_batches_byte_size = 0 self.writer: pq.ParquetWriter = None def open(self, file_handle: IO) -> None: """Open the destination file for writing.""" self.writer = pq.ParquetWriter(file_handle, self._schema, compression=self._codec) def write(self, record: Item) -> None: """Write the record to the destination file.""" if len(self._buffer[0]) >= self._buffer_size: self._flush_buffer() if self._record_batches_byte_size >= self._row_group_buffer_size: self._write_batches() # reorder the data in columnar format. for i, n in enumerate(self._schema.names): self._buffer[i].append(record.get(n)) def close(self) -> None: """Flushes the write buffer and closes the destination file.""" if len(self._buffer[0]) > 0: self._flush_buffer() if self._record_batches_byte_size > 0: self._write_batches() self.writer.close() def _write_batches(self) -> None: table = pa.Table.from_batches(self._record_batches, schema=self._schema) self._record_batches = [] self._record_batches_byte_size = 0 self.writer.write_table(table) def _flush_buffer(self) -> None: arrays: list[pa.array] = [[] for _ in range(len(self._schema.names))] for x, y in enumerate(self._buffer): arrays[x] = pa.array(y, type=self._schema.types[x]) self._buffer[x] = [] rb = pa.RecordBatch.from_arrays(arrays, schema=self._schema) self._record_batches.append(rb) size = 0 for x in arrays: for b in x.buffers(): # type: ignore if b is not None: size = size + b.size self._record_batches_byte_size = self._record_batches_byte_size + size