Gladiator commited on
Commit
1cc5040
1 Parent(s): 7cdb553

fix csv data saving + minor changes of names

Browse files
Files changed (3) hide show
  1. src/config.py +4 -2
  2. src/extract_questions.py +7 -5
  3. src/summarize.py +1 -1
src/config.py CHANGED
@@ -13,8 +13,10 @@ class Config:
13
  # wandb
14
  project_name: str = "gradient_dissent_qabot"
15
  yt_podcast_data_artifact: str = "gladiator/gradient_dissent_qabot/yt_podcast_transcript:latest"
16
- # summarized_data_artifact: str = "gladiator/gradient_dissent_bot/summary_data:latest"
17
- # summarized_que_data_artifact: str = "gladiator/gradient_dissent_bot/summary_que_data:latest"
 
 
18
 
19
 
20
  config = Config()
 
13
  # wandb
14
  project_name: str = "gradient_dissent_qabot"
15
  yt_podcast_data_artifact: str = "gladiator/gradient_dissent_qabot/yt_podcast_transcript:latest"
16
+ summarized_data_artifact: str = "gladiator/gradient_dissent_bot/summarized_podcasts:latest"
17
+ summarized_que_data_artifact: str = (
18
+ "gladiator/gradient_dissent_bot/summarized_que_podcasts:latest"
19
+ )
20
 
21
 
22
  config = Config()
src/extract_questions.py CHANGED
@@ -16,11 +16,13 @@ import wandb
16
  from config import config
17
 
18
 
19
- def get_data(artifact_name: str = "gladiator/gradient_dissent_bot/summary_data:latest"):
20
  podcast_artifact = wandb.use_artifact(artifact_name, type="dataset")
21
  podcast_artifact_dir = podcast_artifact.download(config.root_data_dir)
22
  filename = artifact_name.split(":")[0].split("/")[-1]
23
  df = pd.read_csv(os.path.join(podcast_artifact_dir, f"{filename}.csv"))
 
 
24
  return df
25
 
26
 
@@ -66,7 +68,6 @@ if __name__ == "__main__":
66
  WandbTracer.init(
67
  {
68
  "project": "gradient_dissent_bot",
69
- "name": "extract_questions",
70
  "job_type": "extract_questions",
71
  "config": asdict(config),
72
  }
@@ -101,14 +102,15 @@ if __name__ == "__main__":
101
  df["questions"] = questions
102
 
103
  # log to wandb artifact
104
- path_to_save = os.path.join(config.root_data_dir, "summary_que_data.csv")
105
  df.to_csv(path_to_save, index=False)
106
- artifact = wandb.Artifact("summary_que_data", type="dataset")
107
  artifact.add_file(path_to_save)
108
  wandb.log_artifact(artifact)
109
 
110
  # create wandb table
 
111
  table = wandb.Table(dataframe=df)
112
- wandb.log({"summary_que_data": table})
113
 
114
  WandbTracer.finish()
 
16
  from config import config
17
 
18
 
19
+ def get_data(artifact_name: str, total_episodes: int = None):
20
  podcast_artifact = wandb.use_artifact(artifact_name, type="dataset")
21
  podcast_artifact_dir = podcast_artifact.download(config.root_data_dir)
22
  filename = artifact_name.split(":")[0].split("/")[-1]
23
  df = pd.read_csv(os.path.join(podcast_artifact_dir, f"{filename}.csv"))
24
+ if total_episodes is not None:
25
+ df = df.iloc[:total_episodes]
26
  return df
27
 
28
 
 
68
  WandbTracer.init(
69
  {
70
  "project": "gradient_dissent_bot",
 
71
  "job_type": "extract_questions",
72
  "config": asdict(config),
73
  }
 
102
  df["questions"] = questions
103
 
104
  # log to wandb artifact
105
+ path_to_save = os.path.join(config.root_data_dir, "summarized_que_podcasts.csv")
106
  df.to_csv(path_to_save, index=False)
107
+ artifact = wandb.Artifact("summarized_que_podcasts", type="dataset")
108
  artifact.add_file(path_to_save)
109
  wandb.log_artifact(artifact)
110
 
111
  # create wandb table
112
+ df["questions"] = df["questions"].apply(lambda x: "\n".join(x))
113
  table = wandb.Table(dataframe=df)
114
+ wandb.log({"summarized_que_podcasts": table})
115
 
116
  WandbTracer.finish()
src/summarize.py CHANGED
@@ -109,7 +109,7 @@ if __name__ == "__main__":
109
 
110
  # save data
111
  path_to_save = os.path.join(config.root_data_dir, "summarized_podcasts.csv")
112
- df.to_csv(path_to_save)
113
 
114
  # log to wandb artifact
115
  artifact = wandb.Artifact("summarized_podcasts", type="dataset")
 
109
 
110
  # save data
111
  path_to_save = os.path.join(config.root_data_dir, "summarized_podcasts.csv")
112
+ df.to_csv(path_to_save, index=False)
113
 
114
  # log to wandb artifact
115
  artifact = wandb.Artifact("summarized_podcasts", type="dataset")