GPT-4-To-Midi / app.py
cbg342's picture
Update app.py
168ab78
import os, random, re
from fractions import Fraction
from midiutil.MidiFile import MIDIFile
import streamlit as st
import mido, openai
if 'path' not in st.session_state:
st.session_state['path'] = os.path.realpath(os.path.dirname(__file__))
if 'sessionID' not in st.session_state:
st.session_state['sessionID'] = random.randint(0,99999999)
if 'history' not in st.session_state:
st.session_state['history'] = []
if 'downloadable' not in st.session_state:
st.session_state['downloadable'] = False
notes = [['C'], ['Db', 'C#'], ['D'], ['Eb', 'D#'], ['E'], ['F'], ['Gb', 'F#'], ['G'], ['Ab', 'G#'], ['A'], ['Bb', 'A#'], ['B']]
monsters = [r'(?<![A-Za-z\d])([A-G](?:#|b)?\d-(?:\d+\/\d+|\d+))(?![A-Za-z\d])', r'(?<![A-Za-z\d])([A-G](?:#|b)?\d(?:-\d+(?:\/\d+)?(?:-\d+(?:\.\d+)?)?)+)(?![A-Za-z\d])']
examples = ['\n\nNotation looks like this:\n(Note-duration)\nC4-1/4, Eb4-1/4, D4-1/8, Eb4-1/8, C4-1/4', '\n\nNotation looks like this:\n(Note-duration-time in beats)\nC4-1/4-0, Eb4-1/8-2.5, D4-1/4-3, F4-1/4-3 etc.']
def noteToInt(n):
oct = int(n[-1])
letter = n[:-1]
id = 0
for ix, x in enumerate(notes):
for y in x:
if letter == y:
id = ix
return id+oct*12+12
def midiToStr(mPath, nIndex):
midIn = mido.MidiFile(os.path.expanduser(mPath))
ticks = midIn.ticks_per_beat
midOut = []
globalT = 0
opens = {}
for track in midIn.tracks:
for msg in track:
if msg.type == 'note_on' or msg.type == 'note_off':
globalT += msg.time/ticks
if msg.note in opens:
noteTime = opens[msg.note]
noteTime = int(noteTime) if noteTime.is_integer() else noteTime
noteDur = str(Fraction((globalT-noteTime)/4))
noteDur = str(round((globalT-noteTime),3)) if len(noteDur)>=6 else noteDur
if nIndex:
midOut.append('-'.join([notes[msg.note%12][0]+str(msg.note//12-1), noteDur, str(round(noteTime,3))]))
else:
midOut.append('-'.join([notes[msg.note%12][0]+str(msg.note//12-1), noteDur]))
del opens[msg.note]
if msg.type == 'note_on':
opens[msg.note] = globalT
return ', '.join(midOut)
st.markdown('# GPT-4 2 Midi\n#### AI Generated Polyphonic Music\n##### plus conversion tools for use with Chat-GPT\napp by [d3nt](https://github.com/d3n7/)')
notation = st.selectbox('Notation', ('Polyphonic', 'Monophonic'))
main, m2t, t2m = st.tabs(['GPT4-To-Midi', 'Midi-2-Text', 'Text-2-Midi'])
with main:
userPrompt = st.text_input('Prompt', 'Full piece of sad music with multiple parts. Plan out the structure beforehand, including chords, parts (soprano, alto, tenor, bass), meter, etc.')
with st.expander('System Prompt'):
sysPrompt = st.text_input('', 'You are MusicGPT, a music creation and completion chat bot that. When a user gives you a prompt, you return them a song showing the notes, durations, and times that they occur. Respond with just the music.')
openaikey = st.text_input('OpenAI API Key', type='password')
modelV = st.selectbox('Model', ('GPT-4', 'GPT-3.5-Turbo'))
col1, col2 = st.columns(2)
with col1:
newSession = st.checkbox('New Session', True)
with col2:
showOutput = st.checkbox('Show Output', True)
uploadMidi = st.file_uploader('Upload a midi file (OPTIONAL)')
col3, col4 = st.columns(2)
with col3:
if st.button('Ask GPT'):
if userPrompt != '' and sysPrompt != '' and openaikey != '':
notationIndex = int(notation=='Polyphonic')
if newSession:
st.session_state['history'] = [{'role': 'system', 'content': sysPrompt+examples[notationIndex]}]
prompt = userPrompt
if uploadMidi:
filename = ''.join(uploadMidi.name.split('.')[:-1])+str(st.session_state['sessionID'])+'.'+''.join(uploadMidi.name.split('.')[-1])
midiPath = os.path.join(st.session_state['path'], filename)
with open(midiPath, 'wb') as f:
f.write(uploadMidi.getbuffer())
prompt += '\n'+midiToStr(midiPath, notationIndex)
os.remove(midiPath)
st.session_state['history'].append({'role': 'user', 'content': prompt})
openai.api_key = openaikey
with st.spinner('Talking to OpenAI...'):
r = openai.ChatCompletion.create(
model=modelV.lower(),
messages=st.session_state['history']
)
response = r['choices'][0]['message']['content']
st.session_state['history'].append({'role': 'assistant', 'content': response})
noteInfo = []
for i in re.findall(monsters[notationIndex], response):
n = i.split('-')
if notationIndex:
noteInfo.append([noteToInt(n[0]), float(Fraction(n[1]))*4, float(n[2])]) #note, duration, time
else:
noteInfo.append([noteToInt(n[0]), float(Fraction(n[1]))*4]) # note, duration
song = MIDIFile(1, deinterleave=False)
time = 0
for i in noteInfo:
if notationIndex:
pitch, dur, time = i
else:
pitch, dur = i
song.addNote(0, 0, pitch, time, dur, 100)
if not notationIndex:
time += dur
with open(os.path.join(st.session_state['path'], 'out.mid'), 'wb') as f:
song.writeFile(f)
if not st.session_state['downloadable']:
st.session_state['downloadable'] = True
else:
st.warning('Make sure OpenAI key, prompt, and system prompt are entered', icon='⚠️')
with col4:
if st.session_state['downloadable']:
with open(os.path.join(st.session_state['path'], 'out.mid'), 'rb') as f:
st.download_button('Download Midi', f, file_name='song.mid', key='main')
if showOutput:
with st.container():
for i in st.session_state['history']:
st.text(i['role']+': '+i['content']+'\n')
with m2t:
inMidi = st.file_uploader('Input')
if st.button('Convert', key='1'):
if inMidi:
filename = ''.join(inMidi.name.split('.')[:-1]) + str(st.session_state['sessionID']) + '.' + ''.join(inMidi.name.split('.')[-1])
midiPath = os.path.join(st.session_state['path'], filename)
with open(midiPath, 'wb') as f:
f.write(inMidi.getbuffer())
st.text_area('Output', midiToStr(midiPath, notation=='Polyphonic'))
os.remove(midiPath)
with t2m:
inText = st.text_input('Input')
if st.button('Convert', key='2'):
notationIndex = int(notation=='Polyphonic')
noteInfo = []
for i in re.findall(monsters[notationIndex], inText):
n = i.split('-')
if notationIndex:
noteInfo.append([noteToInt(n[0]), float(Fraction(n[1])) * 4, float(n[2])]) # note, duration, time
else:
noteInfo.append([noteToInt(n[0]), float(Fraction(n[1])) * 4]) # note, duration
song = MIDIFile(1, deinterleave=False)
time = 0
for i in noteInfo:
if notationIndex:
pitch, dur, time = i
else:
pitch, dur = i
song.addNote(0, 0, pitch, time, dur, 100)
if not notationIndex:
time += dur
with open(os.path.join(st.session_state['path'], 't2m.mid'), 'wb') as f:
song.writeFile(f)
with open(os.path.join(st.session_state['path'], 't2m.mid'), 'rb') as f:
st.download_button('Download Midi', f, file_name='song.mid', key='t2m')