## Prerequisites

### Install Dependencies

In [None]:
!nvidia-smi

In [None]:
!pip install --upgrade --quiet pip
!pip install --quiet git+https://github.com/huggingface/transformers.git

In [None]:
!pip install typing-extensions==4.5.0
!pip install python-multipart
!pip install kaleido
!pip install notebook>=6.5.5
!pip install click>=8.0
!pip install fastapi
!pip install "uvicorn[standard]"
!pip install pyngrok

### Load the models

In [None]:
from transformers import MusicgenForConditionalGeneration, MusicgenProcessor, set_seed

model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
processor = MusicgenProcessor.from_pretrained("facebook/musicgen-small")

In [None]:
import torch
from IPython.display import Audio

sampling_rate = model.config.audio_encoder.sampling_rate
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
None

## Music Generation functionality

#### Model Class

In [None]:
import numpy as np
import typing

class AudioPalette:
 def __init__(self):
 pass

 def set_prompt(self, caption: str | typing.List[str]):
 self.caption = caption

 def generate(self):
 if isinstance(self.caption, str):
 return self.generate_single(max_new_tokens=1024)
 else:
 return self.generate_multiple()

 def generate_single(self, prompt=None, max_new_tokens=512):
 if not prompt:
 prompt = self.caption
 inputs = processor(
 text=[prompt],
 padding=True,
 return_tensors="pt",
 sampling_rate=sampling_rate
 )

 audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=max_new_tokens)
 return audio_values

 def generate_audio_with_melody_conditioning(self, prompt, melody, max_new_tokens=256):
 inputs = processor(
 text=[prompt],
 audio=melody[0, 0].cpu().numpy(),
 padding=True,
 return_tensors="pt",
 sampling_rate=sampling_rate
 )

 # set_seed(1)
 audio_values = model.generate(**inputs.to(device), do_sample=True, guidance_scale=3, max_new_tokens=max_new_tokens)
 return audio_values

 def generate_multiple(self):
 for idx, prompt in enumerate(self.caption):
 if idx == 0:
 audio = self.generate_single(prompt, 256)
 else:
 audio = self.generate_audio_with_melody_conditioning(prompt, audio)
 return audio

In [None]:
audiopalette = AudioPalette()

#### API Creation

In [None]:
from fastapi import FastAPI
from pydantic import BaseModel, Field
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

app.add_middleware(
 CORSMiddleware,
 allow_origins=['*'],
 allow_credentials=True,
 allow_methods=['*'],
 allow_headers=['*'],
)

In [None]:
import typing
import numpy as np

class Prompt(BaseModel):
 caption: str | typing.List[str]

class FileData(BaseModel):
 file_path: str

# class Melody(BaseModel):
# audio: np.ndarray

# class Config:
# arbitrary_types_allowed = True

In [None]:
import tempfile
import scipy

from fastapi.responses import FileResponse

@app.get('/')
async def root():
 return {"message": "Hello World"}

@app.post('/download')
async def download(file_data: FileData):
 file_path = file_data.file_path
 return FileResponse(file_path)

@app.post('/generate')
async def gen_music(prompt: Prompt):
 audiopalette.set_prompt(prompt.caption)
 audio = audiopalette.generate()

 file_path = None
 with tempfile.NamedTemporaryFile(delete=False) as f:
 scipy.io.wavfile.write(f, rate=sampling_rate, data=audio[0, 0].cpu().numpy())
 file_path = f.name

 if not file_path:
 return {"error": "There has been an error"}
 return {"file_path": f"{file_path}"}


#### Run the API

In [None]:
from getpass import getpass

import nest_asyncio
import uvicorn
from pyngrok import ngrok

In [None]:
ngrok_auth_token = getpass(prompt="Enter ngrok auth token: ")
ngrok.set_auth_token(ngrok_auth_token)

In [None]:
ngrok_tunnel = ngrok.connect(8000)
print("Public URL:", ngrok_tunnel.public_url)
nest_asyncio.apply()
uvicorn.run(app, port=8000)

#### Kill ngrok Connection

In [None]:
ngrok.kill()