Text-to-Image
Diffusers
Safetensors
English

Self-Play Fine-Tuning of Diffusion Models for Text-to-Image Generation (https://huggingface.co/papers/2402.10210)

image/png

SPIN-Diffusion-iter3

This model is a self-play fine-tuned diffusion model at iteration 3 from runwayml/stable-diffusion-v1-5 using synthetic data based on the winner images of the yuvalkirstain/pickapic_v2 dataset. We have also made a Gradio Demo at UCLA-AGI/SPIN-Diffusion-demo-v1.

Model Details

Model Description

  • Model type: A diffusion model with unet fine-tuned, based on the structure of stable diffusion 1.5
  • Language(s) (NLP): Primarily English
  • License: Apache-2.0
  • Finetuned from model: runwayml/stable-diffusion-v1-5

Training hyperparameters

The following hyperparameters were used during training:

  • learning_rate: 2.0e-05
  • train_batch_size: 8
  • distributed_type: multi-GPU
  • num_devices: 8
  • train_gradient_accumulation_steps: 32
  • total_train_batch_size: 2048
  • optimizer: AdamW
  • lr_scheduler: "linear"
  • lr_warmup_steps: 200
  • num_training_steps: 500

Usage

To use the model, you must first load the SD1.5 base model and then substitute its unet with our fine-tuned version.

from diffusers import StableDiffusionPipeline, UNet2DConditionModel
import torch

model_id = "runwayml/stable-diffusion-v1-5"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)

unet_id = "UCLA-AGI/SPIN-Diffusion-iter3"
unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float16)
pipe.unet = unet

###The rest of your generation code

Evaluation Results on Pick-a-pic test set

Metric Best of Five Mean Median
HPS 0.28 0.27 0.27
Aesthetic 6.26 5.94 5.98
Image Reward 1.13 0.53 0.67
Pickapic Score 22.00 21.36 21.46

Citation

@misc{yuan2024self,
      title={Self-Play Fine-Tuning of Diffusion Models for Text-to-Image Generation}, 
      author={Yuan, Huizhuo and Chen, Zixiang and Ji, Kaixuan and Gu, Quanquan},
      year={2024},
      eprint={2402.10210},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
Downloads last month
0
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train UCLA-AGI/SPIN-Diffusion-iter3

Space using UCLA-AGI/SPIN-Diffusion-iter3 1

Collection including UCLA-AGI/SPIN-Diffusion-iter3