import os from dataclasses import asdict import pandas as pd from langchain.callbacks import get_openai_callback from langchain.chains.summarize import load_summarize_chain from langchain.chat_models import ChatOpenAI from langchain.document_loaders import DataFrameLoader from langchain.prompts import PromptTemplate from langchain.text_splitter import TokenTextSplitter from tqdm import tqdm from wandb.integration.langchain import WandbTracer import wandb from config import config def get_data(artifact_name: str, total_episodes: int = None): podcast_artifact = wandb.use_artifact(artifact_name, type="dataset") podcast_artifact_dir = podcast_artifact.download(config.root_artifact_dir) filename = artifact_name.split(":")[0].split("/")[-1] df = pd.read_csv(os.path.join(podcast_artifact_dir, f"{filename}.csv")) if total_episodes is not None: df = df.iloc[:total_episodes] return df def summarize_episode(episode_df: pd.DataFrame): # load docs into langchain format loader = DataFrameLoader(episode_df, page_content_column="transcript") data = loader.load() # split the documents text_splitter = TokenTextSplitter.from_tiktoken_encoder(chunk_size=1000, chunk_overlap=0) docs = text_splitter.split_documents(data) print(f"Number of documents for podcast {data[0].metadata['title']}: {len(docs)}") # initialize LLM llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0) # define map prompt map_prompt = """Write a concise summary of the following short transcript from a podcast. Don't add your opinions or interpretations. {text} CONCISE SUMMARY:""" # define combine prompt combine_prompt = """You have been provided with summaries of chunks of transcripts from a podcast. Your task is to merge these intermediate summaries to create a brief and comprehensive summary of the entire podcast. The summary should encompass all the crucial points of the podcast. Ensure that the summary is atleast 2 paragraph long and effectively captures the essence of the podcast. {text} SUMMARY:""" map_prompt_template = PromptTemplate(template=map_prompt, input_variables=["text"]) combine_prompt_template = PromptTemplate(template=combine_prompt, input_variables=["text"]) # initialize the summarizer chain chain = load_summarize_chain( llm, chain_type="map_reduce", return_intermediate_steps=True, map_prompt=map_prompt_template, combine_prompt=combine_prompt_template, ) summary = chain({"input_documents": docs}) return summary if __name__ == "__main__": # initialize wandb tracer WandbTracer.init( { "project": config.project_name, "job_type": "summarize", "config": asdict(config), } ) # get scraped data df = get_data(artifact_name=config.yt_podcast_data_artifact) summaries = [] with get_openai_callback() as cb: for episode in tqdm(df.iterrows(), total=len(df), desc="Summarizing episodes"): episode_data = episode[1].to_frame().T summary = summarize_episode(episode_data) summaries.append(summary["output_text"]) print("*" * 25) print(cb) print("*" * 25) wandb.log( { "total_prompt_tokens": cb.prompt_tokens, "total_completion_tokens": cb.completion_tokens, "total_tokens": cb.total_tokens, "total_cost": cb.total_cost, } ) df["summary"] = summaries # save data path_to_save = os.path.join(config.root_data_dir, "summarized_podcasts.csv") df.to_csv(path_to_save, index=False) # log to wandb artifact artifact = wandb.Artifact("summarized_podcasts", type="dataset") artifact.add_file(path_to_save) wandb.log_artifact(artifact) # create wandb table table = wandb.Table(dataframe=df) wandb.log({"summarized_podcasts": table}) WandbTracer.finish()