const models = { "meta-llama/Llama-2-70b-chat-hf": { modelCard: "https://huggingface.co/meta-llama/Llama-2-70b-chat-hf", license: "https://bit.ly/llama2-license", maxSessionLength: 8192, sepToken: "###", stopToken: "###", extraStopSequences: [""], }, "meta-llama/Llama-2-70b-hf": { modelCard: "https://huggingface.co/meta-llama/Llama-2-70b-hf", license: "https://bit.ly/llama2-license", maxSessionLength: 8192, sepToken: "###", stopToken: "###", extraStopSequences: [""], }, "timdettmers/guanaco-65b": { modelCard: "https://huggingface.co/timdettmers/guanaco-65b", license: "https://huggingface.co/timdettmers/guanaco-65b", maxSessionLength: 2048, sepToken: "###", stopToken: "###", extraStopSequences: [""], }, "enoch/llama-65b-hf": { modelCard: "https://github.com/facebookresearch/llama/blob/llama_v1/MODEL_CARD.md", license: "https://bit.ly/llama-license", maxSessionLength: 2048, sepToken: "###", stopToken: "###", extraStopSequences: [""], }, "bigscience/bloom": { modelCard: "https://huggingface.co/bigscience/bloom", license: "https://bit.ly/bloom-license", maxSessionLength: 2048, sepToken: "\n\n", stopToken: "\n\n", extraStopSequences: null, }, "bigscience/bloomz": { modelCard: "https://huggingface.co/bigscience/bloomz", license: "https://bit.ly/bloom-license", maxSessionLength: 2048, sepToken: "\n\n", stopToken: "", extraStopSequences: ["\n\nHuman"], }, }; var curModel = "meta-llama/Llama-2-70b-chat-hf"; const generationParams = { do_sample: 1, temperature: 0.9, top_p: 0.6, }; var ws = null; var position = 0; const initialSessionLength = 512; var sessionLength = initialSessionLength; var connFailureBefore = false; var totalElapsed, nRequests; const Regime = { CHATBOT: 1, FEW_SHOT: 2, }; let curRegime = Regime.CHATBOT; let stop = false; function openSession() { let protocol = location.protocol == "https:" ? "wss:" : "ws:"; ws = new WebSocket(`${protocol}//${location.host}/api/v2/generate`); ws.onopen = () => { ws.send(JSON.stringify({type: "open_inference_session", model: curModel, max_length: sessionLength})); ws.onmessage = event => { const response = JSON.parse(event.data); if (!response.ok) { handleFailure(response.traceback); return; } sendReplica(); }; }; ws.onerror = _event => handleFailure(`Connection failed`); ws.onclose = _event => { if ($(".error-box").is(":hidden")) { handleFailure(`Connection was closed`, true); } }; } function resetSession() { if (ws !== null && ws.readyState <= 1) { // If readyState is "connecting" or "opened" ws.close(); } ws = null; position = 0; } function isWaitingForInputs() { return $('.human-replica textarea').length >= 1; } function sendReplica() { if (isWaitingForInputs()) { const aiPrompt = (curRegime === Regime.CHATBOT) ? 'Assistant:' : ''; $('.human-replica:last').text($('.human-replica:last textarea').val()); $('.dialogue').append($( '

' + `${aiPrompt}` + '' + '' + 'stop generation' + '' + '

')); animateLoading(); $('.stop-generation').click(e => { e.preventDefault(); console.log("Stop generation"); stop = true; }); } else { $('.loading-animation').show(); } if (ws === null) { openSession(); return; } const replicaDivs = $('.human-replica, .ai-replica .text'); var replicas = []; for (var i = position; i < replicaDivs.length; i++) { const el = $(replicaDivs[i]); var phrase = el.text(); if (el.is(".human-replica")) { phrase += models[curModel].sepToken; } else if (i < replicaDivs.length - 1) { phrase += models[curModel].stopToken; } replicas.push(phrase); } const inputs = replicas.join(""); position = replicaDivs.length; totalElapsed = 0; nRequests = 0; receiveReplica(inputs); } function receiveReplica(inputs) { ws.send(JSON.stringify({ type: "generate", inputs: inputs, max_new_tokens: 1, stop_sequence: models[curModel].stopToken, extra_stop_sequences: models[curModel].extraStopSequences, ...generationParams })); var lastMessageTime = null; ws.onmessage = event => { connFailureBefore = false; // We've managed to connect after a possible failure const response = JSON.parse(event.data); if (!response.ok) { handleFailure(response.traceback); return; } if (lastMessageTime != null) { totalElapsed += performance.now() - lastMessageTime; nRequests++; } lastMessageTime = performance.now(); const lastReplica = $('.ai-replica .text').last(); var newText = lastReplica.text() + response.outputs; newText = newText.replace(models[curModel].stopToken, ""); if (models[curModel].extraStopSequences !== null) { for (const seq of models[curModel].extraStopSequences) { newText = newText.replace(seq, ""); } } lastReplica.text(newText); if (!response.stop && !stop) { if (nRequests >= 1) { const speed = nRequests / (totalElapsed / 1000); $('.speed') .text(`Speed: ${speed.toFixed(1)} tokens/sec`) .show(); if (speed < 0.5) { $('.suggest-join').show(); } } } else { $('.loading-animation, .speed, .suggest-join, .generation-controls').remove(); resetSession(); appendTextArea(); stop = false; } }; } function handleFailure(message, autoRetry = false) { resetSession(); if (!isWaitingForInputs()) { // Show the error and the retry button only if a user is waiting for the generation results if (message === "Connection failed" && !connFailureBefore) { autoRetry = true; connFailureBefore = true; } if (/Session .+ expired/.test(message)) { autoRetry = true; } if (/Maximum length exceeded/.test(message) && sessionLength < models[curModel].maxSessionLength) { // We gradually increase sessionLength to save server resources. Default: 512 -> 2048 -> 8192 (if supported) sessionLength = Math.min(sessionLength * 4, models[curModel].maxSessionLength); autoRetry = true; } if (autoRetry) { retry(); } else { $('.loading-animation').hide(); if (/attention cache is full/.test(message)) { $('.error-message').hide(); $('.out-of-capacity').show(); } else { $('.out-of-capacity').hide(); $('.error-message').text(message).show(); } $('.error-box').show(); } } } function retry() { $('.error-box').hide(); sendReplica(); } function appendTextArea() { const humanPrompt = (curRegime === Regime.CHATBOT) ? "Human: " : ""; $('.dialogue').append($( `

` )); upgradeTextArea(); } function upgradeTextArea() { const textarea = $('.human-replica textarea'); autosize(textarea); textarea[0].selectionStart = textarea[0].value.length; textarea.focus(); textarea.on('keypress', e => { if (e.which == 13 && !e.shiftKey) { e.preventDefault(); sendReplica(); } }); } const animFrames = ["⌛", "🧠"]; var curFrame = 0; function animateLoading() { $('.loading-animation').html('  ' + animFrames[curFrame]); curFrame = (curFrame + 1) % animFrames.length; } $(() => { upgradeTextArea(); $('.family-selector label').click(function (e) { if (!isWaitingForInputs()) { alert("Can't switch the model while the AI is writing a response. Please refresh the page"); e.preventDefault(); return; } const radio = $(`#${$(this).attr("for")}`); if (radio.is(":checked")) { setTimeout(() => $('.human-replica textarea').focus(), 10); return; } const curFamily = radio.attr("value"); $('.model-selector').hide(); const firstLabel = $(`.model-selector[data-family=${curFamily}]`).show().children('label:first'); firstLabel.click(); firstLabel.trigger('click'); }); $('.model-selector label').click(function (e) { if (!isWaitingForInputs()) { alert("Can't switch the model while the AI is writing a response. Please refresh the page"); e.preventDefault(); return; } curModel = $(`#${$(this).attr("for")}`).attr("value"); if (curRegime === Regime.CHATBOT) { $('.dialogue p').slice(2).remove(); } else { $('.dialogue').empty(); } sessionLength = initialSessionLength; resetSession(); appendTextArea(); $('.model-name') .text($(this).text()) .attr('href', models[curModel].modelCard); $('.license-link').attr('href', models[curModel].license); setTimeout(() => $('.human-replica textarea').focus(), 10); }); $('.retry-link').click(e => { e.preventDefault(); retry(); }); setInterval(animateLoading, 2000); });