Edit model card

Model Card for Model ID

The Mokka Chat model is a fine-tuned T5 based model built for humorous responses.

Model Details

Model Description

This MokkaChat model is a simple model which was built for humourous chats.

  • Developed by: Sri Soundararajan
  • Model type: Text2Text Conditional Generation
  • Language(s) (NLP): English
  • License: MIT
  • Finetuned from model: T5-Base

Uses

This model can be used normally. Here is an example notebook on how to run inference with this model

https://colab.research.google.com/drive/1Z8bJtiNjmk-d3au_3pdjq-ALB76KYHlr

How to Get Started with the Model

Use the code below to get started with the model.

import warnings
import json
import torch
import evaluate  # Bleu
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
warnings.filterwarnings("ignore")

Q_LEN = 100

MODEL = AutoModelForSeq2SeqLM.from_pretrained(
    "ssounda1/MokkaChat", return_dict=True)
TOKENIZER = AutoTokenizer.from_pretrained("ssounda1/MokkaChat")
DEVICE = torch.device(
    "cuda" if torch.backends.cudnn.is_available() else "cpu")
MODEL = MODEL.to(DEVICE)


def get_answer(context, question, ref_answer=None):
    inputs = TOKENIZER(question, context, max_length=Q_LEN,
                       padding="max_length", truncation=True, add_special_tokens=True)

    input_ids = torch.tensor(
        inputs["input_ids"], dtype=torch.long).to(DEVICE).unsqueeze(0)
    attention_mask = torch.tensor(
        inputs["attention_mask"], dtype=torch.long).to(DEVICE).unsqueeze(0)

    outputs = MODEL.generate(
        input_ids=input_ids, attention_mask=attention_mask, temperature=0.9)

    predicted_answer = TOKENIZER.decode(
        outputs.flatten(), skip_special_tokens=True)

    if ref_answer:
        # Load the Bleu metric
        bleu = evaluate.load("google_bleu")
        score = bleu.compute(predictions=[predicted_answer],
                             references=[ref_answer])

        return {
            "Question: ": question,
            "Context: ": context,
            "Reference Answer: ": ref_answer,
            "Predicted Answer: ": predicted_answer,
            "BLEU Score: ": score
        }
    else:
        return predicted_answer

context = "Keep calm and say ..."
question = "Do you know the answer to this question?"
answer = "Ahaan!"

answer_resp = get_answer(context, question, answer)
print(json.dumps(answer_resp, indent=4))

Training Details

Training Data

The T5-Base was used and it was trined by augmenting the Squad V2 dataset with a custom Mokka chat dataset. Here are the links to these datasets - https://huggingface.co/datasets/ssounda1/mokka-chat-ds-v1 https://huggingface.co/datasets/squad_v2

Training Procedure

Training Hyperparameters

  • Training regime: fp32

Evaluation

1/20 -> Train loss: 0.8245184440580875	Validation loss: 0.4026999438791832
2/20 -> Train loss: 0.703028231633494	Validation loss: 0.30366039834675435
3/20 -> Train loss: 0.6249609817720345	Validation loss: 0.24144947223853383
4/20 -> Train loss: 0.5657204371531265	Validation loss: 0.19916585764708916
5/20 -> Train loss: 0.518096115625194	Validation loss: 0.16852003234101076
6/20 -> Train loss: 0.47824101336522334	Validation loss: 0.14573621848088278
7/20 -> Train loss: 0.4446890475844722	Validation loss: 0.1282667571046452
8/20 -> Train loss: 0.4158546521539049	Validation loss: 0.11418618139097068
9/20 -> Train loss: 0.39071896244012094	Validation loss: 0.10286468480848737
10/20 -> Train loss: 0.3685988230877622	Validation loss: 0.09348667512682264
11/20 -> Train loss: 0.3489853145691834	Validation loss: 0.0856158411675543
12/20 -> Train loss: 0.3313692257589271	Validation loss: 0.07894140510740721
13/20 -> Train loss: 0.3154840102660389	Validation loss: 0.07324570708649529
14/20 -> Train loss: 0.3010822039016147	Validation loss: 0.06825826695942235
15/20 -> Train loss: 0.28787958101105554	Validation loss: 0.06392730204562044
16/20 -> Train loss: 0.27582068473036314	Validation loss: 0.06014419615740111
17/20 -> Train loss: 0.2647442796077156	Validation loss: 0.0567684230703388
18/20 -> Train loss: 0.25449865650574116	Validation loss: 0.053749261090770835
19/20 -> Train loss: 0.24506365559240695	Validation loss: 0.051029498609284206
20/20 -> Train loss: 0.23624430357763543	Validation loss: 0.04856409976122556

References

Thanks to this article on helping me build and train this model https://medium.com/@ajazturki10/simplifying-language-understanding-a-beginners-guide-to-question-answering-with-t5-and-pytorch-253e0d6aac54

Model Card Contact

Sri Soundararajan ssounda1.work@gmail.com

Downloads last month
0

Datasets used to train ssounda1/MokkaChat