Spaces:
Runtime error
Runtime error
// Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
import Accelerate | |
import CoreImage | |
import Foundation | |
import TensorFlowLite | |
import UIKit | |
/// This class handles all data preprocessing and makes calls to run inference on a given frame | |
/// by invoking the `Interpreter`. It then formats the inferences obtained. | |
class ModelDataHandler { | |
// MARK: - Private Properties | |
/// TensorFlow Lite `Interpreter` object for performing inference on a given model. | |
private var interpreter: Interpreter | |
/// TensorFlow lite `Tensor` of model input and output. | |
private var inputTensor: Tensor | |
//private var heatsTensor: Tensor | |
//private var offsetsTensor: Tensor | |
private var outputTensor: Tensor | |
// MARK: - Initialization | |
/// A failable initializer for `ModelDataHandler`. A new instance is created if the model is | |
/// successfully loaded from the app's main bundle. Default `threadCount` is 2. | |
init( | |
threadCount: Int = Constants.defaultThreadCount, | |
delegate: Delegates = Constants.defaultDelegate | |
) throws { | |
// Construct the path to the model file. | |
guard | |
let modelPath = Bundle.main.path( | |
forResource: Model.file.name, | |
ofType: Model.file.extension | |
) | |
else { | |
fatalError("Failed to load the model file with name: \(Model.file.name).") | |
} | |
// Specify the options for the `Interpreter`. | |
var options = Interpreter.Options() | |
options.threadCount = threadCount | |
// Specify the delegates for the `Interpreter`. | |
var delegates: [Delegate]? | |
switch delegate { | |
case .Metal: | |
delegates = [MetalDelegate()] | |
case .CoreML: | |
if let coreMLDelegate = CoreMLDelegate() { | |
delegates = [coreMLDelegate] | |
} else { | |
delegates = nil | |
} | |
default: | |
delegates = nil | |
} | |
// Create the `Interpreter`. | |
interpreter = try Interpreter(modelPath: modelPath, options: options, delegates: delegates) | |
// Initialize input and output `Tensor`s. | |
// Allocate memory for the model's input `Tensor`s. | |
try interpreter.allocateTensors() | |
// Get allocated input and output `Tensor`s. | |
inputTensor = try interpreter.input(at: 0) | |
outputTensor = try interpreter.output(at: 0) | |
//heatsTensor = try interpreter.output(at: 0) | |
//offsetsTensor = try interpreter.output(at: 1) | |
/* | |
// Check if input and output `Tensor`s are in the expected formats. | |
guard (inputTensor.dataType == .uInt8) == Model.isQuantized else { | |
fatalError("Unexpected Model: quantization is \(!Model.isQuantized)") | |
} | |
guard inputTensor.shape.dimensions[0] == Model.input.batchSize, | |
inputTensor.shape.dimensions[1] == Model.input.height, | |
inputTensor.shape.dimensions[2] == Model.input.width, | |
inputTensor.shape.dimensions[3] == Model.input.channelSize | |
else { | |
fatalError("Unexpected Model: input shape") | |
} | |
guard heatsTensor.shape.dimensions[0] == Model.output.batchSize, | |
heatsTensor.shape.dimensions[1] == Model.output.height, | |
heatsTensor.shape.dimensions[2] == Model.output.width, | |
heatsTensor.shape.dimensions[3] == Model.output.keypointSize | |
else { | |
fatalError("Unexpected Model: heat tensor") | |
} | |
guard offsetsTensor.shape.dimensions[0] == Model.output.batchSize, | |
offsetsTensor.shape.dimensions[1] == Model.output.height, | |
offsetsTensor.shape.dimensions[2] == Model.output.width, | |
offsetsTensor.shape.dimensions[3] == Model.output.offsetSize | |
else { | |
fatalError("Unexpected Model: offset tensor") | |
} | |
*/ | |
} | |
/// Runs Midas model with given image with given source area to destination area. | |
/// | |
/// - Parameters: | |
/// - on: Input image to run the model. | |
/// - from: Range of input image to run the model. | |
/// - to: Size of view to render the result. | |
/// - Returns: Result of the inference and the times consumed in every steps. | |
func runMidas(on pixelbuffer: CVPixelBuffer, from source: CGRect, to dest: CGSize) | |
//-> (Result, Times)? | |
//-> (FlatArray<Float32>, Times)? | |
-> ([Float], Int, Int, Times)? | |
{ | |
// Start times of each process. | |
let preprocessingStartTime: Date | |
let inferenceStartTime: Date | |
let postprocessingStartTime: Date | |
// Processing times in miliseconds. | |
let preprocessingTime: TimeInterval | |
let inferenceTime: TimeInterval | |
let postprocessingTime: TimeInterval | |
preprocessingStartTime = Date() | |
guard let data = preprocess(of: pixelbuffer, from: source) else { | |
os_log("Preprocessing failed", type: .error) | |
return nil | |
} | |
preprocessingTime = Date().timeIntervalSince(preprocessingStartTime) * 1000 | |
inferenceStartTime = Date() | |
inference(from: data) | |
inferenceTime = Date().timeIntervalSince(inferenceStartTime) * 1000 | |
postprocessingStartTime = Date() | |
//guard let result = postprocess(to: dest) else { | |
// os_log("Postprocessing failed", type: .error) | |
// return nil | |
//} | |
postprocessingTime = Date().timeIntervalSince(postprocessingStartTime) * 1000 | |
let results: [Float] | |
switch outputTensor.dataType { | |
case .uInt8: | |
guard let quantization = outputTensor.quantizationParameters else { | |
print("No results returned because the quantization values for the output tensor are nil.") | |
return nil | |
} | |
let quantizedResults = [UInt8](outputTensor.data) | |
results = quantizedResults.map { | |
quantization.scale * Float(Int($0) - quantization.zeroPoint) | |
} | |
case .float32: | |
results = [Float32](unsafeData: outputTensor.data) ?? [] | |
default: | |
print("Output tensor data type \(outputTensor.dataType) is unsupported for this example app.") | |
return nil | |
} | |
let times = Times( | |
preprocessing: preprocessingTime, | |
inference: inferenceTime, | |
postprocessing: postprocessingTime) | |
return (results, Model.input.width, Model.input.height, times) | |
} | |
// MARK: - Private functions to run model | |
/// Preprocesses given rectangle image to be `Data` of disired size by croping and resizing it. | |
/// | |
/// - Parameters: | |
/// - of: Input image to crop and resize. | |
/// - from: Target area to be cropped and resized. | |
/// - Returns: The cropped and resized image. `nil` if it can not be processed. | |
private func preprocess(of pixelBuffer: CVPixelBuffer, from targetSquare: CGRect) -> Data? { | |
let sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer) | |
assert(sourcePixelFormat == kCVPixelFormatType_32BGRA) | |
// Resize `targetSquare` of input image to `modelSize`. | |
let modelSize = CGSize(width: Model.input.width, height: Model.input.height) | |
guard let thumbnail = pixelBuffer.resize(from: targetSquare, to: modelSize) | |
else { | |
return nil | |
} | |
// Remove the alpha component from the image buffer to get the initialized `Data`. | |
let byteCount = | |
Model.input.batchSize | |
* Model.input.height * Model.input.width | |
* Model.input.channelSize | |
guard | |
let inputData = thumbnail.rgbData( | |
isModelQuantized: Model.isQuantized | |
) | |
else { | |
os_log("Failed to convert the image buffer to RGB data.", type: .error) | |
return nil | |
} | |
return inputData | |
} | |
/* | |
/// Postprocesses output `Tensor`s to `Result` with size of view to render the result. | |
/// | |
/// - Parameters: | |
/// - to: Size of view to be displaied. | |
/// - Returns: Postprocessed `Result`. `nil` if it can not be processed. | |
private func postprocess(to viewSize: CGSize) -> Result? { | |
// MARK: Formats output tensors | |
// Convert `Tensor` to `FlatArray`. As Midas is not quantized, convert them to Float type | |
// `FlatArray`. | |
let heats = FlatArray<Float32>(tensor: heatsTensor) | |
let offsets = FlatArray<Float32>(tensor: offsetsTensor) | |
// MARK: Find position of each key point | |
// Finds the (row, col) locations of where the keypoints are most likely to be. The highest | |
// `heats[0, row, col, keypoint]` value, the more likely `keypoint` being located in (`row`, | |
// `col`). | |
let keypointPositions = (0..<Model.output.keypointSize).map { keypoint -> (Int, Int) in | |
var maxValue = heats[0, 0, 0, keypoint] | |
var maxRow = 0 | |
var maxCol = 0 | |
for row in 0..<Model.output.height { | |
for col in 0..<Model.output.width { | |
if heats[0, row, col, keypoint] > maxValue { | |
maxValue = heats[0, row, col, keypoint] | |
maxRow = row | |
maxCol = col | |
} | |
} | |
} | |
return (maxRow, maxCol) | |
} | |
// MARK: Calculates total confidence score | |
// Calculates total confidence score of each key position. | |
let totalScoreSum = keypointPositions.enumerated().reduce(0.0) { accumulator, elem -> Float32 in | |
accumulator + sigmoid(heats[0, elem.element.0, elem.element.1, elem.offset]) | |
} | |
let totalScore = totalScoreSum / Float32(Model.output.keypointSize) | |
// MARK: Calculate key point position on model input | |
// Calculates `KeyPoint` coordination model input image with `offsets` adjustment. | |
let coords = keypointPositions.enumerated().map { index, elem -> (y: Float32, x: Float32) in | |
let (y, x) = elem | |
let yCoord = | |
Float32(y) / Float32(Model.output.height - 1) * Float32(Model.input.height) | |
+ offsets[0, y, x, index] | |
let xCoord = | |
Float32(x) / Float32(Model.output.width - 1) * Float32(Model.input.width) | |
+ offsets[0, y, x, index + Model.output.keypointSize] | |
return (y: yCoord, x: xCoord) | |
} | |
// MARK: Transform key point position and make lines | |
// Make `Result` from `keypointPosition'. Each point is adjusted to `ViewSize` to be drawn. | |
var result = Result(dots: [], lines: [], score: totalScore) | |
var bodyPartToDotMap = [BodyPart: CGPoint]() | |
for (index, part) in BodyPart.allCases.enumerated() { | |
let position = CGPoint( | |
x: CGFloat(coords[index].x) * viewSize.width / CGFloat(Model.input.width), | |
y: CGFloat(coords[index].y) * viewSize.height / CGFloat(Model.input.height) | |
) | |
bodyPartToDotMap[part] = position | |
result.dots.append(position) | |
} | |
do { | |
try result.lines = BodyPart.lines.map { map throws -> Line in | |
guard let from = bodyPartToDotMap[map.from] else { | |
throw PostprocessError.missingBodyPart(of: map.from) | |
} | |
guard let to = bodyPartToDotMap[map.to] else { | |
throw PostprocessError.missingBodyPart(of: map.to) | |
} | |
return Line(from: from, to: to) | |
} | |
} catch PostprocessError.missingBodyPart(let missingPart) { | |
os_log("Postprocessing error: %s is missing.", type: .error, missingPart.rawValue) | |
return nil | |
} catch { | |
os_log("Postprocessing error: %s", type: .error, error.localizedDescription) | |
return nil | |
} | |
return result | |
} | |
*/ | |
/// Run inference with given `Data` | |
/// | |
/// Parameter `from`: `Data` of input image to run model. | |
private func inference(from data: Data) { | |
// Copy the initialized `Data` to the input `Tensor`. | |
do { | |
try interpreter.copy(data, toInputAt: 0) | |
// Run inference by invoking the `Interpreter`. | |
try interpreter.invoke() | |
// Get the output `Tensor` to process the inference results. | |
outputTensor = try interpreter.output(at: 0) | |
//heatsTensor = try interpreter.output(at: 0) | |
//offsetsTensor = try interpreter.output(at: 1) | |
} catch let error { | |
os_log( | |
"Failed to invoke the interpreter with error: %s", type: .error, | |
error.localizedDescription) | |
return | |
} | |
} | |
/// Returns value within [0,1]. | |
private func sigmoid(_ x: Float32) -> Float32 { | |
return (1.0 / (1.0 + exp(-x))) | |
} | |
} | |
// MARK: - Data types for inference result | |
struct KeyPoint { | |
var bodyPart: BodyPart = BodyPart.NOSE | |
var position: CGPoint = CGPoint() | |
var score: Float = 0.0 | |
} | |
struct Line { | |
let from: CGPoint | |
let to: CGPoint | |
} | |
struct Times { | |
var preprocessing: Double | |
var inference: Double | |
var postprocessing: Double | |
} | |
struct Result { | |
var dots: [CGPoint] | |
var lines: [Line] | |
var score: Float | |
} | |
enum BodyPart: String, CaseIterable { | |
case NOSE = "nose" | |
case LEFT_EYE = "left eye" | |
case RIGHT_EYE = "right eye" | |
case LEFT_EAR = "left ear" | |
case RIGHT_EAR = "right ear" | |
case LEFT_SHOULDER = "left shoulder" | |
case RIGHT_SHOULDER = "right shoulder" | |
case LEFT_ELBOW = "left elbow" | |
case RIGHT_ELBOW = "right elbow" | |
case LEFT_WRIST = "left wrist" | |
case RIGHT_WRIST = "right wrist" | |
case LEFT_HIP = "left hip" | |
case RIGHT_HIP = "right hip" | |
case LEFT_KNEE = "left knee" | |
case RIGHT_KNEE = "right knee" | |
case LEFT_ANKLE = "left ankle" | |
case RIGHT_ANKLE = "right ankle" | |
/// List of lines connecting each part. | |
static let lines = [ | |
(from: BodyPart.LEFT_WRIST, to: BodyPart.LEFT_ELBOW), | |
(from: BodyPart.LEFT_ELBOW, to: BodyPart.LEFT_SHOULDER), | |
(from: BodyPart.LEFT_SHOULDER, to: BodyPart.RIGHT_SHOULDER), | |
(from: BodyPart.RIGHT_SHOULDER, to: BodyPart.RIGHT_ELBOW), | |
(from: BodyPart.RIGHT_ELBOW, to: BodyPart.RIGHT_WRIST), | |
(from: BodyPart.LEFT_SHOULDER, to: BodyPart.LEFT_HIP), | |
(from: BodyPart.LEFT_HIP, to: BodyPart.RIGHT_HIP), | |
(from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_SHOULDER), | |
(from: BodyPart.LEFT_HIP, to: BodyPart.LEFT_KNEE), | |
(from: BodyPart.LEFT_KNEE, to: BodyPart.LEFT_ANKLE), | |
(from: BodyPart.RIGHT_HIP, to: BodyPart.RIGHT_KNEE), | |
(from: BodyPart.RIGHT_KNEE, to: BodyPart.RIGHT_ANKLE), | |
] | |
} | |
// MARK: - Delegates Enum | |
enum Delegates: Int, CaseIterable { | |
case CPU | |
case Metal | |
case CoreML | |
var description: String { | |
switch self { | |
case .CPU: | |
return "CPU" | |
case .Metal: | |
return "GPU" | |
case .CoreML: | |
return "NPU" | |
} | |
} | |
} | |
// MARK: - Custom Errors | |
enum PostprocessError: Error { | |
case missingBodyPart(of: BodyPart) | |
} | |
// MARK: - Information about the model file. | |
typealias FileInfo = (name: String, extension: String) | |
enum Model { | |
static let file: FileInfo = ( | |
name: "model_opt", extension: "tflite" | |
) | |
static let input = (batchSize: 1, height: 256, width: 256, channelSize: 3) | |
static let output = (batchSize: 1, height: 256, width: 256, channelSize: 1) | |
static let isQuantized = false | |
} | |
extension Array { | |
/// Creates a new array from the bytes of the given unsafe data. | |
/// | |
/// - Warning: The array's `Element` type must be trivial in that it can be copied bit for bit | |
/// with no indirection or reference-counting operations; otherwise, copying the raw bytes in | |
/// the `unsafeData`'s buffer to a new array returns an unsafe copy. | |
/// - Note: Returns `nil` if `unsafeData.count` is not a multiple of | |
/// `MemoryLayout<Element>.stride`. | |
/// - Parameter unsafeData: The data containing the bytes to turn into an array. | |
init?(unsafeData: Data) { | |
guard unsafeData.count % MemoryLayout<Element>.stride == 0 else { return nil } | |
#if swift(>=5.0) | |
self = unsafeData.withUnsafeBytes { .init($0.bindMemory(to: Element.self)) } | |
#else | |
self = unsafeData.withUnsafeBytes { | |
.init(UnsafeBufferPointer<Element>( | |
start: $0, | |
count: unsafeData.count / MemoryLayout<Element>.stride | |
)) | |
} | |
#endif // swift(>=5.0) | |
} | |
} | |