import gradio as gr import openai from t2a import text_to_audio import joblib from sentence_transformers import SentenceTransformer import numpy as np import os reg = joblib.load('text_reg.joblib') model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') finetune = "davinci:ft-personal:autodrummer-v5-2022-11-04-22-34-07" def get_note_text(prompt): prompt = prompt + " ->" # get completion from finetune response = openai.Completion.create( engine=finetune, prompt=prompt, temperature=0.5, max_tokens=200, top_p=1, frequency_penalty=0, presence_penalty=0, stop=["###"] ) return response.choices[0].text.strip() def increment_count(): with open('count.txt', 'r') as f: count = int(f.read()) count += 1 with open('count.txt', 'w') as f: f.write(str(count)) def get_drummer_output(prompt, tempo): openai.api_key = os.environ['key'] if tempo == "fast": tempo = 138 elif tempo == "slow": tempo = 100 note_text = get_note_text(prompt) # note_text = note_text + " " + note_text # prompt_enc = model.encode([prompt]) # bpm = int(reg.predict(prompt_enc)[0]) + 20 audio = text_to_audio(note_text, tempo) audio = np.array(audio.get_array_of_samples(), dtype=np.float32) increment_count() return (96000, audio) iface = gr.Interface( fn=get_drummer_output, inputs=[ "text", gr.Radio(["fast", "slow"], label="Tempo", default="fast"), ], examples=[ ["hiphop groove 808", "fast"], ["rock metal", "fast"], ["disco funk", "fast"], ], outputs="audio", title='Autodrummer', description="Stable Diffusion for drum beats. Type in a genre and some descriptors (e.g., 'hiphop groove 808') to the prompt box and get a drum beat in that genre" ) iface.launch()