| import * as tf from '@tensorflow/tfjs-node'; |
| import * as nsfwjs from 'nsfwjs'; |
| import {NSFWJS} from 'nsfwjs'; |
| import {Tensor3D} from '@tensorflow/tfjs'; |
| import {model as config} from 'app/config/model'; |
|
|
| tf.enableProdMode(); |
|
|
| export class NsfwImageClassifier { |
| #model?: NSFWJS; |
|
|
| async classify(imageBuffer: Buffer) { |
| const [model, image] = await Promise.all([ |
| this.#getModel(), |
| tf.node.decodeImage(imageBuffer, 3), |
| ]); |
|
|
| const predictions = await model.classify(image as Tensor3D); |
|
|
| image.dispose(); |
|
|
| return this.#transformData(predictions); |
| } |
|
|
| async classifyMany(imagesBuffers: Buffer[]) { |
| return await Promise.all(imagesBuffers.map(buffer => this.classify(buffer))); |
| } |
|
|
| async #getModel(): Promise<NSFWJS> { |
| if (!this.#model) { |
| this.#model = await nsfwjs.load('file://model/', {size: config.size}); |
| } |
|
|
| return this.#model; |
| } |
|
|
| #transformData(data: { className: string; probability: number }[]): Record<string, number> { |
| const result: Record<string, number> = {}; |
|
|
| for (const item of data) { |
| result[item.className.toLowerCase()] = item.probability; |
| } |
|
|
| return result; |
| } |
| } |
|
|