music-gen-kit / index.html
yiwv's picture
wip: use model in magenta html
fde0931
raw
history blame contribute delete
No virus
5.15 kB
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Magenta.js Model Loader and Player</title>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.11.0"></script>
<script src="https://cdn.jsdelivr.net/npm/@magenta/music"></script>
</head>
<body>
<button onclick="generateAndPlayMusic()">Generate and Play Music</button>
<script>
class MSEWithPositivePressure extends tf.layers.Layer {
constructor() {
super({});
}
call(inputs) {
let y_true = inputs[0];
let y_pred = inputs[1];
let mse = tf.mean(tf.square(tf.sub(y_true, y_pred)));
let positive_pressure = tf.mean(tf.maximum(tf.scalar(0), tf.neg(y_pred)));
return tf.add(mse, positive_pressure);
}
static get className() {
return 'MSEWithPositivePressure';
}
}
// tf.serialization.registerClass(MSEWithPositivePressure);
let model;
async function loadModel() {
tf.serialization.registerClass(MSEWithPositivePressure);
// model = await tf.loadLayersModel('/js_model/model.json');
model = await tf.loadGraphModel('/js_model/model.json');
console.log("Model loaded successfully!");
}
async function generateAndPlayMusic() {
if (!model) {
await loadModel();
}
let inputValues = [60, 0.5, 0.5, 62, 0.5, 0.5, 64, 0.5, 0.5];
let numNotes = inputValues.length / 3;
let inputSequence;
if (numNotes > 25) {
let inputData = new Array(numNotes * 3).fill(0).concat(inputValues);
inputSequence = tf.tensor3d(inputData, [1, numNotes, 3]);
} else {
const padding = new Array((25 - numNotes) * 3).fill(0);
let inputData = padding.concat(inputValues);
inputSequence = tf.tensor3d(inputData, [1, 25, 3]);
}
// inputSequence = inputSequence.bufferSync();
// for (let i = 0; i < inputValues.length; i++) {
// inputSequence.set(inputValues[i], 0, 24 - numNotes + Math.floor(i / 3), i % 3);
// }
// inputSequence = inputSequence.toTensor();
const temperature = 2.0 // 0.5 // 2.0;
const numPredictions = 40; // 120;
let generatedNotes = [];
for (let i = 0; i < numPredictions; i++) {
const predictions = await model.executeAsync(inputSequence);
// const pitchProbs = tf.softmax(predictions[2]);
// const pitch = tf.multinomial(pitchProbs, 1).dataSync()[0];
// const pitchProbs = tf.softmax(predictions[2].dataSync()).div(temperature);
// const pitch = tf.multinomial(pitchProbs, 1).dataSync()[0];
// const pitchLogitsArray = predictions[2].dataSync();
const pitchLogitsArray = predictions[2].dataSync().map(value => value / temperature);
// const pitchLogitsTensor = tf.tensor(pitchLogitsArray).div(temperature);
// const pitchProbs = tf.softmax(pitchLogitsTensor);
const pitchProbs = tf.softmax(pitchLogitsArray);
const pitch = tf.multinomial(pitchProbs, 1).dataSync()[0];
const clippedPitch = Math.min(Math.max(pitch, 21), 108);
const step = Math.max(0, predictions[1].dataSync()[0]);
const duration = Math.max(0, predictions[0].dataSync()[0]);
// console.log('///////////////////')
// console.log(predictions[0].dataSync(),predictions[1].dataSync(),predictions[2].dataSync())
// console.log('/////////////////// //')
// console.log({predictions, pitch, step, duration})
console.log('pitch:', pitch, {pitchLogitsArray, pitchProbs})
generatedNotes.push([clippedPitch, step, duration]);
// ζ–°γ—γ„γƒŽγƒΌγƒˆγ‚’η”Ÿζˆ
const newNote = tf.tensor3d([[[clippedPitch * 1.0, step, duration]]], [1, 1, 3]);
// ε…₯εŠ›γ‚·γƒΌγ‚±γƒ³γ‚Ήγ«ζ–°γ—γ„γƒŽγƒΌγƒˆγ‚’θΏ½εŠ 
inputSequence = inputSequence.slice([0, 1, 0], [-1, -1, -1]).concat(newNote, 1);
}
// η”Ÿζˆγ•γ‚ŒγŸγƒŽγƒΌγƒˆγ‚’NoteSequenceに倉換
const noteSequence = {
ticksPerQuarter: 220,
totalTime: generatedNotes.length / 2,
timeSignatures: [{ time: 0, numerator: 4, denominator: 4 }],
tempos: [{ time: 0, qpm: 120 }],
notes: generatedNotes.map((note, index) => ({
startTime: index / 2,
endTime: (index + 1) / 2,
pitch: note[0],
velocity: 80
}))
};
// // NoteSequenceγ‚’ε†η”Ÿγ™γ‚‹
// const player = new mm.Player();
// player.start(noteSequence);
// Play the note sequence using SoundFontPlayer
const soundfontURL = 'https://storage.googleapis.com/magentadata/js/soundfonts/sgm_plus';
const player = new mm.SoundFontPlayer(soundfontURL);
player.start(noteSequence);
}
</script>
</body>
</html>