File size: 2,741 Bytes
d1430bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ac1c38
d1430bc
5ac1c38
d1430bc
 
5ac1c38
 
d1430bc
 
 
5ac1c38
d1430bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ac1c38
d1430bc
 
 
 
 
 
 
 
5ac1c38
d1430bc
 
 
 
 
 
 
 
 
 
 
 
 
5ac1c38
d1430bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ac1c38
d1430bc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
from dataclasses import asdict

import pandas as pd
from langchain.callbacks import get_openai_callback
from langchain.document_loaders import DataFrameLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.text_splitter import TokenTextSplitter
from langchain.vectorstores import Chroma
from tqdm import tqdm
from wandb.integration.langchain import WandbTracer

import wandb
from config import config


def get_data(artifact_name: str, total_episodes=None):
    podcast_artifact = wandb.use_artifact(artifact_name, type="dataset")
    podcast_artifact_dir = podcast_artifact.download(config.root_artifact_dir)
    filename = artifact_name.split(":")[0].split("/")[-1]
    df = pd.read_csv(os.path.join(podcast_artifact_dir, f"{filename}.csv"))
    if total_episodes is not None:
        df = df.iloc[:total_episodes]
    return df


def create_embeddings(episode_df: pd.DataFrame, index: int):
    # load docs into langchain format
    loader = DataFrameLoader(episode_df, page_content_column="transcript")
    data = loader.load()

    # split the documents
    text_splitter = TokenTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0)
    docs = text_splitter.split_documents(data)

    title = data[0].metadata["title"]
    print(f"Number of documents for podcast {title}: {len(docs)}")

    # initialize embedding engine
    embeddings = OpenAIEmbeddings()

    db = Chroma.from_documents(
        docs,
        embeddings,
        persist_directory=os.path.join(config.root_data_dir / "chromadb", str(index)),
    )
    db.persist()


if __name__ == "__main__":
    # initialize wandb tracer
    WandbTracer.init(
        {
            "project": config.project_name,
            "job_type": "embed_transcripts",
            "config": asdict(config),
        }
    )

    # get data
    df = get_data(artifact_name=config.summarized_que_data_artifact)

    # create embeddings
    with get_openai_callback() as cb:
        for episode in tqdm(df.iterrows(), total=len(df), desc="Embedding transcripts"):
            episode_data = episode[1].to_frame().T

            create_embeddings(episode_data, index=episode[0])

        print("*" * 25)
        print(cb)
        print("*" * 25)

        wandb.log(
            {
                "total_prompt_tokens": cb.prompt_tokens,
                "total_completion_tokens": cb.completion_tokens,
                "total_tokens": cb.total_tokens,
                "total_cost": cb.total_cost,
            }
        )

    # log embeddings to wandb artifact
    artifact = wandb.Artifact("transcript_embeddings", type="dataset")
    artifact.add_dir(config.root_data_dir / "chromadb")
    wandb.log_artifact(artifact)

    WandbTracer.finish()