File size: 2,628 Bytes
9ae1b66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os

import numpy as np
import pandas as pd
from datasets import Dataset, DownloadMode, load_dataset
from gradio_client import Client

from src.my_logger import setup_logger

SUBREDDIT = os.environ["SUBREDDIT"]
USERNAME = os.environ["USERNAME"]
OG_DATASET= f"{USERNAME}/dataset-creator-reddit-{SUBREDDIT}"
PROCESSED_DATASET= os.environ['PROCESSED_DATASET']

client = Client("derek-thomas/nomic-embeddings")
logger = setup_logger(__name__)


async def load_datasets():
    # Get latest datasets locally
    logger.debug(f"Trying to download {PROCESSED_DATASET}")
    dataset = load_dataset(PROCESSED_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD)
    logger.debug(f"Loaded {PROCESSED_DATASET}")

    logger.debug(f"Trying to download {OG_DATASET}")
    original_dataset = load_dataset(OG_DATASET, download_mode=DownloadMode.FORCE_REDOWNLOAD)
    logger.debug(f"Loaded {OG_DATASET}")
    return dataset, original_dataset


def merge_and_update_datasets(dataset, original_dataset):
    # Merge and figure out which rows need to be updated with embeddings
    odf = original_dataset['train'].to_pandas()
    df = dataset['train'].to_pandas()

    # Step 1: Merge df onto odf
    # We'll bring in 'content' and 'embedding' from df to compare and possibly update 'embedding'
    merged_df = pd.merge(odf, df[['id', 'content', 'embedding']], on='id', how='left', suffixes=('_odf', ''))
    updated_rows = len(merged_df[merged_df.content != merged_df.content_odf])

    # Step 2: Compare 'content' from odf and df, update 'embedding' if they differ
    merged_df['embedding'] = np.where(merged_df['content_odf'] != merged_df['content'], None, merged_df['embedding'])

    # Step 3: Cleanup - keep only the necessary columns.
    # Assuming you want to keep 'content' from 'odf' and the updated 'embedding', and drop the rest
    merged_df = merged_df.drop(columns=['content', 'new', 'updated'])  # Update columns to match df
    merged_df.rename(columns={'content_odf': 'content'}, inplace=True)  # Rename 'content_odf' back to 'content'

    logger.info(f"Updating {updated_rows} rows...")
    # Iterate over the DataFrame rows where 'embedding' is None
    for index, row in merged_df[merged_df['embedding'].isnull()].iterrows():
        # Update 'embedding' for the current row using our function
        merged_df.at[index, 'embedding'] = update_embeddings(row['content'])

    dataset['train'] = Dataset.from_pandas(merged_df)
    logger.info(f"Updated {updated_rows} rows")
    return dataset


def update_embeddings(content):
    embedding = client.predict(content, api_name="/embed")
    return np.array(embedding)