File size: 4,060 Bytes
d1430bc
 
 
 
 
 
 
 
 
 
 
 
 
10b23b5
d1430bc
 
 
10b23b5
d1430bc
10b23b5
 
 
d1430bc
 
 
 
 
 
 
 
 
 
 
 
 
c075310
d1430bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10b23b5
d1430bc
 
 
 
 
 
c270693
d1430bc
 
 
 
 
 
 
 
 
c075310
 
 
d1430bc
 
 
 
 
 
 
 
 
 
 
 
10b23b5
 
1cc5040
10b23b5
 
 
d1430bc
 
 
 
 
10b23b5
d1430bc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
from dataclasses import asdict

import pandas as pd
from langchain.callbacks import get_openai_callback
from langchain.chains.summarize import load_summarize_chain
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import DataFrameLoader
from langchain.prompts import PromptTemplate
from langchain.text_splitter import TokenTextSplitter
from tqdm import tqdm
from wandb.integration.langchain import WandbTracer

import wandb
from config import config


def get_data(artifact_name: str, total_episodes: int = None):
    podcast_artifact = wandb.use_artifact(artifact_name, type="dataset")
    podcast_artifact_dir = podcast_artifact.download(config.root_artifact_dir)
    filename = artifact_name.split(":")[0].split("/")[-1]
    df = pd.read_csv(os.path.join(podcast_artifact_dir, f"{filename}.csv"))
    if total_episodes is not None:
        df = df.iloc[:total_episodes]
    return df


def summarize_episode(episode_df: pd.DataFrame):
    # load docs into langchain format
    loader = DataFrameLoader(episode_df, page_content_column="transcript")
    data = loader.load()

    # split the documents
    text_splitter = TokenTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0)
    docs = text_splitter.split_documents(data)
    print(f"Number of documents for podcast {data[0].metadata['title']}: {len(docs)}")

    # initialize LLM
    llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)

    # define map prompt
    map_prompt = """Write a concise summary of the following short transcript from a podcast.
    Don't add your opinions or interpretations.

    {text}

    CONCISE SUMMARY:"""

    # define combine prompt
    combine_prompt = """You have been provided with summaries of chunks of transcripts from a podcast.
    Your task is to merge these intermediate summaries to create a brief and comprehensive summary of the entire podcast.
    The summary should encompass all the crucial points of the podcast.
    Ensure that the summary is atleast 2 paragraph long and effectively captures the essence of the podcast.
    {text}

    SUMMARY:"""

    map_prompt_template = PromptTemplate(template=map_prompt, input_variables=["text"])
    combine_prompt_template = PromptTemplate(template=combine_prompt, input_variables=["text"])

    # initialize the summarizer chain
    chain = load_summarize_chain(
        llm,
        chain_type="map_reduce",
        return_intermediate_steps=True,
        map_prompt=map_prompt_template,
        combine_prompt=combine_prompt_template,
    )

    summary = chain({"input_documents": docs})
    return summary


if __name__ == "__main__":
    # initialize wandb tracer
    WandbTracer.init(
        {
            "project": config.project_name,
            "job_type": "summarize",
            "config": asdict(config),
        }
    )

    # get scraped data
    df = get_data(artifact_name=config.yt_podcast_data_artifact)

    summaries = []
    with get_openai_callback() as cb:
        for episode in tqdm(df.iterrows(), total=len(df), desc="Summarizing episodes"):
            episode_data = episode[1].to_frame().T

            summary = summarize_episode(episode_data)
            summaries.append(summary["output_text"])

        print("*" * 25)
        print(cb)
        print("*" * 25)

        wandb.log(
            {
                "total_prompt_tokens": cb.prompt_tokens,
                "total_completion_tokens": cb.completion_tokens,
                "total_tokens": cb.total_tokens,
                "total_cost": cb.total_cost,
            }
        )

    df["summary"] = summaries

    # save data
    path_to_save = os.path.join(config.root_data_dir, "summarized_podcasts.csv")
    df.to_csv(path_to_save, index=False)

    # log to wandb artifact
    artifact = wandb.Artifact("summarized_podcasts", type="dataset")
    artifact.add_file(path_to_save)
    wandb.log_artifact(artifact)

    # create wandb table
    table = wandb.Table(dataframe=df)
    wandb.log({"summarized_podcasts": table})

    WandbTracer.finish()