File size: 5,892 Bytes
7368ee6 c698696 7368ee6 c698696 7368ee6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
using System;
using System.Collections.Generic;
using System.Linq;
using Unity.Sentis;
using UnityEngine;
public sealed class DebertaV3 : MonoBehaviour
{
public ModelAsset model;
public TextAsset vocabulary;
public bool multipleTrueClasses;
public string text = "Angela Merkel is a politician in Germany and leader of the CDU";
public string hypothesisTemplate = "This example is about {}";
public string[] classes = { "politics", "economy", "entertainment", "environment" };
Ops ops;
IWorker engine;
ITensorAllocator allocator;
string[] vocabularyTokens;
const int padToken = 0;
const int startToken = 1;
const int separatorToken = 2;
const int vocabToTokenOffset = 260;
const BackendType backend = BackendType.GPUCompute;
void Start()
{
vocabularyTokens = vocabulary.text.Split("\n");
allocator = new TensorCachingAllocator();
ops = WorkerFactory.CreateOps(backend, allocator);
Model loadedModel = ModelLoader.Load(model);
engine = WorkerFactory.CreateWorker(backend, loadedModel);
if (classes.Length == 0)
{
Debug.LogError("There need to be more than 0 classes");
return;
}
string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray();
Batch batch = GetTokenizedBatch(text, hypotheses);
float[] scores = GetBatchScores(batch);
for (int i = 0; i < scores.Length; i++)
{
Debug.Log($"[{classes[i]}] Entailment Score: {scores[i]}");
}
}
float[] GetBatchScores(Batch batch)
{
using var inputIds = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedTokens);
using var attentionMask = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedMasks);
Dictionary<string, Tensor> inputs = new()
{
{"input_ids", inputIds},
{"attention_mask", attentionMask}
};
engine.Execute(inputs);
TensorFloat logits = (TensorFloat)engine.PeekOutput("logits");
float[] scores = ScoresFromLogits(logits);
return scores;
}
Batch GetTokenizedBatch(string prompt, string[] hypotheses)
{
Batch batch = new Batch();
List<int> promptTokens = Tokenize(prompt);
promptTokens.Insert(0, startToken);
List<int>[] tokenizedHypotheses = hypotheses.Select(Tokenize).ToArray();
int maxTokenLength = tokenizedHypotheses.Max(x => x.Count);
// Each example in the batch follows this format:
// Start Prompt Separator Hypothesis Separator Padding
int[] batchedTokens = tokenizedHypotheses.SelectMany(hypothesis => promptTokens
.Append(separatorToken)
.Concat(hypothesis)
.Append(separatorToken)
.Concat(Enumerable.Repeat(padToken, maxTokenLength - hypothesis.Count)))
.ToArray();
// The attention masks have the same length as the tokens.
// Each attention mask contains repeating 1s for each token, except for padding tokens.
int[] batchedMasks = tokenizedHypotheses.SelectMany(hypothesis => Enumerable.Repeat(1, promptTokens.Count + 1)
.Concat(Enumerable.Repeat(1, hypothesis.Count + 1))
.Concat(Enumerable.Repeat(0, maxTokenLength - hypothesis.Count)))
.ToArray();
batch.BatchCount = hypotheses.Length;
batch.BatchLength = batchedTokens.Length / hypotheses.Length;
batch.BatchedTokens = batchedTokens;
batch.BatchedMasks = batchedMasks;
return batch;
}
float[] ScoresFromLogits(TensorFloat logits)
{
// The logits represent the model's predictions for entailment and non-entailment for each example in the batch.
// They are of shape [batch size, 2], with two values per example.
// To obtain a single value (score) per example, a softmax function is applied
TensorFloat tensorScores;
if (multipleTrueClasses || logits.shape.Length(0, 1) == 1)
{
// Softmax over the entailment vs. contradiction dimension for each label independently
tensorScores = ops.Softmax(logits, -1);
}
else
{
// Softmax over all candidate labels
tensorScores = ops.Softmax(logits, 0);
}
tensorScores.MakeReadable();
float[] tensorArray = tensorScores.ToReadOnlyArray();
tensorScores.Dispose();
// Select the first column which is the column where the scores are stored
float[] scores = new float[tensorArray.Length / 2];
for (int i = 0; i < scores.Length; i++)
{
scores[i] = tensorArray[i * 2];
}
return scores;
}
List<int> Tokenize(string input)
{
string[] words = input.Split(null);
List<int> ids = new();
foreach (string word in words)
{
int start = 0;
for(int i = word.Length; i >= 0;i--)
{
string subWord = start == 0 ? "▁" + word.Substring(start, i) : word.Substring(start, i-start);
int index = Array.IndexOf(vocabularyTokens, subWord);
if (index >= 0)
{
ids.Add(index + vocabToTokenOffset);
if (i == word.Length) break;
start = i;
i = word.Length + 1;
}
}
}
return ids;
}
void OnDestroy()
{
engine?.Dispose();
allocator?.Dispose();
ops?.Dispose();
}
struct Batch
{
public int BatchCount;
public int BatchLength;
public int[] BatchedTokens;
public int[] BatchedMasks;
}
} |