Pippe's picture
Upgraded to sentis 1.4.0-pre3
ca685d7
using System;
using System.Collections.Generic;
using System.Linq;
using Unity.Sentis;
using UnityEngine;
public sealed class DebertaV3 : MonoBehaviour
{
public ModelAsset model;
public TextAsset vocabulary;
public bool multipleTrueClasses;
public string text = "Angela Merkel is a politician in Germany and leader of the CDU";
public string hypothesisTemplate = "This example is about {}";
public string[] classes = { "politics", "economy", "entertainment", "environment" };
IWorker engine;
string[] vocabularyTokens;
const int padToken = 0;
const int startToken = 1;
const int separatorToken = 2;
const int vocabToTokenOffset = 260;
void Start()
{
if (classes.Length == 0)
{
Debug.LogError("There need to be more than 0 classes");
return;
}
vocabularyTokens = vocabulary.text.Replace("\r", "").Split("\n");
Model baseModel = ModelLoader.Load(model);
Model modelWithScoring = Functional.Compile(
input =>
{
// The logits represent the model's predictions for entailment and non-entailment for each example in the batch.
// They are of shape [batch size, 2] i.e. with two values per example.
// To obtain a single score per example, a softmax function is applied
FunctionalTensor logits = baseModel.Forward(input)[0];
if (multipleTrueClasses || classes.Length == 1)
{
// Softmax over the entailment vs. contradiction dimension for each label independently
logits = Functional.Softmax(logits);
}
else
{
// Softmax over all candidate labels
logits = Functional.Softmax(logits, 0);
}
// The scores are stored along the first column
return new []{logits[.., 0]};
},
InputDef.FromModel(baseModel)
);
engine = WorkerFactory.CreateWorker(BackendType.GPUCompute, modelWithScoring);
string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray();
Batch batch = GetTokenizedBatch(text, hypotheses);
float[] scores = GetBatchScores(batch);
for (int i = 0; i < scores.Length; i++)
{
Debug.Log($"[{classes[i]}] Entailment Score: {scores[i]}");
}
}
float[] GetBatchScores(Batch batch)
{
using var inputIds = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedTokens);
using var attentionMask = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedMasks);
Dictionary<string, Tensor> inputs = new()
{
{"input_0", inputIds},
{"input_1", attentionMask}
};
engine.Execute(inputs);
TensorFloat scores = (TensorFloat)engine.PeekOutput("output_0");
scores.CompleteOperationsAndDownload();
return scores.ToReadOnlyArray();
}
Batch GetTokenizedBatch(string prompt, string[] hypotheses)
{
Batch batch = new Batch();
List<int> promptTokens = Tokenize(prompt);
promptTokens.Insert(0, startToken);
List<int>[] tokenizedHypotheses = hypotheses.Select(Tokenize).ToArray();
int maxTokenLength = tokenizedHypotheses.Max(x => x.Count);
// Each example in the batch follows this format:
// Start Prompt Separator Hypothesis Separator Padding
int[] batchedTokens = tokenizedHypotheses.SelectMany(hypothesis => promptTokens
.Append(separatorToken)
.Concat(hypothesis)
.Append(separatorToken)
.Concat(Enumerable.Repeat(padToken, maxTokenLength - hypothesis.Count)))
.ToArray();
// The attention masks have the same length as the tokens.
// Each attention mask contains repeating 1s for each token, except for padding tokens.
int[] batchedMasks = tokenizedHypotheses.SelectMany(hypothesis => Enumerable.Repeat(1, promptTokens.Count + 1)
.Concat(Enumerable.Repeat(1, hypothesis.Count + 1))
.Concat(Enumerable.Repeat(0, maxTokenLength - hypothesis.Count)))
.ToArray();
batch.BatchCount = hypotheses.Length;
batch.BatchLength = batchedTokens.Length / hypotheses.Length;
batch.BatchedTokens = batchedTokens;
batch.BatchedMasks = batchedMasks;
return batch;
}
List<int> Tokenize(string input)
{
string[] words = input.Split(null);
List<int> ids = new();
foreach (string word in words)
{
int start = 0;
for(int i = word.Length; i >= 0;i--)
{
string subWord = start == 0 ? "▁" + word.Substring(start, i) : word.Substring(start, i-start);
int index = Array.IndexOf(vocabularyTokens, subWord);
if (index >= 0)
{
ids.Add(index + vocabToTokenOffset);
if (i == word.Length) break;
start = i;
i = word.Length + 1;
}
}
}
return ids;
}
void OnDestroy() => engine?.Dispose();
struct Batch
{
public int BatchCount;
public int BatchLength;
public int[] BatchedTokens;
public int[] BatchedMasks;
}
}