riffusion123 / inference.py
nikhilchintawar's picture
Upload files
c131f40
raw
history blame contribute delete
No virus
1.81 kB
# @title Imports
from diffusers import DiffusionPipeline
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from io import BytesIO
# @title Define a `predict` function
params = SpectrogramParams()
converter = SpectrogramImageConverter(params)
def preprocess_function(text):
with open(text, "r", encoding="utf-8") as f:
data = f.read()
print(data)
# pass the textand the target tanguage to be translated separated by a ";" semicolon
# data = text_path.read().decode("utf-8")
prompt = data.split(";")[0]
negative_prompt = data.split(";")[1].strip()
print(negative_prompt.strip())
print(data)
return (prompt, negative_prompt)
def predict_function(params, pipe):
prompt, negative_prompt = params
spec = pipe(
prompt,
negative_prompt=negative_prompt,
width=768,
).images[0]
wav = converter.audio_from_spectrogram_image(image=spec)
wav.export("output.wav", format="wav")
return ("output.wav", spec)
def model_load_function(model_path):
pipe = DiffusionPipeline.from_pretrained(model_path)
pipe = pipe.to("cuda")
return pipe
def postprocess_function(audio_file, content_type=None):
audio = open(audio_file, "rb")
audio = audio.read()
print(type(audio))
audio_bytes = BytesIO(audio)
response = dict()
audio_bytes.seek(0)
response["output"] = {"data": audio_bytes, "ext": "wav"}
return response
## Test the script
"""
if __name__ == '__main__':
text = ""
data = preprocess_function(text)
model_path = "./model_files"
path = model_load_function(model_path)
predictions = predict_function(data,path)
out = postprocess_function(audio_file)
print(out)
"""