ncoop57 reshinthadith commited on
Commit
3e4a220
1 Parent(s): cf5eed6

Fix dataset loading bug (#1)

Browse files

- Fix dataset loading bug (675e604cfa3ab96c24d910dc12094232cffb2db3)
- Update app.py (f9bf4f822308bdbb5c5906ba49a28d817522d6fe)


Co-authored-by: reshinth.adith <reshinthadith@users.noreply.huggingface.co>

Files changed (1) hide show
  1. app.py +205 -177
app.py CHANGED
@@ -2,186 +2,214 @@ import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
  from functools import partial
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # ai4code_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/AI4Code")
7
- # amps_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/AMPS")
8
- # apache_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/ASFPublicMail")
9
- # books3_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/Books3")
10
- # cp_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/CPDataset")
11
- # dmmath_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/DMMath")
12
- # discourse_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/Discourse")
13
- # wiki_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/Enwiki")
14
- # euro_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/EuroParliamentProceedings")
15
- # freelaw_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/FreeLaw_Options")
16
- # ghdiffs_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/GitHubDiff")
17
- # ghissues_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/GitHubIssues")
18
- # gutenberg_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/Gutenberg")
19
- # leet_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/LeetCode")
20
- # pileoflaw_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/PileOfLaw")
21
- # pubmed_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/PubMed")
22
- # s2orc_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/S2ORC")
23
- # se_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/StackExchange")
24
- # usenet_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/USENET")
25
- # uspto_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/USPTO")
26
- # ubuntuirc_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/UbuntuIRC")
27
- # arxiv_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/arXiv")
28
 
29
  dataset_data = {
30
- "AI4Code": {
31
- # create fake data for the different ratios
32
- "word_rep_ratios": np.random.randn(1000),
33
- "char_rep_ratios": np.random.randn(1000),
34
- "flagged_word_ratios": np.random.randn(1000),
35
- "num_words": np.random.randint(0, 1000, 1000),
36
- },
37
- "AMPS": {
38
- # create fake data for the different ratios
39
- "word_rep_ratios": np.random.randn(1000),
40
- "char_rep_ratios": np.random.randn(1000),
41
- "flagged_word_ratios": np.random.randn(1000),
42
- "num_words": np.random.randint(0, 1000, 1000),
43
- },
44
- "ASFPublicMail": {
45
- # create fake data for the different ratios
46
- "word_rep_ratios": np.random.randn(1000),
47
- "char_rep_ratios": np.random.randn(1000),
48
- "flagged_word_ratios": np.random.randn(1000),
49
- "num_words": np.random.randint(0, 1000, 1000),
50
- },
51
- "Books3": {
52
- # create fake data for the different ratios
53
- "word_rep_ratios": np.random.randn(1000),
54
- "char_rep_ratios": np.random.randn(1000),
55
- "flagged_word_ratios": np.random.randn(1000),
56
- "num_words": np.random.randint(0, 1000, 1000),
57
- },
58
- "CPDataset": {
59
- # create fake data for the different ratios
60
- "word_rep_ratios": np.random.randn(1000),
61
- "char_rep_ratios": np.random.randn(1000),
62
- "flagged_word_ratios": np.random.randn(1000),
63
- "num_words": np.random.randint(0, 1000, 1000),
64
- },
65
- "DMMath": {
66
- # create fake data for the different ratios
67
- "word_rep_ratios": np.random.randn(1000),
68
- "char_rep_ratios": np.random.randn(1000),
69
- "flagged_word_ratios": np.random.randn(1000),
70
- "num_words": np.random.randint(0, 1000, 1000),
71
- },
72
- "Discourse": {
73
- # create fake data for the different ratios
74
- "word_rep_ratios": np.random.randn(1000),
75
- "char_rep_ratios": np.random.randn(1000),
76
- "flagged_word_ratios": np.random.randn(1000),
77
- "num_words": np.random.randint(0, 1000, 1000),
78
- },
79
- "Enwiki": {
80
- # create fake data for the different ratios
81
- "word_rep_ratios": np.random.randn(1000),
82
- "char_rep_ratios": np.random.randn(1000),
83
- "flagged_word_ratios": np.random.randn(1000),
84
- "num_words": np.random.randint(0, 1000, 1000),
85
- },
86
- "EuroParliamentProceedings": {
87
- # create fake data for the different ratios
88
- "word_rep_ratios": np.random.randn(1000),
89
- "char_rep_ratios": np.random.randn(1000),
90
- "flagged_word_ratios": np.random.randn(1000),
91
- "num_words": np.random.randint(0, 1000, 1000),
92
- },
93
- "FreeLaw_Options": {
94
- # create fake data for the different ratios
95
- "word_rep_ratios": np.random.randn(1000),
96
- "char_rep_ratios": np.random.randn(1000),
97
- "flagged_word_ratios": np.random.randn(1000),
98
- "num_words": np.random.randint(0, 1000, 1000),
99
- },
100
- "GitHubDiff": {
101
- # create fake data for the different ratios
102
- "word_rep_ratios": np.random.randn(1000),
103
- "char_rep_ratios": np.random.randn(1000),
104
- "flagged_word_ratios": np.random.randn(1000),
105
- "num_words": np.random.randint(0, 1000, 1000),
106
- },
107
- "GitHubIssues": {
108
- # create fake data for the different ratios
109
- "word_rep_ratios": np.random.randn(1000),
110
- "char_rep_ratios": np.random.randn(1000),
111
- "flagged_word_ratios": np.random.randn(1000),
112
- "num_words": np.random.randint(0, 1000, 1000),
113
- },
114
- "Gutenberg": {
115
- # create fake data for the different ratios
116
- "word_rep_ratios": np.random.randn(1000),
117
- "char_rep_ratios": np.random.randn(1000),
118
- "flagged_word_ratios": np.random.randn(1000),
119
- "num_words": np.random.randint(0, 1000, 1000),
120
- },
121
- "LeetCode": {
122
- # create fake data for the different ratios
123
- "word_rep_ratios": np.random.randn(1000),
124
- "char_rep_ratios": np.random.randn(1000),
125
- "flagged_word_ratios": np.random.randn(1000),
126
- "num_words": np.random.randint(0, 1000, 1000),
127
- },
128
- "PileOfLaw": {
129
- # create fake data for the different ratios
130
- "word_rep_ratios": np.random.randn(1000),
131
- "char_rep_ratios": np.random.randn(1000),
132
- "flagged_word_ratios": np.random.randn(1000),
133
- "num_words": np.random.randint(0, 1000, 1000),
134
- },
135
- "PubMed": {
136
- # create fake data for the different ratios
137
- "word_rep_ratios": np.random.randn(1000),
138
- "char_rep_ratios": np.random.randn(1000),
139
- "flagged_word_ratios": np.random.randn(1000),
140
- "num_words": np.random.randint(0, 1000, 1000),
141
- },
142
- "S2ORC": {
143
- # create fake data for the different ratios
144
- "word_rep_ratios": np.random.randn(1000),
145
- "char_rep_ratios": np.random.randn(1000),
146
- "flagged_word_ratios": np.random.randn(1000),
147
- "num_words": np.random.randint(0, 1000, 1000),
148
- },
149
- "StackExchange": {
150
- # create fake data for the different ratios
151
- "word_rep_ratios": np.random.randn(1000),
152
- "char_rep_ratios": np.random.randn(1000),
153
- "flagged_word_ratios": np.random.randn(1000),
154
- "num_words": np.random.randint(0, 1000, 1000),
155
- },
156
- "USENET": {
157
- # create fake data for the different ratios
158
- "word_rep_ratios": np.random.randn(1000),
159
- "char_rep_ratios": np.random.randn(1000),
160
- "flagged_word_ratios": np.random.randn(1000),
161
- "num_words": np.random.randint(0, 1000, 1000),
162
- },
163
- "USPTO": {
164
- # create fake data for the different ratios
165
- "word_rep_ratios": np.random.randn(1000),
166
- "char_rep_ratios": np.random.randn(1000),
167
- "flagged_word_ratios": np.random.randn(1000),
168
- "num_words": np.random.randint(0, 1000, 1000),
169
- },
170
- "UbuntuIRC": {
171
- # create fake data for the different ratios
172
- "word_rep_ratios": np.random.randn(1000),
173
- "char_rep_ratios": np.random.randn(1000),
174
- "flagged_word_ratios": np.random.randn(1000),
175
- "num_words": np.random.randint(0, 1000, 1000),
176
- },
177
- "arXiv": {
178
- # create fake data for the different ratios
179
- "word_rep_ratios": np.random.randn(1000),
180
- "char_rep_ratios": np.random.randn(1000),
181
- "flagged_word_ratios": np.random.randn(1000),
182
- "num_words": np.random.randint(0, 1000, 1000),
183
- },
184
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
  def plt_plot(ratio, dataset, threshold):
187
  x = dataset_data[dataset][ratio]
 
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
  from functools import partial
5
+ import datasets
6
+ from datasets import load_dataset
7
+
8
+ ai4code_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/AI4Code/data.json")
9
+ amps_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/AMPS/data.json")
10
+ apache_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/ASFPublicMail/data.json")
11
+ books3_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/Books3/data.json")
12
+ cp_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/CPDataset/data.json")
13
+ dmmath_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/DMMath/data.json")
14
+ discourse_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/Discourse/data.json")
15
+ wiki_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/Enwiki/data.json")
16
+ euro_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/EuroParliamentProceedings/data.json")
17
+ freelaw_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/FreeLaw_Options/data.json")
18
+ ghdiffs_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/GitHubDiff/data.json")
19
+ ghissues_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/GitHubIssues/data.json")
20
+ gutenberg_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/Gutenberg/data.json")
21
+ leet_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/LeetCode/data.json")
22
+ pileoflaw_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/PileOfLaw/data.json")
23
+ pubmed_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/PubMed/data.json")
24
+ s2orc_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/S2ORC/data.json")
25
+ se_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/StackExchange/data.json")
26
+ usenet_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/USENET/data.json")
27
+ uspto_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/USPTO/data.json")
28
+ ubuntuirc_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/UbuntuIRC/data.json")
29
+ arxiv_ds = load_dataset("CarperAI/pile-v2-small", data_dir="data/arXiv/data.json")
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  dataset_data = {
33
+ "ai4code" : ai4code_ds["train"],
34
+ "amps" : amps_ds["train"],
35
+ "apache" : apache_ds["train"],
36
+ "books3" : books3_ds["train"],
37
+ "competitive_programming" : cp_ds["train"],
38
+ "dmmath" : dmmath_ds["train"],
39
+ "discourse" : discourse_ds["train"],
40
+ "enwiki" : wiki_ds["train"],
41
+ "euro" : euro_ds["train"],
42
+ "freelaw" : freelaw_ds["train"],
43
+ "ghdiffs" : ghdiffs_ds["train"],
44
+ "ghissues" : ghissues_ds["train"],
45
+ "gutenberg" : gutenberg_ds["train"],
46
+ "leetcode" : leet_ds["train"],
47
+ "pileoflaw" : pileoflaw_ds["train"],
48
+ "pubmed" : pubmed_ds["train"],
49
+ "s2orc" : s2orc_ds["train"],
50
+ "se" : se_ds["train"],
51
+ "usenet" : usenet_ds["train"],
52
+ "uspto" : uspto_ds["train"],
53
+ "ubuntuirc" : ubuntuirc_ds["train"],
54
+ "arxiv" : arxiv_ds["train"]
55
+ }
56
+
57
+ # dataset_data = {
58
+ # "AI4Code": {
59
+ # # create fake data for the different ratios
60
+ # "word_rep_ratios": np.random.randn(1000),
61
+ # "char_rep_ratios": np.random.randn(1000),
62
+ # "flagged_word_ratios": np.random.randn(1000),
63
+ # "num_words": np.random.randint(0, 1000, 1000),
64
+ # },
65
+ # "AMPS": {
66
+ # # create fake data for the different ratios
67
+ # "word_rep_ratios": np.random.randn(1000),
68
+ # "char_rep_ratios": np.random.randn(1000),
69
+ # "flagged_word_ratios": np.random.randn(1000),
70
+ # "num_words": np.random.randint(0, 1000, 1000),
71
+ # },
72
+ # "ASFPublicMail": {
73
+ # # create fake data for the different ratios
74
+ # "word_rep_ratios": np.random.randn(1000),
75
+ # "char_rep_ratios": np.random.randn(1000),
76
+ # "flagged_word_ratios": np.random.randn(1000),
77
+ # "num_words": np.random.randint(0, 1000, 1000),
78
+ # },
79
+ # "Books3": {
80
+ # # create fake data for the different ratios
81
+ # "word_rep_ratios": np.random.randn(1000),
82
+ # "char_rep_ratios": np.random.randn(1000),
83
+ # "flagged_word_ratios": np.random.randn(1000),
84
+ # "num_words": np.random.randint(0, 1000, 1000),
85
+ # },
86
+ # "CPDataset": {
87
+ # # create fake data for the different ratios
88
+ # "word_rep_ratios": np.random.randn(1000),
89
+ # "char_rep_ratios": np.random.randn(1000),
90
+ # "flagged_word_ratios": np.random.randn(1000),
91
+ # "num_words": np.random.randint(0, 1000, 1000),
92
+ # },
93
+ # "DMMath": {
94
+ # # create fake data for the different ratios
95
+ # "word_rep_ratios": np.random.randn(1000),
96
+ # "char_rep_ratios": np.random.randn(1000),
97
+ # "flagged_word_ratios": np.random.randn(1000),
98
+ # "num_words": np.random.randint(0, 1000, 1000),
99
+ # },
100
+ # "Discourse": {
101
+ # # create fake data for the different ratios
102
+ # "word_rep_ratios": np.random.randn(1000),
103
+ # "char_rep_ratios": np.random.randn(1000),
104
+ # "flagged_word_ratios": np.random.randn(1000),
105
+ # "num_words": np.random.randint(0, 1000, 1000),
106
+ # },
107
+ # "Enwiki": {
108
+ # # create fake data for the different ratios
109
+ # "word_rep_ratios": np.random.randn(1000),
110
+ # "char_rep_ratios": np.random.randn(1000),
111
+ # "flagged_word_ratios": np.random.randn(1000),
112
+ # "num_words": np.random.randint(0, 1000, 1000),
113
+ # },
114
+ # "EuroParliamentProceedings": {
115
+ # # create fake data for the different ratios
116
+ # "word_rep_ratios": np.random.randn(1000),
117
+ # "char_rep_ratios": np.random.randn(1000),
118
+ # "flagged_word_ratios": np.random.randn(1000),
119
+ # "num_words": np.random.randint(0, 1000, 1000),
120
+ # },
121
+ # "FreeLaw_Options": {
122
+ # # create fake data for the different ratios
123
+ # "word_rep_ratios": np.random.randn(1000),
124
+ # "char_rep_ratios": np.random.randn(1000),
125
+ # "flagged_word_ratios": np.random.randn(1000),
126
+ # "num_words": np.random.randint(0, 1000, 1000),
127
+ # },
128
+ # "GitHubDiff": {
129
+ # # create fake data for the different ratios
130
+ # "word_rep_ratios": np.random.randn(1000),
131
+ # "char_rep_ratios": np.random.randn(1000),
132
+ # "flagged_word_ratios": np.random.randn(1000),
133
+ # "num_words": np.random.randint(0, 1000, 1000),
134
+ # },
135
+ # "GitHubIssues": {
136
+ # # create fake data for the different ratios
137
+ # "word_rep_ratios": np.random.randn(1000),
138
+ # "char_rep_ratios": np.random.randn(1000),
139
+ # "flagged_word_ratios": np.random.randn(1000),
140
+ # "num_words": np.random.randint(0, 1000, 1000),
141
+ # },
142
+ # "Gutenberg": {
143
+ # # create fake data for the different ratios
144
+ # "word_rep_ratios": np.random.randn(1000),
145
+ # "char_rep_ratios": np.random.randn(1000),
146
+ # "flagged_word_ratios": np.random.randn(1000),
147
+ # "num_words": np.random.randint(0, 1000, 1000),
148
+ # },
149
+ # "LeetCode": {
150
+ # # create fake data for the different ratios
151
+ # "word_rep_ratios": np.random.randn(1000),
152
+ # "char_rep_ratios": np.random.randn(1000),
153
+ # "flagged_word_ratios": np.random.randn(1000),
154
+ # "num_words": np.random.randint(0, 1000, 1000),
155
+ # },
156
+ # "PileOfLaw": {
157
+ # # create fake data for the different ratios
158
+ # "word_rep_ratios": np.random.randn(1000),
159
+ # "char_rep_ratios": np.random.randn(1000),
160
+ # "flagged_word_ratios": np.random.randn(1000),
161
+ # "num_words": np.random.randint(0, 1000, 1000),
162
+ # },
163
+ # "PubMed": {
164
+ # # create fake data for the different ratios
165
+ # "word_rep_ratios": np.random.randn(1000),
166
+ # "char_rep_ratios": np.random.randn(1000),
167
+ # "flagged_word_ratios": np.random.randn(1000),
168
+ # "num_words": np.random.randint(0, 1000, 1000),
169
+ # },
170
+ # "S2ORC": {
171
+ # # create fake data for the different ratios
172
+ # "word_rep_ratios": np.random.randn(1000),
173
+ # "char_rep_ratios": np.random.randn(1000),
174
+ # "flagged_word_ratios": np.random.randn(1000),
175
+ # "num_words": np.random.randint(0, 1000, 1000),
176
+ # },
177
+ # "StackExchange": {
178
+ # # create fake data for the different ratios
179
+ # "word_rep_ratios": np.random.randn(1000),
180
+ # "char_rep_ratios": np.random.randn(1000),
181
+ # "flagged_word_ratios": np.random.randn(1000),
182
+ # "num_words": np.random.randint(0, 1000, 1000),
183
+ # },
184
+ # "USENET": {
185
+ # # create fake data for the different ratios
186
+ # "word_rep_ratios": np.random.randn(1000),
187
+ # "char_rep_ratios": np.random.randn(1000),
188
+ # "flagged_word_ratios": np.random.randn(1000),
189
+ # "num_words": np.random.randint(0, 1000, 1000),
190
+ # },
191
+ # "USPTO": {
192
+ # # create fake data for the different ratios
193
+ # "word_rep_ratios": np.random.randn(1000),
194
+ # "char_rep_ratios": np.random.randn(1000),
195
+ # "flagged_word_ratios": np.random.randn(1000),
196
+ # "num_words": np.random.randint(0, 1000, 1000),
197
+ # },
198
+ # "UbuntuIRC": {
199
+ # # create fake data for the different ratios
200
+ # "word_rep_ratios": np.random.randn(1000),
201
+ # "char_rep_ratios": np.random.randn(1000),
202
+ # "flagged_word_ratios": np.random.randn(1000),
203
+ # "num_words": np.random.randint(0, 1000, 1000),
204
+ # },
205
+ # "arXiv": {
206
+ # # create fake data for the different ratios
207
+ # "word_rep_ratios": np.random.randn(1000),
208
+ # "char_rep_ratios": np.random.randn(1000),
209
+ # "flagged_word_ratios": np.random.randn(1000),
210
+ # "num_words": np.random.randint(0, 1000, 1000),
211
+ # },
212
+ # }
213
 
214
  def plt_plot(ratio, dataset, threshold):
215
  x = dataset_data[dataset][ratio]