juancopi81 commited on
Commit
bc718b3
1 Parent(s): 520354b

Change way to define inference model

Browse files
Files changed (1) hide show
  1. app.py +9 -16
app.py CHANGED
@@ -24,20 +24,12 @@ us["musescoreDirectPNGPath"] = "/usr/bin/mscore3"
24
  os.putenv("QT_QPA_PLATFORM", "offscreen")
25
  os.putenv("XDG_RUNTIME_DIR", environment.Environment().getRootTempDir())
26
 
27
- # Start inference model
28
- inference_model = InferenceModel("/home/user/app/checkpoints/mt3/", "mt3")
29
- current_model = "ismir2021"
30
-
31
- def change_model(model):
32
- global current_model
33
- global inference_model
34
- print("Inferece model", inference_model)
35
- print("Current model", current_model)
36
- checkpoint_path = f"/home/user/app/checkpoints/{model}/"
37
- if model == current_model:
38
- return
39
  inference_model = InferenceModel(checkpoint_path, model)
40
- current_model = model
 
41
 
42
  # Credits https://huggingface.co/spaces/rajesh1729/youtube-video-transcription-with-whisper
43
  def get_audio(url):
@@ -56,12 +48,14 @@ def populate_metadata(link):
56
  audio = get_audio(link)
57
  return yt.thumbnail_url, yt.title, audio, audio
58
 
59
- def inference(yt_audio_path):
60
 
61
  with open(yt_audio_path, 'rb') as fd:
62
  contents = fd.read()
63
 
64
  audio = upload_audio(contents,sample_rate=SAMPLE_RATE)
 
 
65
 
66
  est_ns = inference_model(audio)
67
 
@@ -104,7 +98,6 @@ with demo:
104
  label=model_label,
105
  value="mt3"
106
  )
107
- model.change(fn=change_model, inputs=model, outputs=[])
108
 
109
  with gr.Row():
110
  link = gr.Textbox(label="YouTube Link")
@@ -120,7 +113,7 @@ with demo:
120
  yt_audio_path = gr.Textbox(visible=False)
121
 
122
  preview_btn.click(fn=populate_metadata,
123
- inputs=[link],
124
  outputs=[img, title, yt_audio, yt_audio_path])
125
 
126
  with gr.Row():
 
24
  os.putenv("QT_QPA_PLATFORM", "offscreen")
25
  os.putenv("XDG_RUNTIME_DIR", environment.Environment().getRootTempDir())
26
 
27
+ def load_model(model=str):
28
+ checkpoint_path = f'/content/checkpoints/{model}/'
29
+ # Start inference model
 
 
 
 
 
 
 
 
 
30
  inference_model = InferenceModel(checkpoint_path, model)
31
+ return inference_model
32
+
33
 
34
  # Credits https://huggingface.co/spaces/rajesh1729/youtube-video-transcription-with-whisper
35
  def get_audio(url):
 
48
  audio = get_audio(link)
49
  return yt.thumbnail_url, yt.title, audio, audio
50
 
51
+ def inference(yt_audio_path, model):
52
 
53
  with open(yt_audio_path, 'rb') as fd:
54
  contents = fd.read()
55
 
56
  audio = upload_audio(contents,sample_rate=SAMPLE_RATE)
57
+
58
+ inference_model = load_model(model)
59
 
60
  est_ns = inference_model(audio)
61
 
 
98
  label=model_label,
99
  value="mt3"
100
  )
 
101
 
102
  with gr.Row():
103
  link = gr.Textbox(label="YouTube Link")
 
113
  yt_audio_path = gr.Textbox(visible=False)
114
 
115
  preview_btn.click(fn=populate_metadata,
116
+ inputs=[link, model],
117
  outputs=[img, title, yt_audio, yt_audio_path])
118
 
119
  with gr.Row():