sentis-blaze-face / RunBlazeFace.cs
Paul Bird
Upload 3 files
bbd25eb verified
raw
history blame
No virus
11 kB
using UnityEngine;
using Unity.Sentis;
using UnityEngine.Video;
using UnityEngine.UI;
using Lays = Unity.Sentis.Layers;
/*
* Blaze Face Inference
* ====================
*
* Basic inference script for blaze face
*
* Put this script on the Main Camera
* Put blazeface.sentis in the Assets/StreamingAssets folder
* Create a RawImage of size 320x320 in the scene
* Put a link to that image in _previewUI
* Put a video in Assets/StreamingAssets folder and put the name of it int videoName
* Or put a test image in _inputImage
* Set _inputType to appropriate input
*/
public class RunBlazeFace : MonoBehaviour
{
//Drag a link to a raw image here:
public RawImage _previewUI = null;
// Put your bounding box sprite image here
public Sprite faceTexture;
// 6 optional sprite images (left eye, right eye, nose, mouth, left ear, right ear)
public Sprite[] markerTextures;
const string videoName = "chatting.mp4";
//
public Texture2D _inputImage;
public InputType _inputType = InputType.Video;
Vector2Int _resolution = new Vector2Int(640, 640);
WebCamTexture _webcam;
VideoPlayer _video;
const BackendType backend = BackendType.GPUCompute;
RenderTexture _targetTexture;
public enum InputType { Image, Video, Webcam };
//Some adjustable parameters for the model
[SerializeField, Range(0, 1)] float _iouThreshold = 0.5f;
[SerializeField, Range(0, 1)] float _scoreThreshold = 0.5f;
int _maxOutputBoxes = 64;
IWorker _worker;
//Holds image size
int _size;
Ops ops;
ITensorAllocator allocator;
Model _model;
//webcam device name:
const string _deviceName = "";
bool closing = false;
public struct BoundingBox
{
public float centerX;
public float centerY;
public float width;
public float height;
}
void Start()
{
allocator = new TensorCachingAllocator();
//(Note: if using a webcam on mobile get permissions here first)
_targetTexture = new RenderTexture(_resolution.x, _resolution.y, 0);
SetupInput();
SetupModel();
SetupEngine();
}
void SetupInput()
{
switch (_inputType)
{
case InputType.Webcam:
{
_webcam = new WebCamTexture(_deviceName, _resolution.x, _resolution.y);
_webcam.requestedFPS = 30;
_webcam.Play();
break;
}
case InputType.Video:
{
_video = gameObject.AddComponent<VideoPlayer>();//new VideoPlayer();
_video.renderMode = VideoRenderMode.APIOnly;
_video.source = VideoSource.Url;
_video.url = Application.streamingAssetsPath + "/"+videoName;
_video.isLooping = true;
_video.Play();
break;
}
default:
{
Graphics.Blit(_inputImage, _targetTexture);
}
break;
}
}
void Update()
{
if (_inputType == InputType.Webcam)
{
// Format video input
if (!_webcam.didUpdateThisFrame) return;
var aspect1 = (float)_webcam.width / _webcam.height;
var aspect2 = (float)_resolution.x / _resolution.y;
var gap = aspect2 / aspect1;
var vflip = _webcam.videoVerticallyMirrored;
var scale = new Vector2(gap, vflip ? -1 : 1);
var offset = new Vector2((1 - gap) / 2, vflip ? 1 : 0);
Graphics.Blit(_webcam, _targetTexture, scale, offset);
}
if (_inputType == InputType.Video)
{
var aspect1 = (float)_video.width / _video.height;
var aspect2 = (float)_resolution.x / _resolution.y;
var gap = aspect2 / aspect1;
var vflip = false;
var scale = new Vector2(gap, vflip ? -1 : 1);
var offset = new Vector2((1 - gap) / 2, vflip ? 1 : 0);
Graphics.Blit(_video.texture, _targetTexture, scale, offset);
}
if (_inputType == InputType.Image)
{
Graphics.Blit(_inputImage, _targetTexture);
}
if (Input.GetKeyDown(KeyCode.Escape))
{
closing = true;
Application.Quit();
}
if (Input.GetKeyDown(KeyCode.P))
{
_previewUI.enabled = !_previewUI.enabled;
}
}
void LateUpdate()
{
if (!closing)
{
RunInference(_targetTexture);
}
}
//Calculate the centers of the grid squares for two 16x16 grids and six 8x8 grids
float[] GetGridBoxCoords()
{
var offsets = new float[896 * 4];
int n = 0;
for (int j = 0; j < 2 * 16 * 16; j++)
{
offsets[n++] = (j / 2) % 16 - 7.5f;
offsets[n++] = (j / 2 / 16) - 7.5f;
n += 2;
}
for (int j = 0; j < 6 * 8 * 8; j++)
{
offsets[n++] = (j / 6) % 8 - 7f;
offsets[n++] = (j / 6 / 8) - 7f;
n += 2;
}
return offsets;
}
void SetupModel()
{
float[] offsets = GetGridBoxCoords();
_model = ModelLoader.Load(Application.streamingAssetsPath + "/blazeface.sentis");
//We need to add extra layers to the model in order to aggregate the box predicions:
_size = _model.inputs[0].shape.ToTensorShape()[1]; // Input tensor width
_model.AddConstant(new Lays.Constant("zero", new int[] { 0 }));
_model.AddConstant(new Lays.Constant("two", new int[] { 2 }));
_model.AddConstant(new Lays.Constant("four", new int[] { 4 }));
_model.AddLayer(new Lays.Slice("boxes1", "regressors", "zero", "four", "two"));
_model.AddLayer(new Lays.Transpose("scores", "classificators", new int[] { 0, 2, 1 }));
_model.AddConstant(new Lays.Constant("eighth", new float[] { 1 / 8f }));
_model.AddConstant(new Lays.Constant("offsets",
new TensorFloat(new TensorShape(1, 896, 4), offsets)
));
_model.AddLayer(new Lays.Mul("boxes1scaled", "boxes1", "eighth"));
_model.AddLayer(new Lays.Add("boxCoords", "boxes1scaled", "offsets"));
_model.AddOutput("boxCoords");
_model.AddConstant(new Lays.Constant("maxOutputBoxes", new int[] { _maxOutputBoxes }));
_model.AddConstant(new Lays.Constant("iouThreshold", new float[] { _iouThreshold }));
_model.AddConstant(new Lays.Constant("scoreThreshold", new float[] { _scoreThreshold }));
_model.AddLayer(new Lays.NonMaxSuppression("NMS", "boxCoords", "scores",
"maxOutputBoxes", "iouThreshold", "scoreThreshold",
centerPointBox: Lays.CenterPointBox.Center
));
_model.AddOutput("NMS");
}
public void SetupEngine()
{
_worker = WorkerFactory.CreateWorker(backend, _model);
ops = WorkerFactory.CreateOps(backend, allocator);
}
void DrawFaces(TensorFloat index3, TensorFloat regressors, int NMAX, Vector2 scale)
{
for (int n = 0; n < NMAX; n++)
{
//Draw bounding box of face
var box = new BoundingBox
{
centerX = index3[0, n, 0] * scale.x,
centerY = index3[0, n, 1] * scale.y,
width = index3[0, n, 2] * scale.x,
height = index3[0, n, 3] * scale.y
};
DrawBox(box, faceTexture);
if (regressors == null) continue;
//Draw markers for eyes, ears, nose, mouth:
for (int j = 0; j < 6; j++)
{
var marker = new BoundingBox
{
centerX = box.centerX + (regressors[0, n, 4 + j * 2] - regressors[0, n, 0]) * scale.x / 8,
centerY = box.centerY + (regressors[0, n, 4 + j * 2 + 1] - regressors[0, n, 1]) * scale.y / 8,
width = 1.0f * scale.x,
height = 1.0f * scale.y,
};
DrawBox(marker, j < markerTextures.Length ? markerTextures[j] : faceTexture);
}
}
}
void ExecuteML(Texture source)
{
var transform = new TextureTransform();
transform.SetDimensions(_size, _size, 3);
transform.SetTensorLayout(0, 3, 1, 2);
using var image0 = TextureConverter.ToTensor(source, transform);
// Pre-process the image to make input in range (-1..1)
using var image = ops.Mad(image0, 2f, -1f);
_worker.Execute(image);
using var boxCoords = _worker.PeekOutput("boxCoords") as TensorFloat; //face coords
using var regressors = _worker.PeekOutput("regressors") as TensorFloat; //contains markers
var NM1 = _worker.PeekOutput("NMS") as TensorInt;
using var boxCoords2 = boxCoords.ShallowReshape(
new TensorShape(1, boxCoords.shape[0], boxCoords.shape[1], boxCoords.shape[2])) as TensorFloat;
var output = ops.GatherND(boxCoords2, NM1, 0);
using var regressors2 = regressors.ShallowReshape(
new TensorShape(1, regressors.shape[0], regressors.shape[1], regressors.shape[2])) as TensorFloat;
var markersOutput = ops.GatherND(regressors2, NM1, 0);
output.MakeReadable();
markersOutput.MakeReadable();
ClearAnnotations();
Vector2 markerScale = _previewUI.rectTransform.rect.size/ 16;
DrawFaces(output, markersOutput, output.shape[0], markerScale);
}
void RunInference(Texture input)
{
// Face detection
ExecuteML(input);
_previewUI.texture = input;
}
public void DrawBox(BoundingBox box, Sprite sprite)
{
var panel = new GameObject("ObjectBox");
panel.AddComponent<CanvasRenderer>();
panel.AddComponent<Image>();
panel.transform.SetParent(_previewUI.transform, false);
var img = panel.GetComponent<Image>();
img.color = Color.white;
img.sprite = sprite;
img.type = Image.Type.Sliced;
panel.transform.localPosition = new Vector3(box.centerX, -box.centerY);
RectTransform rt = panel.GetComponent<RectTransform>();
rt.sizeDelta = new Vector2(box.width, box.height);
}
public void ClearAnnotations()
{
foreach (Transform child in _previewUI.transform)
{
Destroy(child.gameObject);
}
}
void CleanUp()
{
ops?.Dispose();
allocator?.Dispose();
if (_webcam) Destroy(_webcam);
if (_video) Destroy(_video);
RenderTexture.active = null;
_targetTexture.Release();
_worker?.Dispose();
_worker = null;
}
void OnDestroy()
{
CleanUp();
}
}