Paul Bird commited on
Commit
3d94e2c
·
verified ·
1 Parent(s): 61773dc

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. RunTinyStories.cs +273 -0
  3. merges.txt +0 -0
  4. tinystories.sentis +3 -0
  5. vocab.json +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tinystories.sentis filter=lfs diff=lfs merge=lfs -text
RunTinyStories.cs ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ using System.Collections;
2
+ using System.Collections.Generic;
3
+ using UnityEngine;
4
+ using Unity.Sentis;
5
+ using System.IO;
6
+ using Newtonsoft.Json;
7
+ using System.Text;
8
+
9
+ /*
10
+ * Tiny Stories Inference Code
11
+ * ===========================
12
+ *
13
+ * Put this script on the Main Camera
14
+ *
15
+ * In Assets/StreamingAssets put:
16
+ *
17
+ * tinystories.sentis
18
+ * vocab.json
19
+ * merges.txt
20
+ *
21
+ * Install package com.unity.nuget.newtonsoft-json from packagemanger
22
+ * Install package com.unity.sentis
23
+ *
24
+ */
25
+
26
+
27
+ public class RunTinyStories : MonoBehaviour
28
+ {
29
+ const BackendType backend = BackendType.GPUCompute;
30
+
31
+ //string outputString = "Once upon a time, there were three bears";
32
+ string outputString = "One day an alien came down from Mars. It saw a chicken";
33
+
34
+ // This is how many tokens you want. It can be adjusted.
35
+ const int maxTokens = 100;
36
+
37
+ //Make this smaller for more randomness
38
+ const float predictability = 5f;
39
+
40
+ //Special tokens
41
+ const int END_OF_TEXT = 50256;
42
+
43
+ Ops ops;
44
+ ITensorAllocator allocator;
45
+
46
+ //Store the vocabulary
47
+ string[] tokens;
48
+
49
+ IWorker engine;
50
+
51
+ int currentToken = 0;
52
+ int[] outputTokens = new int[maxTokens];
53
+
54
+ // Used for special character decoding
55
+ int[] whiteSpaceCharacters = new int[256];
56
+ int[] encodedWhiteSpace = new int[256];
57
+
58
+ bool runInference = false;
59
+
60
+
61
+ //stop after this many tokens
62
+ const int stopAfter = 100;
63
+
64
+ int totalTokens = 0;
65
+
66
+ string[] merges;
67
+ Dictionary<string, int> vocab;
68
+
69
+ void Start()
70
+ {
71
+ allocator = new TensorCachingAllocator();
72
+ ops = WorkerFactory.CreateOps(backend, allocator);
73
+
74
+ SetupWhiteSpaceShifts();
75
+
76
+ LoadVocabulary();
77
+
78
+ Model model = ModelLoader.Load(Application.streamingAssetsPath + "/tinystories.sentis");
79
+
80
+ engine = WorkerFactory.CreateWorker(backend, model);
81
+
82
+ DecodePrompt(outputString);
83
+
84
+ runInference = true;
85
+ }
86
+
87
+ // Update is called once per frame
88
+ void Update()
89
+ {
90
+ if (runInference)
91
+ {
92
+ RunInference();
93
+ }
94
+ }
95
+
96
+ void RunInference()
97
+ {
98
+ using var tokensSoFar = new TensorInt(new TensorShape(1, maxTokens), outputTokens);
99
+ engine.Execute(tokensSoFar);
100
+
101
+ var tokensOut = engine.PeekOutput() as TensorFloat;
102
+
103
+ using var row = ops.Slice(tokensOut, new[] { currentToken }, new[] { currentToken + 1 }, new[] { 1 }, new[] { 1 });
104
+ using var rowB = ops.Mul(predictability, row);
105
+ using var probs = ops.Softmax(rowB, 2);
106
+ probs.MakeReadable();
107
+
108
+ int ID = SelectRandomToken(probs.ToReadOnlyArray());
109
+
110
+ if (currentToken >= maxTokens - 1)
111
+ {
112
+ for (int i = 0; i < maxTokens - 1; i++) outputTokens[i] = outputTokens[i + 1];
113
+ currentToken--;
114
+ }
115
+
116
+ outputTokens[++currentToken] = ID;
117
+ totalTokens++;
118
+
119
+ if (ID == END_OF_TEXT || totalTokens >= stopAfter)
120
+ {
121
+ runInference = false;
122
+ }
123
+ else outputString += GetUnicodeText(tokens[ID]);
124
+
125
+ Debug.Log(outputString);
126
+
127
+ }
128
+
129
+
130
+ void DecodePrompt(string text)
131
+ {
132
+ var inputTokens = GetTokens(text);
133
+
134
+ for(int i = 0; i < inputTokens.Count; i++)
135
+ {
136
+ outputTokens[i] = inputTokens[i];
137
+ }
138
+ currentToken = inputTokens.Count - 1;
139
+ }
140
+
141
+
142
+ void LoadVocabulary()
143
+ {
144
+ var jsonText = File.ReadAllText(Application.streamingAssetsPath + "/vocab.json");
145
+ vocab = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, int>>(jsonText);
146
+ tokens = new string[vocab.Count];
147
+ foreach (var item in vocab)
148
+ {
149
+ tokens[item.Value] = item.Key;
150
+ }
151
+
152
+ merges = File.ReadAllLines(Application.streamingAssetsPath + "/merges.txt");
153
+ }
154
+
155
+
156
+ int SelectRandomToken(float[] probs)
157
+ {
158
+ float p = UnityEngine.Random.Range(0, 1f);
159
+ float t = 0;
160
+ for (int i = 0; i < probs.Length; i++)
161
+ {
162
+ t += probs[i];
163
+ if (p < t)
164
+ {
165
+ return i;
166
+ }
167
+ }
168
+ return probs.Length - 1;
169
+ }
170
+
171
+ // Translates encoded special characters to Unicode
172
+ string GetUnicodeText(string text)
173
+ {
174
+ var bytes = Encoding.GetEncoding("ISO-8859-1").GetBytes(ShiftCharacterDown(text));
175
+ return Encoding.UTF8.GetString(bytes);
176
+ }
177
+ string GetASCIIText(string newText)
178
+ {
179
+ var bytes = Encoding.UTF8.GetBytes(newText);
180
+ return ShiftCharacterUp(Encoding.GetEncoding("ISO-8859-1").GetString(bytes));
181
+ }
182
+
183
+ string ShiftCharacterDown(string text)
184
+ {
185
+ string outText = "";
186
+ foreach (char letter in text)
187
+ {
188
+ outText += ((int)letter <= 256) ? letter :
189
+ (char)whiteSpaceCharacters[(int)(letter - 256)];
190
+ }
191
+ return outText;
192
+ }
193
+
194
+ string ShiftCharacterUp(string text)
195
+ {
196
+ string outText = "";
197
+ foreach (char letter in text)
198
+ {
199
+ outText += (char)encodedWhiteSpace[(int)letter];
200
+ }
201
+ return outText;
202
+ }
203
+
204
+ void SetupWhiteSpaceShifts()
205
+ {
206
+ for (int i = 0, n = 0; i < 256; i++)
207
+ {
208
+ encodedWhiteSpace[i] = i;
209
+ if (IsWhiteSpace((char)i))
210
+ {
211
+ encodedWhiteSpace[i] = n + 256;
212
+ whiteSpaceCharacters[n++] = i;
213
+ }
214
+ }
215
+ }
216
+
217
+ bool IsWhiteSpace(char c)
218
+ {
219
+ return !(('!' <= c && c <= '~') || ('�' <= c && c <= '�') || ('�' <= c && c <= '�'));
220
+ }
221
+
222
+ List<int> GetTokens(string text)
223
+ {
224
+ text = GetASCIIText(text);
225
+
226
+ // Start with a list of single characters
227
+ var inputTokens = new List<string>();
228
+ foreach(var letter in text)
229
+ {
230
+ inputTokens.Add(letter.ToString());
231
+ }
232
+
233
+ ApplyMerges(inputTokens);
234
+
235
+ //Find the ids of the words in the vocab
236
+ var ids = new List<int>();
237
+ foreach(var token in inputTokens)
238
+ {
239
+ if (vocab.TryGetValue(token, out int id))
240
+ {
241
+ ids.Add(id);
242
+ }
243
+ }
244
+
245
+ return ids;
246
+ }
247
+
248
+ void ApplyMerges(List<string> inputTokens)
249
+ {
250
+ foreach(var merge in merges)
251
+ {
252
+ string[] pair = merge.Split(' ');
253
+ int n = 0;
254
+ while (n >= 0)
255
+ {
256
+ n = inputTokens.IndexOf(pair[0], n);
257
+ if (n != -1 && n < inputTokens.Count - 1 && inputTokens[n + 1] == pair[1])
258
+ {
259
+ inputTokens[n] += inputTokens[n + 1];
260
+ inputTokens.RemoveAt(n + 1);
261
+ }
262
+ if (n != -1) n++;
263
+ }
264
+ }
265
+ }
266
+
267
+ private void OnDestroy()
268
+ {
269
+ engine?.Dispose();
270
+ ops?.Dispose();
271
+ allocator?.Dispose();
272
+ }
273
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tinystories.sentis ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8ed28a03db24da6fa58cc2bde739ecfb83b731ca47c263d17cdfec22e4b1698
3
+ size 478881707
vocab.json ADDED
The diff for this file is too large to render. See raw diff