Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import { v } from 'convex/values'; | |
import { query, internalMutation } from './_generated/server'; | |
import Replicate, { WebhookEventType } from 'replicate'; | |
import { httpAction, internalAction } from './_generated/server'; | |
import { internal, api } from './_generated/api'; | |
function client(): Replicate { | |
const replicate = new Replicate({ | |
auth: process.env.REPLICATE_API_TOKEN || '', | |
}); | |
return replicate; | |
} | |
function replicateAvailable(): boolean { | |
return !!process.env.REPLICATE_API_TOKEN; | |
} | |
export const insertMusic = internalMutation({ | |
args: { storageId: v.string(), type: v.union(v.literal('background'), v.literal('player')) }, | |
handler: async (ctx, args) => { | |
await ctx.db.insert('music', { | |
storageId: args.storageId, | |
type: args.type, | |
}); | |
}, | |
}); | |
export const getBackgroundMusic = query({ | |
handler: async (ctx) => { | |
const music = await ctx.db | |
.query('music') | |
.filter((entry) => entry.eq(entry.field('type'), 'background')) | |
.order('desc') | |
.first(); | |
if (!music) { | |
return '/assets/background.mp3'; | |
} | |
const url = await ctx.storage.getUrl(music.storageId); | |
if (!url) { | |
throw new Error(`Invalid storage ID: ${music.storageId}`); | |
} | |
return url; | |
}, | |
}); | |
export const enqueueBackgroundMusicGeneration = internalAction({ | |
handler: async (ctx): Promise<void> => { | |
if (!replicateAvailable()) { | |
return; | |
} | |
const worldStatus = await ctx.runQuery(api.world.defaultWorldStatus); | |
if (!worldStatus) { | |
console.log('No active default world, returning.'); | |
return; | |
} | |
// TODO: MusicGen-Large on Replicate only allows 30 seconds. Use MusicGen-Small for longer? | |
await generateMusic('16-bit RPG adventure game with wholesome vibe', 30); | |
}, | |
}); | |
export const handleReplicateWebhook = httpAction(async (ctx, request) => { | |
const req = await request.json(); | |
if (req.id) { | |
const prediction = await client().predictions.get(req.id); | |
const response = await fetch(prediction.output); | |
const music = await response.blob(); | |
const storageId = await ctx.storage.store(music); | |
await ctx.runMutation(internal.music.insertMusic, { type: 'background', storageId }); | |
} | |
return new Response(); | |
}); | |
enum MusicGenNormStrategy { | |
Clip = 'clip', | |
Loudness = 'loudness', | |
Peak = 'peak', | |
Rms = 'rms', | |
} | |
enum MusicGenFormat { | |
wav = 'wav', | |
mp3 = 'mp3', | |
} | |
/** | |
* | |
* @param prompt A description of the music you want to generate. | |
* @param duration Duration of the generated audio in seconds. | |
* @param webhook webhook URL for Replicate to call when @param webhook_events_filter is triggered | |
* @param webhook_events_filter Array of event names to filter the webhook. See https://replicate.com/docs/reference/http#predictions.create--webhook_events_filter | |
* @param normalization_strategy Strategy for normalizing audio. | |
* @param top_k Reduces sampling to the k most likely tokens. | |
* @param top_p Reduces sampling to tokens with cumulative probability of p. When set to `0` (default), top_k sampling is used. | |
* @param temperature Controls the 'conservativeness' of the sampling process. Higher temperature means more diversity. | |
* @param classifer_free_gudance Increases the influence of inputs on the output. Higher values produce lower-varience outputs that adhere more closely to inputs. | |
* @param output_format Output format for generated audio. See @ | |
* @param seed Seed for random number generator. If None or -1, a random seed will be used. | |
* @returns object containing metadata of the prediction with ID to fetch once result is completed | |
*/ | |
export async function generateMusic( | |
prompt: string, | |
duration: number, | |
webhook: string = process.env.CONVEX_SITE_URL + '/replicate_webhook' || '', | |
webhook_events_filter: [WebhookEventType] = ['completed'], | |
normalization_strategy: MusicGenNormStrategy = MusicGenNormStrategy.Peak, | |
output_format: MusicGenFormat = MusicGenFormat.mp3, | |
top_k = 250, | |
top_p = 0, | |
temperature = 1, | |
classifer_free_gudance = 3, | |
seed = -1, | |
model_version = 'large', | |
) { | |
if (!replicateAvailable()) { | |
throw new Error('Replicate API token not set'); | |
} | |
return await client().predictions.create({ | |
// https://replicate.com/facebookresearch/musicgen/versions/7a76a8258b23fae65c5a22debb8841d1d7e816b75c2f24218cd2bd8573787906 | |
version: '7a76a8258b23fae65c5a22debb8841d1d7e816b75c2f24218cd2bd8573787906', | |
input: { | |
model_version, | |
prompt, | |
duration, | |
normalization_strategy, | |
top_k, | |
top_p, | |
temperature, | |
classifer_free_gudance, | |
output_format, | |
seed, | |
}, | |
webhook, | |
webhook_events_filter, | |
}); | |
} | |