timgremore commited on
Commit
d8e387b
1 Parent(s): 5aee375

feat: Notebook for training a model

Browse files

This notebook assumes 2 files exist in your Livebook: fraudTest.csv and fraudTrain.csv. These files can be found on Kaggle.com "Credit Card Transactions Fraud Detection Dataset".

Files changed (1) hide show
  1. livebooks/training.livemd +190 -0
livebooks/training.livemd ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- livebook:{"file_entries":[{"name":"fraudTest.csv","type":"attachment"},{"name":"fraudTrain.csv","type":"attachment"}]} -->
2
+
3
+ # Training
4
+
5
+ ```elixir
6
+ Mix.install(
7
+ [
8
+ {:kino_bumblebee, "~> 0.4.0"},
9
+ {:exla, ">= 0.0.0"},
10
+ {:kino, "~> 0.11.0"},
11
+ {:kino_explorer, "~> 0.1.11"}
12
+ ],
13
+ config: [nx: [default_backend: EXLA.Backend]]
14
+ )
15
+ ```
16
+
17
+ ## Section
18
+
19
+ ```elixir
20
+ # {:ok, spec} = Bumblebee.load_spec({:hf, ""})
21
+ ```
22
+
23
+ ```elixir
24
+ training_df =
25
+ Kino.FS.file_path("fraudTrain.csv")
26
+ |> Explorer.DataFrame.from_csv!()
27
+ |> Explorer.DataFrame.select(["merchant", "category"])
28
+ ```
29
+
30
+ ```elixir
31
+ test_df =
32
+ Kino.FS.file_path("fraudTest.csv")
33
+ |> Explorer.DataFrame.from_csv!()
34
+ |> Explorer.DataFrame.select(["merchant", "category"])
35
+ ```
36
+
37
+ ```elixir
38
+ labels =
39
+ training_df
40
+ |> Explorer.DataFrame.distinct(["category"])
41
+ |> Explorer.DataFrame.to_series()
42
+ |> Map.get("category")
43
+ |> Explorer.Series.to_list()
44
+ ```
45
+
46
+ ```elixir
47
+ model_name = "facebook/bart-large-mnli"
48
+
49
+ {:ok, spec} =
50
+ Bumblebee.load_spec({:hf, model_name},
51
+ architecture: :for_sequence_classification
52
+ )
53
+
54
+ num_labels = Enum.count(labels)
55
+
56
+ id_to_label =
57
+ labels
58
+ |> Enum.with_index(fn item, index -> {index, item} end)
59
+ |> Enum.into(%{})
60
+
61
+ spec =
62
+ Bumblebee.configure(spec, num_labels: num_labels, id_to_label: id_to_label)
63
+
64
+ {:ok, model_info} = Bumblebee.load_model({:hf, model_name}, spec: spec)
65
+ {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
66
+
67
+ # serving =
68
+ # Bumblebee.Text.zero_shot_classification(model_info, tokenizer, labels,
69
+ # compile: [batch_size: 1, sequence_length: 100],
70
+ # defn_options: [compiler: EXLA]
71
+ # )
72
+ ```
73
+
74
+ ```elixir
75
+ defmodule Finance do
76
+ def load(df, tokenizer, opts \\ []) do
77
+ df
78
+ |> stream()
79
+ |> tokenize_and_batch(
80
+ tokenizer,
81
+ opts[:batch_size],
82
+ opts[:sequence_length],
83
+ opts[:id_to_label]
84
+ )
85
+ end
86
+
87
+ def stream(df) do
88
+ xs = df["merchant"]
89
+ ys = df["category"]
90
+
91
+ xs
92
+ |> Explorer.Series.to_enum()
93
+ |> Stream.zip(Explorer.Series.to_enum(ys))
94
+ end
95
+
96
+ def tokenize_and_batch(stream, tokenizer, batch_size, sequence_length, id_to_label) do
97
+ stream
98
+ |> Stream.chunk_every(batch_size)
99
+ |> Stream.map(fn batch ->
100
+ {text, labels} = Enum.unzip(batch)
101
+
102
+ id_to_label_values = id_to_label |> Map.values()
103
+
104
+ label_ids =
105
+ Enum.map(labels, fn item ->
106
+ Enum.find_index(id_to_label_values, fn label_value -> label_value == item end)
107
+ end)
108
+
109
+ tokenized = Bumblebee.apply_tokenizer(tokenizer, text, length: sequence_length)
110
+ {tokenized, Nx.stack(label_ids)}
111
+ end)
112
+ end
113
+ end
114
+ ```
115
+
116
+ ```elixir
117
+ batch_size = 32
118
+ sequence_length = 64
119
+
120
+ train_data =
121
+ training_df
122
+ |> Finance.load(tokenizer,
123
+ batch_size: batch_size,
124
+ sequence_length: sequence_length,
125
+ id_to_label: id_to_label
126
+ )
127
+
128
+ test_data =
129
+ test_df
130
+ |> Finance.load(tokenizer,
131
+ batch_size: batch_size,
132
+ sequence_length: sequence_length,
133
+ id_to_label: id_to_label
134
+ )
135
+ ```
136
+
137
+ ```elixir
138
+ train_data = Enum.take(train_data, 250)
139
+ test_data = Enum.take(test_data, 50)
140
+ :ok
141
+ ```
142
+
143
+ ```elixir
144
+ %{model: model, params: params} = model_info
145
+
146
+ model
147
+ ```
148
+
149
+ ```elixir
150
+ [{input, _}] = Enum.take(train_data, 1)
151
+ Axon.get_output_shape(model, input)
152
+ ```
153
+
154
+ ```elixir
155
+ logits_model = Axon.nx(model, & &1.logits)
156
+ ```
157
+
158
+ ```elixir
159
+ loss =
160
+ &Axon.Losses.categorical_cross_entropy(&1, &2,
161
+ reduction: :mean,
162
+ from_logits: true,
163
+ sparse: true
164
+ )
165
+
166
+ optimizer = Polaris.Optimizers.adam(learning_rate: 5.0e-5)
167
+
168
+ loop = Axon.Loop.trainer(logits_model, loss, optimizer, log: 1)
169
+ ```
170
+
171
+ ```elixir
172
+ accuracy = &Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true)
173
+
174
+ loop = Axon.Loop.metric(loop, accuracy, "accuracy")
175
+ ```
176
+
177
+ ```elixir
178
+ loop = Axon.Loop.checkpoint(loop, event: :epoch_completed)
179
+ ```
180
+
181
+ ```elixir
182
+ trained_model_state =
183
+ logits_model
184
+ |> Axon.Loop.trainer(loss, optimizer, log: 1)
185
+ |> Axon.Loop.metric(accuracy, "accuracy")
186
+ |> Axon.Loop.checkpoint(event: :epoch_completed)
187
+ |> Axon.Loop.run(train_data, params, epochs: 3, compiler: EXLA, strict?: false)
188
+
189
+ :ok
190
+ ```