Commit
•
d8e387b
1
Parent(s):
5aee375
feat: Notebook for training a model
Browse filesThis notebook assumes 2 files exist in your Livebook: fraudTest.csv and fraudTrain.csv. These files can be found on Kaggle.com "Credit Card Transactions Fraud Detection Dataset".
- livebooks/training.livemd +190 -0
livebooks/training.livemd
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!-- livebook:{"file_entries":[{"name":"fraudTest.csv","type":"attachment"},{"name":"fraudTrain.csv","type":"attachment"}]} -->
|
2 |
+
|
3 |
+
# Training
|
4 |
+
|
5 |
+
```elixir
|
6 |
+
Mix.install(
|
7 |
+
[
|
8 |
+
{:kino_bumblebee, "~> 0.4.0"},
|
9 |
+
{:exla, ">= 0.0.0"},
|
10 |
+
{:kino, "~> 0.11.0"},
|
11 |
+
{:kino_explorer, "~> 0.1.11"}
|
12 |
+
],
|
13 |
+
config: [nx: [default_backend: EXLA.Backend]]
|
14 |
+
)
|
15 |
+
```
|
16 |
+
|
17 |
+
## Section
|
18 |
+
|
19 |
+
```elixir
|
20 |
+
# {:ok, spec} = Bumblebee.load_spec({:hf, ""})
|
21 |
+
```
|
22 |
+
|
23 |
+
```elixir
|
24 |
+
training_df =
|
25 |
+
Kino.FS.file_path("fraudTrain.csv")
|
26 |
+
|> Explorer.DataFrame.from_csv!()
|
27 |
+
|> Explorer.DataFrame.select(["merchant", "category"])
|
28 |
+
```
|
29 |
+
|
30 |
+
```elixir
|
31 |
+
test_df =
|
32 |
+
Kino.FS.file_path("fraudTest.csv")
|
33 |
+
|> Explorer.DataFrame.from_csv!()
|
34 |
+
|> Explorer.DataFrame.select(["merchant", "category"])
|
35 |
+
```
|
36 |
+
|
37 |
+
```elixir
|
38 |
+
labels =
|
39 |
+
training_df
|
40 |
+
|> Explorer.DataFrame.distinct(["category"])
|
41 |
+
|> Explorer.DataFrame.to_series()
|
42 |
+
|> Map.get("category")
|
43 |
+
|> Explorer.Series.to_list()
|
44 |
+
```
|
45 |
+
|
46 |
+
```elixir
|
47 |
+
model_name = "facebook/bart-large-mnli"
|
48 |
+
|
49 |
+
{:ok, spec} =
|
50 |
+
Bumblebee.load_spec({:hf, model_name},
|
51 |
+
architecture: :for_sequence_classification
|
52 |
+
)
|
53 |
+
|
54 |
+
num_labels = Enum.count(labels)
|
55 |
+
|
56 |
+
id_to_label =
|
57 |
+
labels
|
58 |
+
|> Enum.with_index(fn item, index -> {index, item} end)
|
59 |
+
|> Enum.into(%{})
|
60 |
+
|
61 |
+
spec =
|
62 |
+
Bumblebee.configure(spec, num_labels: num_labels, id_to_label: id_to_label)
|
63 |
+
|
64 |
+
{:ok, model_info} = Bumblebee.load_model({:hf, model_name}, spec: spec)
|
65 |
+
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, model_name})
|
66 |
+
|
67 |
+
# serving =
|
68 |
+
# Bumblebee.Text.zero_shot_classification(model_info, tokenizer, labels,
|
69 |
+
# compile: [batch_size: 1, sequence_length: 100],
|
70 |
+
# defn_options: [compiler: EXLA]
|
71 |
+
# )
|
72 |
+
```
|
73 |
+
|
74 |
+
```elixir
|
75 |
+
defmodule Finance do
|
76 |
+
def load(df, tokenizer, opts \\ []) do
|
77 |
+
df
|
78 |
+
|> stream()
|
79 |
+
|> tokenize_and_batch(
|
80 |
+
tokenizer,
|
81 |
+
opts[:batch_size],
|
82 |
+
opts[:sequence_length],
|
83 |
+
opts[:id_to_label]
|
84 |
+
)
|
85 |
+
end
|
86 |
+
|
87 |
+
def stream(df) do
|
88 |
+
xs = df["merchant"]
|
89 |
+
ys = df["category"]
|
90 |
+
|
91 |
+
xs
|
92 |
+
|> Explorer.Series.to_enum()
|
93 |
+
|> Stream.zip(Explorer.Series.to_enum(ys))
|
94 |
+
end
|
95 |
+
|
96 |
+
def tokenize_and_batch(stream, tokenizer, batch_size, sequence_length, id_to_label) do
|
97 |
+
stream
|
98 |
+
|> Stream.chunk_every(batch_size)
|
99 |
+
|> Stream.map(fn batch ->
|
100 |
+
{text, labels} = Enum.unzip(batch)
|
101 |
+
|
102 |
+
id_to_label_values = id_to_label |> Map.values()
|
103 |
+
|
104 |
+
label_ids =
|
105 |
+
Enum.map(labels, fn item ->
|
106 |
+
Enum.find_index(id_to_label_values, fn label_value -> label_value == item end)
|
107 |
+
end)
|
108 |
+
|
109 |
+
tokenized = Bumblebee.apply_tokenizer(tokenizer, text, length: sequence_length)
|
110 |
+
{tokenized, Nx.stack(label_ids)}
|
111 |
+
end)
|
112 |
+
end
|
113 |
+
end
|
114 |
+
```
|
115 |
+
|
116 |
+
```elixir
|
117 |
+
batch_size = 32
|
118 |
+
sequence_length = 64
|
119 |
+
|
120 |
+
train_data =
|
121 |
+
training_df
|
122 |
+
|> Finance.load(tokenizer,
|
123 |
+
batch_size: batch_size,
|
124 |
+
sequence_length: sequence_length,
|
125 |
+
id_to_label: id_to_label
|
126 |
+
)
|
127 |
+
|
128 |
+
test_data =
|
129 |
+
test_df
|
130 |
+
|> Finance.load(tokenizer,
|
131 |
+
batch_size: batch_size,
|
132 |
+
sequence_length: sequence_length,
|
133 |
+
id_to_label: id_to_label
|
134 |
+
)
|
135 |
+
```
|
136 |
+
|
137 |
+
```elixir
|
138 |
+
train_data = Enum.take(train_data, 250)
|
139 |
+
test_data = Enum.take(test_data, 50)
|
140 |
+
:ok
|
141 |
+
```
|
142 |
+
|
143 |
+
```elixir
|
144 |
+
%{model: model, params: params} = model_info
|
145 |
+
|
146 |
+
model
|
147 |
+
```
|
148 |
+
|
149 |
+
```elixir
|
150 |
+
[{input, _}] = Enum.take(train_data, 1)
|
151 |
+
Axon.get_output_shape(model, input)
|
152 |
+
```
|
153 |
+
|
154 |
+
```elixir
|
155 |
+
logits_model = Axon.nx(model, & &1.logits)
|
156 |
+
```
|
157 |
+
|
158 |
+
```elixir
|
159 |
+
loss =
|
160 |
+
&Axon.Losses.categorical_cross_entropy(&1, &2,
|
161 |
+
reduction: :mean,
|
162 |
+
from_logits: true,
|
163 |
+
sparse: true
|
164 |
+
)
|
165 |
+
|
166 |
+
optimizer = Polaris.Optimizers.adam(learning_rate: 5.0e-5)
|
167 |
+
|
168 |
+
loop = Axon.Loop.trainer(logits_model, loss, optimizer, log: 1)
|
169 |
+
```
|
170 |
+
|
171 |
+
```elixir
|
172 |
+
accuracy = &Axon.Metrics.accuracy(&1, &2, from_logits: true, sparse: true)
|
173 |
+
|
174 |
+
loop = Axon.Loop.metric(loop, accuracy, "accuracy")
|
175 |
+
```
|
176 |
+
|
177 |
+
```elixir
|
178 |
+
loop = Axon.Loop.checkpoint(loop, event: :epoch_completed)
|
179 |
+
```
|
180 |
+
|
181 |
+
```elixir
|
182 |
+
trained_model_state =
|
183 |
+
logits_model
|
184 |
+
|> Axon.Loop.trainer(loss, optimizer, log: 1)
|
185 |
+
|> Axon.Loop.metric(accuracy, "accuracy")
|
186 |
+
|> Axon.Loop.checkpoint(event: :epoch_completed)
|
187 |
+
|> Axon.Loop.run(train_data, params, epochs: 3, compiler: EXLA, strict?: false)
|
188 |
+
|
189 |
+
:ok
|
190 |
+
```
|