|
using System.Collections; |
|
using System.Collections.Generic; |
|
using UnityEngine; |
|
using Unity.Sentis; |
|
using System.IO; |
|
using Newtonsoft.Json; |
|
using System.Text; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public class RunWhisper : MonoBehaviour |
|
{ |
|
IWorker decoderEngine, encoderEngine, spectroEngine; |
|
|
|
const BackendType backend = BackendType.GPUCompute; |
|
|
|
|
|
public AudioClip audioClip; |
|
|
|
const int maxTokens = 100; |
|
|
|
|
|
const int END_OF_TEXT = 50257; |
|
const int START_OF_TRANSCRIPT = 50258; |
|
const int ENGLISH = 50259; |
|
const int TRANSCRIBE = 50359; |
|
const int START_TIME = 50364; |
|
|
|
Ops ops; |
|
ITensorAllocator allocator; |
|
|
|
int numSamples; |
|
float[] data; |
|
string[] tokens; |
|
|
|
int currentToken = 0; |
|
int[] outputTokens = new int[maxTokens]; |
|
|
|
|
|
int[] shiftDownDict = new int[256]; |
|
|
|
TensorFloat encodedAudio; |
|
|
|
bool transcribe = false; |
|
string outputString = ""; |
|
|
|
void Start() |
|
{ |
|
allocator = new TensorCachingAllocator(); |
|
ops = WorkerFactory.CreateOps(backend, allocator); |
|
|
|
SetupCharacterShifts(); |
|
|
|
GetTokens(); |
|
|
|
Model decoder = ModelLoader.Load(Application.streamingAssetsPath + "/AudioDecoder_Tiny.sentis"); |
|
Model encoder = ModelLoader.Load(Application.streamingAssetsPath + "/AudioEncoder_Tiny.sentis"); |
|
Model spectro = ModelLoader.Load(Application.streamingAssetsPath + "/LogMelSepctro.sentis"); |
|
|
|
decoderEngine = WorkerFactory.CreateWorker(backend, decoder); |
|
encoderEngine = WorkerFactory.CreateWorker(backend, encoder); |
|
spectroEngine = WorkerFactory.CreateWorker(backend, spectro); |
|
|
|
outputTokens[0] = START_OF_TRANSCRIPT; |
|
outputTokens[1] = ENGLISH; |
|
outputTokens[2] = TRANSCRIBE; |
|
outputTokens[3] = START_TIME; |
|
currentToken = 3; |
|
|
|
LoadAudio(); |
|
EncodeAudio(); |
|
transcribe = true; |
|
} |
|
|
|
void LoadAudio() |
|
{ |
|
if(audioClip.frequency != 16000) |
|
{ |
|
Debug.Log($"The audio clip should have frequency 16kHz. It has frequency {audioClip.frequency / 1000f}kHz"); |
|
} |
|
|
|
numSamples = audioClip.samples; |
|
data = new float[numSamples]; |
|
audioClip.GetData(data, 0); |
|
} |
|
|
|
|
|
void GetTokens() |
|
{ |
|
var jsonText = File.ReadAllText(Application.streamingAssetsPath + "/vocab.json"); |
|
var vocab = Newtonsoft.Json.JsonConvert.DeserializeObject<Dictionary<string, int>>(jsonText); |
|
tokens = new string[vocab.Count]; |
|
foreach(var item in vocab) |
|
{ |
|
tokens[item.Value] = item.Key; |
|
} |
|
} |
|
|
|
void EncodeAudio() |
|
{ |
|
var input = new TensorFloat(new TensorShape(1, numSamples), data); |
|
|
|
int maxSamples = 30 * 16000; |
|
if (numSamples > maxSamples) |
|
{ |
|
Debug.Log("The AudioClip is too long."); |
|
return; |
|
} |
|
|
|
|
|
var input30seconds = ops.Pad(input, new int[] { 0, 0, 0, 30 * 16000 - numSamples }); |
|
|
|
spectroEngine.Execute(input30seconds); |
|
var spectroOutput = spectroEngine.PeekOutput() as TensorFloat; |
|
|
|
encoderEngine.Execute(spectroOutput); |
|
encodedAudio = encoderEngine.PeekOutput() as TensorFloat; |
|
} |
|
|
|
|
|
|
|
void Update() |
|
{ |
|
if (transcribe && currentToken < outputTokens.Length - 1) |
|
{ |
|
var tokensSoFar = new TensorInt(new TensorShape(1, outputTokens.Length), outputTokens); |
|
|
|
var inputs = new Dictionary<string, Tensor> |
|
{ |
|
{"encoded_audio",encodedAudio }, |
|
{"tokens" , tokensSoFar } |
|
}; |
|
|
|
decoderEngine.Execute(inputs); |
|
var tokensOut = decoderEngine.PeekOutput() as TensorFloat; |
|
|
|
var tokensPredictions = ops.ArgMax(tokensOut, 2, false); |
|
tokensPredictions.MakeReadable(); |
|
|
|
int ID = tokensPredictions[currentToken]; |
|
|
|
currentToken++; |
|
outputTokens[currentToken] = ID; |
|
|
|
if (ID == END_OF_TEXT) |
|
{ |
|
transcribe = false; |
|
} |
|
else if (ID >= tokens.Length) outputString += $"(time={(ID - START_TIME) * 0.02f})"; |
|
else outputString += GetUnicodeText(tokens[ID]); |
|
|
|
Debug.Log(outputString); |
|
} |
|
} |
|
|
|
|
|
string GetUnicodeText(string text) |
|
{ |
|
var bytes = Encoding.GetEncoding("ISO-8859-1").GetBytes(ShiftCharacterDown(text)); |
|
return Encoding.UTF8.GetString(bytes); |
|
} |
|
|
|
string ShiftCharacterDown(string text) |
|
{ |
|
string outText = ""; |
|
foreach (char letter in text) |
|
{ |
|
outText += ((int)letter <= 256) ? letter : |
|
(char)shiftDownDict[(int)(letter - 256)]; |
|
} |
|
return outText; |
|
} |
|
|
|
void SetupCharacterShifts() |
|
{ |
|
for (int i = 0, n = 0; i < 256; i++) |
|
{ |
|
if (IsWhiteSpace((char)i)) shiftDownDict[n++] = i; |
|
} |
|
} |
|
|
|
bool IsWhiteSpace(char c) |
|
{ |
|
return !(('!' <= c && c <= '~') || ('�' <= c && c <= '�') || ('�' <= c && c <= '�')); |
|
} |
|
|
|
private void OnDestroy() |
|
{ |
|
decoderEngine?.Dispose(); |
|
encoderEngine?.Dispose(); |
|
spectroEngine?.Dispose(); |
|
ops?.Dispose(); |
|
} |
|
} |
|
|