noahsettersten commited on
Commit
36e17a1
1 Parent(s): b3d858d

feat: Create mix task to precompute label vectors

Browse files
Files changed (4) hide show
  1. README.md +1 -0
  2. lib/mix/build_code_vectors.ex +84 -0
  3. mix.exs +1 -0
  4. mix.lock +1 -1
README.md CHANGED
@@ -12,6 +12,7 @@
12
  To start your Phoenix server:
13
 
14
  * Run `mix setup` to install and setup dependencies
 
15
  * Start Phoenix endpoint with `mix phx.server` or inside IEx with `iex -S mix phx.server`
16
 
17
  Now you can visit [`localhost:4000`](http://localhost:4000) from your browser.
 
12
  To start your Phoenix server:
13
 
14
  * Run `mix setup` to install and setup dependencies
15
+ * Run `mix build_code_vectors` to download the ICD-9 codelist, precompute vectors, and store the results in the database.
16
  * Start Phoenix endpoint with `mix phx.server` or inside IEx with `iex -S mix phx.server`
17
 
18
  Now you can visit [`localhost:4000`](http://localhost:4000) from your browser.
lib/mix/build_code_vectors.ex ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defmodule Mix.Tasks.BuildCodeVectors do
2
+ @moduledoc "Populate database with vector embeddings from downloaded ICD-9 code list"
3
+
4
+ use Mix.Task
5
+ alias MedicalTranscription.CodeVector
6
+
7
+ @shortdoc "Downloads the ICD-9 codelist, calculates vector embeddings for each, and adds them to the database"
8
+ def run(_args) do
9
+ Mix.Task.run("app.start")
10
+ Logger.configure(level: :info)
11
+
12
+ if File.exists?(csv_file()) do
13
+ IO.puts("CSV file found. Precomputing vectors...")
14
+ else
15
+ IO.puts("CSV file not found. Downloading and preparing...")
16
+ AudioTagger.SampleData.get_icd9_code_list_csv()
17
+
18
+ IO.puts("Precomputing vectors...")
19
+ end
20
+
21
+ precompute_vectors()
22
+
23
+ :ok
24
+ end
25
+
26
+ defp precompute_vectors() do
27
+ time_start = System.monotonic_time()
28
+ df = load_dataframe_from_csv()
29
+ model_tuple = AudioTagger.Classifier.SemanticSearch.prepare_model()
30
+
31
+ num_rows = Explorer.DataFrame.n_rows(df)
32
+ ProgressBar.render(0, num_rows, suffix: :count)
33
+
34
+ df
35
+ |> Explorer.DataFrame.to_rows_stream()
36
+ |> Stream.filter(fn %{"code" => code} -> String.length(code) > 0 end)
37
+ |> Stream.with_index()
38
+ |> Enum.each(fn {%{"code" => code, "long_description" => description}, index} ->
39
+ if !CodeVector.exists_for_code?(code) do
40
+ compute_vector_for_code(model_tuple, code, description)
41
+ end
42
+
43
+ ProgressBar.render(index + 1, num_rows, suffix: :count)
44
+ end)
45
+
46
+ ProgressBar.render(num_rows, num_rows, suffix: :count)
47
+ time_end = System.monotonic_time()
48
+
49
+ IO.puts(
50
+ "Finished in #{System.convert_time_unit(time_end - time_start, :native, :millisecond)}ms"
51
+ )
52
+ end
53
+
54
+ defp load_dataframe_from_csv() do
55
+ {:ok, df} =
56
+ Explorer.DataFrame.from_csv(
57
+ csv_file(),
58
+ dtypes: [
59
+ {"code", :string},
60
+ {"long_description", :string}
61
+ ]
62
+ )
63
+
64
+ df
65
+ end
66
+
67
+ defp compute_vector_for_code({model_info, tokenizer}, code, description) do
68
+ vector =
69
+ AudioTagger.Vectors.embed_with_model(model_info, tokenizer, [description])
70
+
71
+ vector_for_db = Nx.to_flat_list(vector.pooled_state)
72
+
73
+ CodeVector.insert_vector(%{
74
+ code: code,
75
+ description: description,
76
+ description_vector: vector_for_db
77
+ })
78
+ end
79
+
80
+ defp csv_file() do
81
+ AudioTagger.SampleData.cache_dir()
82
+ |> Path.join("icd9_codelist.csv")
83
+ end
84
+ end
mix.exs CHANGED
@@ -54,6 +54,7 @@ defmodule MedicalTranscription.MixProject do
54
  {:plug_cowboy, "~> 2.5"},
55
  {:credo, "~> 1.7.3"},
56
  {:audio_tagger, git: "https://github.com/headwayio/audio_tagger.git"},
 
57
  {:membrane_core, "~> 1.0"},
58
  {:membrane_raw_audio_format, "~> 0.12.0"}
59
  # {:membrane_portaudio_plugin, "~> 0.18.0"},
 
54
  {:plug_cowboy, "~> 2.5"},
55
  {:credo, "~> 1.7.3"},
56
  {:audio_tagger, git: "https://github.com/headwayio/audio_tagger.git"},
57
+ {:progress_bar, "~> 3.0"},
58
  {:membrane_core, "~> 1.0"},
59
  {:membrane_raw_audio_format, "~> 0.12.0"}
60
  # {:membrane_portaudio_plugin, "~> 0.18.0"},
mix.lock CHANGED
@@ -1,5 +1,5 @@
1
  %{
2
- "audio_tagger": {:git, "https://github.com/headwayio/audio_tagger.git", "b960515109d6249a792f37aa07452373d4d8bdd1", []},
3
  "aws_signature": {:hex, :aws_signature, "0.3.1", "67f369094cbd55ffa2bbd8cc713ede14b195fcfb45c86665cd7c5ad010276148", [:rebar3], [], "hexpm", "50fc4dc1d1f7c2d0a8c63f455b3c66ecd74c1cf4c915c768a636f9227704a674"},
4
  "axon": {:hex, :axon, "0.6.0", "fd7560079581e4cedebaf0cd5f741d6ac3516d06f204ebaf1283b1093bf66ff6", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}, {:polaris, "~> 0.1", [hex: :polaris, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "204e7aeb50d231a30b25456adf17bfbaae33fe7c085e03793357ac3bf62fd853"},
5
  "bimap": {:hex, :bimap, "1.3.0", "3ea4832e58dc83a9b5b407c6731e7bae87458aa618e6d11d8e12114a17afa4b3", [:mix], [], "hexpm", "bf5a2b078528465aa705f405a5c638becd63e41d280ada41e0f77e6d255a10b4"},
 
1
  %{
2
+ "audio_tagger": {:git, "https://github.com/headwayio/audio_tagger.git", "dc02cf6990ab5fef0100269c03faa6e2dbdbaea7", []},
3
  "aws_signature": {:hex, :aws_signature, "0.3.1", "67f369094cbd55ffa2bbd8cc713ede14b195fcfb45c86665cd7c5ad010276148", [:rebar3], [], "hexpm", "50fc4dc1d1f7c2d0a8c63f455b3c66ecd74c1cf4c915c768a636f9227704a674"},
4
  "axon": {:hex, :axon, "0.6.0", "fd7560079581e4cedebaf0cd5f741d6ac3516d06f204ebaf1283b1093bf66ff6", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: true]}, {:kino_vega_lite, "~> 0.1.7", [hex: :kino_vega_lite, repo: "hexpm", optional: true]}, {:nx, "~> 0.6.0", [hex: :nx, repo: "hexpm", optional: false]}, {:polaris, "~> 0.1", [hex: :polaris, repo: "hexpm", optional: false]}, {:table_rex, "~> 3.1.1", [hex: :table_rex, repo: "hexpm", optional: true]}], "hexpm", "204e7aeb50d231a30b25456adf17bfbaae33fe7c085e03793357ac3bf62fd853"},
5
  "bimap": {:hex, :bimap, "1.3.0", "3ea4832e58dc83a9b5b407c6731e7bae87458aa618e6d11d8e12114a17afa4b3", [:mix], [], "hexpm", "bf5a2b078528465aa705f405a5c638becd63e41d280ada41e0f77e6d255a10b4"},