Response Generator for Emotion Chat Bot

Model description

This model is a dpo fine-tuned version of hermeschen1116/response_generator_for_emotion_chat_bot on Shotaro30678/rlhf-RG-trl-style-v3, self modified version of daily_dialog.

Intended uses & limitations

Use dpo trainer to do the RLHF so that the model can be more precise and consistent.

Model performance

Sentiment Score: Shotaro30678/emotion_text_classifier_on_dd_v1

Metric DPO Trained Model SFT Model (Reference)
Accuracy 0.851 0.788
F1-score 0.8564 0.7975

Gibberish Distribution: madhurjindal/autonlp-Gibberish-Detector-492513457

Category DPO Trained Model SFT Model (Reference)
Clean 882 898
Mild Gibberish 94 58
Word Salad 21 33
Noise 3 11

Cut-Off Output:

Output Type DPO Trained Model SFT Model (Reference)
Complete Output 985 975
Incomplete Output 15 25

on hermeschen1116/daily_dialog_for_RG test split.

test on config:

  generation_config = GenerationConfig(
      max_new_tokens=150,
      min_new_tokens=5,
      repetition_penalty=1.1,
      top_k=3,
      top_p=0.9,
      pad_token_id=tokenizer.pad_token_id,
      eos_token_id=tokenizer.eos_token_id,
      temperature=1.0,
      do_sample=True,
      num_beams=1
  )

Training procedure

Training hyperparameters

The following hyperparameters were used during training:

  • beta=0.1,
  • remove_unused_columns=False,
  • num_train_epochs=3,
  • gradient_checkpointing=True

others remain default

Framework versions

  • Bitsandbytes 0.43.1
  • Datasets 2.20.0
  • PEFT 0.11.1
  • Pytorch 2.3.0+cu121
  • Transformers 4.42.4
  • Tokenizers 0.19.1
  • Trl 0.8.6
  • unsloth 2024.7 0f2e484

Uploaded model

  • Developed by: Shotaro30678
  • Finetuned from model : hermeschen1116/response_generator_for_emotion_chat_bot

This llama model was trained 2x faster with Unsloth and Huggingface's TRL library.

Quick sample

  # libs are from github repo
  from libs import ResponseGeneratorPipeline
  from unsloth import FastLanguageModel
  model, tokenizer = FastLanguageModel.from_pretrained(
      model_name = "Shotaro30678/response_generator_DPO", # YOUR MODEL YOU USED FOR TRAINING
      load_in_4bit = True,
  )
  FastLanguageModel.for_inference(model) # Enable native 2x faster inference
  
  bot = ResponseGeneratorPipeline(
      model,
      tokenizer,
      framework="pt",
      task="conversation-generation",
      num_workers=16,
      torch_dtype="auto",
      add_special_tokens=True,
      truncation=False,
      padding=True
  )
  
  conversation = [
      {'content': {'dialog': '', 'emotion': ''}, 'role': 'system'},
      {'content': {'dialog': 'Can you do push-ups ?', 'emotion': 'neutral'},
      'role': 'user'},
      {'content': {'dialog': "Of course I can . It's a piece of cake ! Believe it or not , I can do 30 push-ups a minute .",
      'emotion': 'neutral'},
      'role': 'assistant'},
      {'content': {'dialog': "Really ? I think that's impossible !",
      'emotion': 'surprise'},
      'role': 'user'},
      {'content': {'dialog': 'You mean 30 push-ups ?', 'emotion': 'neutral'},
      'role': 'assistant'},
      {'content': {'dialog': 'Yeah !', 'emotion': 'neutral'}, 'role': 'user'},
      {'content': {'dialog': '', 'emotion': 'neutral'}, 'role': 'assistant'}
   ]
  
  generation_config = GenerationConfig(
      max_new_tokens=150,
      min_new_tokens=5,
      repetition_penalty=1.1,
      top_k=3,
      top_p=0.9,
      pad_token_id=tokenizer.pad_token_id,
      eos_token_id=tokenizer.eos_token_id,
      temperature=1.0,
      do_sample=True,
      num_beams=1
  )
  
  print(bot(conversation, generation_config=generation_config)[0]['generated_text'][-1]["content"]["dialog"])

output:

30 push-ups in a row? 
Downloads last month
0
Safetensors
Model size
3.6B params
Tensor type
F32
·
BF16
·
U8
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for Shotaro30678/response_generator_DPO

Dataset used to train Shotaro30678/response_generator_DPO