from typing import Dict, List, Optional | |
from pathlib import Path | |
import pandas as pd | |
from llama_index.readers.file import CSVReader | |
from llama_index.schema import Document | |
class EnhancedCSVReader: | |
"""Enhanced CSV reader with metadata extraction capabilities.""" | |
def __init__(self): | |
self.csv_reader = CSVReader() | |
def load_data(self, file_path: str) -> List[Document]: | |
"""Load CSV file and extract documents with metadata.""" | |
# Load the CSV file | |
documents = self.csv_reader.load_data(file_path) | |
# Extract and add metadata | |
csv_metadata = self._extract_metadata(file_path) | |
# Enhance documents with metadata | |
for doc in documents: | |
doc.metadata.update(csv_metadata) | |
return documents | |
def _extract_metadata(self, file_path: str) -> Dict: | |
"""Extract useful metadata from CSV file.""" | |
df = pd.read_csv(file_path) | |
filename = Path(file_path).name | |
# Extract column information | |
columns = df.columns.tolist() | |
dtypes = {col: str(df[col].dtype) for col in columns} | |
# Extract sample values (first 3 non-null values per column) | |
samples = {} | |
for col in columns: | |
non_null_values = df[col].dropna().head(3).tolist() | |
samples[col] = [str(val) for val in non_null_values] | |
# Basic statistics | |
row_count = len(df) | |
return { | |
"filename": filename, | |
"columns": columns, | |
"dtypes": dtypes, | |
"samples": samples, | |
"row_count": row_count | |
} | |