Spaces:
Sleeping
Sleeping
import json | |
import os | |
import sys | |
from pathlib import Path | |
from typing import Dict, Any, List | |
from loguru import logger | |
from omegaconf import OmegaConf | |
from sqlalchemy import create_engine | |
from sqlalchemy.orm import sessionmaker | |
def process_json_files( | |
raw_data_path: str) -> tuple[List[Dict[str, str]], List[Dict[str, Any]]]: | |
""" | |
Process all JSON files in the raw data folder. | |
Args: | |
raw_data_path (str): Path to the folder containing JSON files. | |
Returns: | |
tuple: Lists of public and meme data to be added to the database. | |
""" | |
publics_to_add: List[Dict[str, str]] = [] | |
memes_to_add: List[Dict[str, Any]] = [] | |
for filename in os.listdir(raw_data_path): | |
if filename.endswith('.json'): | |
public_vk = filename[:-5] # Remove .json extension | |
file_path = os.path.join(raw_data_path, filename) | |
with open(file_path, 'r', encoding='utf-8') as file: | |
data = json.load(file) | |
publics_to_add.append({ | |
"public_vk": public_vk, | |
"public_name": data['name'] | |
}) | |
for post in data['posts']: | |
memes_to_add.append({ | |
"public_vk": public_vk, | |
"text": post['text'], | |
"image_url": post['image_url'] | |
}) | |
logger.info( | |
f"Processed file: {filename}, found {len(data['posts'])} memes") | |
return publics_to_add, memes_to_add | |
def main(): | |
from src.db.models import Base | |
from src.db import crud | |
logger.add("logs/make_db.log", rotation="10 MB") | |
# Load configuration | |
config = OmegaConf.load('config.yaml') | |
config = OmegaConf.to_container(config) | |
engine = create_engine(config['database']['url']) | |
# Drop all existing tables and create new ones | |
Base.metadata.drop_all(bind=engine) | |
Base.metadata.create_all(bind=engine) | |
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
db = SessionLocal() | |
raw_data_path = config['data_folders']['raw_data'] | |
publics_to_add, memes_to_add = process_json_files(raw_data_path) | |
# Add all publics to the database | |
added_publics = crud.add_publics(db, publics_to_add) | |
# Create a mapping of public_vk to public_id | |
public_vk_to_id = {public.public_vk: public.id for public in added_publics} | |
# Update memes with correct public_id | |
for meme in memes_to_add: | |
meme['public_id'] = public_vk_to_id[meme.pop('public_vk')] | |
# Add all memes to the database | |
crud.add_memes(db, memes_to_add) | |
logger.info( | |
f"Added {len(added_publics)} publics and {len(memes_to_add)} memes to the database") | |
db.close() | |
logger.info("Database population completed") | |
if __name__ == "__main__": | |
# Set up project root path | |
project_root = Path(__file__).resolve().parents[1] | |
sys.path.insert(0, str(project_root)) | |
main() | |