File size: 5,892 Bytes
7368ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c698696
 
 
 
 
 
7368ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c698696
7368ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
using System;
using System.Collections.Generic;
using System.Linq;
using Unity.Sentis;
using UnityEngine;

public sealed class DebertaV3 : MonoBehaviour
{
    public ModelAsset model;
    public TextAsset vocabulary;
    public bool multipleTrueClasses;
    public string text = "Angela Merkel is a politician in Germany and leader of the CDU";
    public string hypothesisTemplate = "This example is about {}";
    public string[] classes = { "politics", "economy", "entertainment", "environment" };

    Ops ops;
    IWorker engine;
    ITensorAllocator allocator;
    string[] vocabularyTokens;

    const int padToken = 0;
    const int startToken = 1;
    const int separatorToken = 2;
    const int vocabToTokenOffset = 260;
    const BackendType backend = BackendType.GPUCompute;

    void Start()
    {
        vocabularyTokens = vocabulary.text.Split("\n");

        allocator = new TensorCachingAllocator();
        ops = WorkerFactory.CreateOps(backend, allocator);

        Model loadedModel = ModelLoader.Load(model);
        engine = WorkerFactory.CreateWorker(backend, loadedModel);

        if (classes.Length == 0)
        {
            Debug.LogError("There need to be more than 0 classes");
            return;
        }

        string[] hypotheses = classes.Select(x => hypothesisTemplate.Replace("{}", x)).ToArray();
        Batch batch = GetTokenizedBatch(text, hypotheses);
        float[] scores = GetBatchScores(batch);

        for (int i = 0; i < scores.Length; i++)
        {
            Debug.Log($"[{classes[i]}] Entailment Score: {scores[i]}");
        }
    }

    float[] GetBatchScores(Batch batch)
    {
        using var inputIds = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedTokens);
        using var attentionMask = new TensorInt(new TensorShape(batch.BatchCount, batch.BatchLength), batch.BatchedMasks);

        Dictionary<string, Tensor> inputs = new()
        {
            {"input_ids", inputIds},
            {"attention_mask", attentionMask}
        };

        engine.Execute(inputs);
        TensorFloat logits = (TensorFloat)engine.PeekOutput("logits");
        float[] scores = ScoresFromLogits(logits);

        return scores;
    }

    Batch GetTokenizedBatch(string prompt, string[] hypotheses)
    {
        Batch batch = new Batch();

        List<int> promptTokens = Tokenize(prompt);
        promptTokens.Insert(0, startToken);

        List<int>[] tokenizedHypotheses = hypotheses.Select(Tokenize).ToArray();
        int maxTokenLength = tokenizedHypotheses.Max(x => x.Count);

        // Each example in the batch follows this format:
        // Start Prompt Separator Hypothesis Separator Padding

        int[] batchedTokens = tokenizedHypotheses.SelectMany(hypothesis => promptTokens
                .Append(separatorToken)
                .Concat(hypothesis)
                .Append(separatorToken)
                .Concat(Enumerable.Repeat(padToken, maxTokenLength - hypothesis.Count)))
            .ToArray();


        // The attention masks have the same length as the tokens.
        // Each attention mask contains repeating 1s for each token, except for padding tokens.

        int[] batchedMasks = tokenizedHypotheses.SelectMany(hypothesis => Enumerable.Repeat(1, promptTokens.Count + 1)
                .Concat(Enumerable.Repeat(1, hypothesis.Count + 1))
                .Concat(Enumerable.Repeat(0, maxTokenLength - hypothesis.Count)))
            .ToArray();

        batch.BatchCount = hypotheses.Length;
        batch.BatchLength = batchedTokens.Length / hypotheses.Length;
        batch.BatchedTokens = batchedTokens;
        batch.BatchedMasks = batchedMasks;

        return batch;
    }

    float[] ScoresFromLogits(TensorFloat logits)
    {
        // The logits represent the model's predictions for entailment and non-entailment for each example in the batch.
        // They are of shape [batch size, 2], with two values per example.
        // To obtain a single value (score) per example, a softmax function is applied

        TensorFloat tensorScores;
        if (multipleTrueClasses || logits.shape.Length(0, 1) == 1)
        {
            // Softmax over the entailment vs. contradiction dimension for each label independently
            tensorScores = ops.Softmax(logits, -1);
        }
        else
        {
            // Softmax over all candidate labels
            tensorScores = ops.Softmax(logits, 0);
        }

        tensorScores.MakeReadable();
        float[] tensorArray = tensorScores.ToReadOnlyArray();

        tensorScores.Dispose();

        // Select the first column which is the column where the scores are stored
        float[] scores = new float[tensorArray.Length / 2];
        for (int i = 0; i < scores.Length; i++)
        {
            scores[i] = tensorArray[i * 2];
        }

        return scores;
    }

    List<int> Tokenize(string input)
    {
        string[] words = input.Split(null);

        List<int> ids = new();

        foreach (string word in words)
        {
            int start = 0;
            for(int i = word.Length; i >= 0;i--)
            {
                string subWord = start == 0 ? "▁" + word.Substring(start, i) : word.Substring(start, i-start);
                int index = Array.IndexOf(vocabularyTokens, subWord);
                if (index >= 0)
                {
                    ids.Add(index + vocabToTokenOffset);
                    if (i == word.Length) break;
                    start = i;
                    i = word.Length + 1;
                }
            }
        }

        return ids;
    }

    void OnDestroy()
    {
        engine?.Dispose();
        allocator?.Dispose();
        ops?.Dispose();
    }

    struct Batch
    {
        public int BatchCount;
        public int BatchLength;
        public int[] BatchedTokens;
        public int[] BatchedMasks;
    }
}