nsthorat's picture
Push
55dc3dd
raw
history blame
No virus
2.59 kB
"""CSV source."""
from typing import Iterable, Optional
import duckdb
import pandas as pd
from pydantic import Field
from typing_extensions import override
from ...schema import Item
from ...utils import download_http_files
from ..duckdb_utils import duckdb_setup
from .source import Source, SourceSchema, normalize_column_name, schema_from_df
LINE_NUMBER_COLUMN = '__line_number__'
class CSVDataset(Source):
"""CSV data loader
CSV files can live locally as a filepath, or point to an external URL.
""" # noqa: D415, D400
name = 'csv'
filepaths: list[str] = Field(description='A list of filepaths to CSV files.')
delim: Optional[str] = Field(default=',', description='The CSV file delimiter to use.')
header: Optional[bool] = Field(default=True, description='Whether the CSV file has a header row.')
names: Optional[list[str]] = Field(
default=None, description='Provide header names if the file does not contain a header.')
_source_schema: SourceSchema
_df: pd.DataFrame
@override
def setup(self) -> None:
# Download CSV files to /tmp if they are via HTTP to speed up duckdb.
filepaths = download_http_files(self.filepaths)
con = duckdb.connect(database=':memory:')
# DuckDB expects s3 protocol: https://duckdb.org/docs/guides/import/s3_import.html.
s3_filepaths = [path.replace('gs://', 's3://') for path in filepaths]
# NOTE: We use duckdb here to increase parallelism for multiple files.
# NOTE: We turn off the parallel reader because of https://github.com/lilacai/lilac/issues/373.
self._df = con.execute(f"""
{duckdb_setup(con)}
SELECT * FROM read_csv_auto(
{s3_filepaths},
SAMPLE_SIZE=500000,
HEADER={self.header},
{f'NAMES={self.names},' if self.names else ''}
DELIM='{self.delim or ','}',
IGNORE_ERRORS=true,
PARALLEL=false
)
""").df()
for column_name in self._df.columns:
self._df.rename(columns={column_name: normalize_column_name(column_name)}, inplace=True)
# Create the source schema in prepare to share it between process and source_schema.
self._source_schema = schema_from_df(self._df, LINE_NUMBER_COLUMN)
@override
def source_schema(self) -> SourceSchema:
"""Return the source schema."""
return self._source_schema
@override
def process(self) -> Iterable[Item]:
"""Process the source upload request."""
cols = self._df.columns.tolist()
yield from ({
LINE_NUMBER_COLUMN: idx,
**dict(zip(cols, item_vals)),
} for idx, *item_vals in self._df.itertuples())