AdamCodd's picture
Update README.md
5ae75b7 verified
|
raw
history blame
3.06 kB
metadata
datasets:
  - AdamCodd/Civitai-2m-prompts
metrics:
  - accuracy
  - f1
  - precision
  - recall
  - roc_auc
inference: true
base_model: distilroberta-base
model-index:
  - name: distilroberta-nsfw-prompt-stable-diffusion
    results:
      - task:
          type: text-classification
          name: Text Classification
        metrics:
          - type: loss
            value: 0.3103
          - type: accuracy
            value: 0.8642
            name: Accuracy
          - type: f1
            value: 0.8612
            name: F1
          - type: precision
            value: 0.8805
            name: Precision
          - type: recall
            value: 0.8427
            name: Recall
          - type: ROC_AUC
            value: 0.9408
            name: AUC
language:
  - en

DistilRoBERTa-nsfw-prompt-stable-diffusion

This model utilizes the Distilroberta base architecture, which has been fine-tuned for a classification task on AdamCodd/Civitai-2m-prompts dataset, on the positive prompts. It achieves the following results on the evaluation set:

  • Loss: 0.3103
  • Accuracy: 0.8642
  • F1: 0.8612
  • AUC: 0.9408
  • Precision: 0.8805
  • Recall: 0.8427

Model description

This model is designed to identify NSFW prompts in Stable-diffusion, trained on a dataset comprising of ~2 million prompts, evenly split between SFW and NSFW categories (1,043,475 samples of each, ensuring a balanced dataset). Single-word prompts have been excluded to enhance the accuracy and relevance of the predictions.

Although this model demonstrates satisfactory accuracy, it is recommended to use it in conjunction with this image NSFW detector to improve overall detection capabilities and minimize the occurrence of false positives.

Usage

from transformers import pipeline

prompt_detector = pipeline("text-classification", model="AdamCodd/distilroberta-nsfw-prompt-stable-diffusion")

predicted_class = prompt_detector("masterpiece, 1girl, yellow sundress, looking at viewer")
print(predicted_class)
#[{'label': 'SFW', 'score': 0.9983291029930115}]

Training and evaluation data

More information needed

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 3e-05
  • train_batch_size: 32
  • eval_batch_size: 32
  • seed: 42
  • optimizer: AdamW with betas=(0.9,0.999) and epsilon=1e-08
  • lr_scheduler_type: linear
  • lr_scheduler_warmup_steps: 150
  • num_epochs: 1
  • weight_decay: 0.01

Training results

Metrics: Accuracy, F1, Precision, Recall, AUC

'eval_loss': 0.3103,
'eval_accuracy': 0.8642,
'eval_f1': 0.8612,
'eval_precision': 0.8805,
'eval_recall': 0.8427,
'eval_roc_auc': 0.9408,

Confusion matrix:

[[184931 23859]

[32820 175780]]

Framework versions

  • Transformers 4.36.2
  • Datasets 2.16.1
  • Tokenizers 0.15.0
  • Evaluate 0.4.1

If you want to support me, you can here.