Spaces:
Runtime error
Runtime error
import gradio as gr | |
import openai | |
from t2a import text_to_audio | |
import joblib | |
from sentence_transformers import SentenceTransformer | |
import numpy as np | |
import os | |
reg = joblib.load('text_reg.joblib') | |
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
finetune = "davinci:ft-personal:autodrummer-v5-2022-11-04-22-34-07" | |
def get_note_text(prompt): | |
prompt = prompt + " ->" | |
# get completion from finetune | |
response = openai.Completion.create( | |
engine=finetune, | |
prompt=prompt, | |
temperature=0.5, | |
max_tokens=200, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
stop=["###"] | |
) | |
return response.choices[0].text.strip() | |
def increment_count(): | |
with open('count.txt', 'r') as f: | |
count = int(f.read()) | |
count += 1 | |
with open('count.txt', 'w') as f: | |
f.write(str(count)) | |
def get_drummer_output(prompt, tempo): | |
openai.api_key = os.environ['key'] | |
if tempo == "fast": | |
tempo = 138 | |
elif tempo == "slow": | |
tempo = 100 | |
note_text = get_note_text(prompt) | |
# note_text = note_text + " " + note_text | |
# prompt_enc = model.encode([prompt]) | |
# bpm = int(reg.predict(prompt_enc)[0]) + 20 | |
audio = text_to_audio(note_text, tempo) | |
audio = np.array(audio.get_array_of_samples(), dtype=np.float32) | |
increment_count() | |
return (96000, audio) | |
iface = gr.Interface( | |
fn=get_drummer_output, | |
inputs=[ | |
"text", | |
gr.Radio(["fast", "slow"], label="Tempo", default="fast"), | |
], | |
examples=[ | |
["hiphop groove 808", "fast"], | |
["rock metal", "fast"], | |
["disco funk", "fast"], | |
], | |
outputs="audio", | |
title='Autodrummer', | |
description="Stable Diffusion for drum beats. Type in a genre and some descriptors (e.g., 'hiphop groove 808') to the prompt box and get a drum beat in that genre" | |
) | |
iface.launch() |