using System; using System.Collections.Generic; using System.Linq; using Unity.Sentis; using UnityEngine; public sealed class DebertaV3 : MonoBehaviour { public ModelAsset model; public TextAsset vocabulary; public bool multipleTrueClasses; public string text = "Angela Merkel is a politician in Germany and leader of the CDU"; public string hypothesisTemplate = "This example is about {}"; public string[] classes = { "politics", "economy", "entertainment", "environment" }; IWorker engine; string[] vocabularyTokens; const int padToken = 0; const int startToken = 1; const int separatorToken = 2; const int vocabToTokenOffset = 260; void Start() { if (classes.Length == 0) { Debug.LogError("There need to be more than 0 classes"); return; } vocabularyTokens = vocabulary.text.Replace("\r", "").Split("\n"); Model baseModel = ModelLoader.Load(model); Model modelWithScoring = Functional.Compile( input => { // The logits represent the model's predictions for entailment and non-entailment for each example in the batch. // They are of shape [batch size, 2] i.e. with two values per example. // To obtain a single score per example, a softmax function is applied FunctionalTensor logits = baseModel.Forward(input)[0]; if (multipleTrueClasses || classes.Length == 1) { // Softmax over the entailment vs. contradiction dimension for each label independently logits = Functional.Softmax(logits); } else { // Softmax over all candidate labels logits = Functional.Softmax(logits, 0); } // The scores are stored along the first column return new []{logits[.., 0]}; }, InputDef.FromModel(baseModel) ); engine = WorkerFactory.CreateWorker(BackendType.GPUCompute, modelWithScoring); string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray(); Batch batch = GetTokenizedBatch(text, hypotheses); float[] scores = GetBatchScores(batch); for (int i = 0; i < scores.Length; i++) { Debug.Log($"[{classes[i]}] Entailment Score: {scores[i]}"); } } float[] GetBatchScores(Batch batch) { using var inputIds = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedTokens); using var attentionMask = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedMasks); Dictionary inputs = new() { {"input_0", inputIds}, {"input_1", attentionMask} }; engine.Execute(inputs); TensorFloat scores = (TensorFloat)engine.PeekOutput("output_0"); scores.CompleteOperationsAndDownload(); return scores.ToReadOnlyArray(); } Batch GetTokenizedBatch(string prompt, string[] hypotheses) { Batch batch = new Batch(); List promptTokens = Tokenize(prompt); promptTokens.Insert(0, startToken); List[] tokenizedHypotheses = hypotheses.Select(Tokenize).ToArray(); int maxTokenLength = tokenizedHypotheses.Max(x => x.Count); // Each example in the batch follows this format: // Start Prompt Separator Hypothesis Separator Padding int[] batchedTokens = tokenizedHypotheses.SelectMany(hypothesis => promptTokens .Append(separatorToken) .Concat(hypothesis) .Append(separatorToken) .Concat(Enumerable.Repeat(padToken, maxTokenLength - hypothesis.Count))) .ToArray(); // The attention masks have the same length as the tokens. // Each attention mask contains repeating 1s for each token, except for padding tokens. int[] batchedMasks = tokenizedHypotheses.SelectMany(hypothesis => Enumerable.Repeat(1, promptTokens.Count + 1) .Concat(Enumerable.Repeat(1, hypothesis.Count + 1)) .Concat(Enumerable.Repeat(0, maxTokenLength - hypothesis.Count))) .ToArray(); batch.BatchCount = hypotheses.Length; batch.BatchLength = batchedTokens.Length / hypotheses.Length; batch.BatchedTokens = batchedTokens; batch.BatchedMasks = batchedMasks; return batch; } List Tokenize(string input) { string[] words = input.Split(null); List ids = new(); foreach (string word in words) { int start = 0; for(int i = word.Length; i >= 0;i--) { string subWord = start == 0 ? "▁" + word.Substring(start, i) : word.Substring(start, i-start); int index = Array.IndexOf(vocabularyTokens, subWord); if (index >= 0) { ids.Add(index + vocabToTokenOffset); if (i == word.Length) break; start = i; i = word.Length + 1; } } } return ids; } void OnDestroy() => engine?.Dispose(); struct Batch { public int BatchCount; public int BatchLength; public int[] BatchedTokens; public int[] BatchedMasks; } }