asigalov61 commited on
Commit
0304307
1 Parent(s): 9d6cd5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -49
app.py CHANGED
@@ -4,12 +4,8 @@ import time as reqtime
4
  import datetime
5
  from pytz import timezone
6
 
7
- import torch
8
-
9
- import spaces
10
  import gradio as gr
11
 
12
- from x_transformer_1_23_2 import *
13
  import random
14
  import tqdm
15
 
@@ -22,48 +18,11 @@ in_space = os.getenv("SYSTEM") == "spaces"
22
 
23
  # =================================================================================================
24
 
25
- @spaces.GPU
26
  def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type, input_strip_notes):
27
  print('=' * 70)
28
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
29
  start_time = reqtime.time()
30
 
31
- print('Loading model...')
32
-
33
- SEQ_LEN = 8192 # Models seq len
34
- PAD_IDX = 707 # Models pad index
35
- DEVICE = 'cuda' # 'cuda'
36
-
37
- # instantiate the model
38
-
39
- model = TransformerWrapper(
40
- num_tokens = PAD_IDX+1,
41
- max_seq_len = SEQ_LEN,
42
- attn_layers = Decoder(dim = 2048, depth = 4, heads = 16, attn_flash = True)
43
- )
44
-
45
- model = AutoregressiveWrapper(model, ignore_index = PAD_IDX)
46
-
47
- model.to(DEVICE)
48
- print('=' * 70)
49
-
50
- print('Loading model checkpoint...')
51
-
52
- model.load_state_dict(
53
- torch.load('Chords_Progressions_Transformer_Small_2048_Trained_Model_12947_steps_0.9316_loss_0.7386_acc.pth',
54
- map_location=DEVICE))
55
- print('=' * 70)
56
-
57
- model.eval()
58
-
59
- if DEVICE == 'cpu':
60
- dtype = torch.bfloat16
61
- else:
62
- dtype = torch.float16
63
-
64
- ctx = torch.amp.autocast(device_type=DEVICE, dtype=dtype)
65
-
66
- print('Done!')
67
  print('=' * 70)
68
 
69
  fn = os.path.basename(input_midi.name)
@@ -363,15 +322,12 @@ if __name__ == "__main__":
363
 
364
  app = gr.Blocks()
365
  with app:
366
- gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Chords Progressions Transformer</h1>")
367
- gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Chords-conditioned music transformer</h1>")
368
  gr.Markdown(
369
- "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.Chords-Progressions-Transformer&style=flat)\n\n"
370
- "Generate music based on chords progressions\n\n"
371
- "Check out [Chords Progressions Transformer](https://github.com/asigalov61/Chords-Progressions-Transformer) on GitHub!\n\n"
372
- "[Open In Colab]"
373
- "(https://colab.research.google.com/github/asigalov61/Chords-Progressions-Transformer/blob/main/Chords_Progressions_Transformer.ipynb)"
374
- " for faster execution and endless generation"
375
  )
376
  gr.Markdown("## Upload your MIDI or select a sample example MIDI")
377
 
 
4
  import datetime
5
  from pytz import timezone
6
 
 
 
 
7
  import gradio as gr
8
 
 
9
  import random
10
  import tqdm
11
 
 
18
 
19
  # =================================================================================================
20
 
 
21
  def GenerateAccompaniment(input_midi, input_num_tokens, input_conditioning_type, input_strip_notes):
22
  print('=' * 70)
23
  print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
24
  start_time = reqtime.time()
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  print('=' * 70)
27
 
28
  fn = os.path.basename(input_midi.name)
 
322
 
323
  app = gr.Blocks()
324
  with app:
325
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Melody</h1>")
326
+ gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Add a unique melody to any MIDI</h1>")
327
  gr.Markdown(
328
+ "![Visitors](https://api.visitorbadge.io/api/visitors?path=asigalov61.MIDI-Melody&style=flat)\n\n"
329
+ "This is a demo for TMIDIX Python module from tegridy-tools\n\n"
330
+ "Check out [tegridy-tools](https://github.com/asigalov61/tegridy-tools) on GitHub!\n\n"
 
 
 
331
  )
332
  gr.Markdown("## Upload your MIDI or select a sample example MIDI")
333