Edit model card

Model Card for Model ID

The Mokka Chat model is a fine-tuned T5 based model built for humorous responses.

Model Details

Model Description

This MokkaChat model is a simple model which was built for humourous chats.

  • Developed by: Sri Soundararajan
  • Model type: Text2Text Conditional Generation
  • Language(s) (NLP): English
  • License: MIT
  • Finetuned from model: T5-Base

Uses

This model can be used normally. Here is an example notebook on how to run inference with this model

https://colab.research.google.com/drive/1Z8bJtiNjmk-d3au_3pdjq-ALB76KYHlr

How to Get Started with the Model

Use the code below to get started with the model.

import warnings
import json
import torch
import evaluate  # Bleu
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
warnings.filterwarnings("ignore")

Q_LEN = 100

MODEL = AutoModelForSeq2SeqLM.from_pretrained(
    "ssounda1/MokkaChat", return_dict=True)
TOKENIZER = AutoTokenizer.from_pretrained("ssounda1/MokkaChat")
DEVICE = torch.device(
    "cuda" if torch.backends.cudnn.is_available() else "cpu")
MODEL = MODEL.to(DEVICE)


def get_answer(context, question, ref_answer=None):
    inputs = TOKENIZER(question, context, max_length=Q_LEN,
                       padding="max_length", truncation=True, add_special_tokens=True)

    input_ids = torch.tensor(
        inputs["input_ids"], dtype=torch.long).to(DEVICE).unsqueeze(0)
    attention_mask = torch.tensor(
        inputs["attention_mask"], dtype=torch.long).to(DEVICE).unsqueeze(0)

    outputs = MODEL.generate(
        input_ids=input_ids, attention_mask=attention_mask, temperature=0.9)

    predicted_answer = TOKENIZER.decode(
        outputs.flatten(), skip_special_tokens=True)

    if ref_answer:
        # Load the Bleu metric
        bleu = evaluate.load("google_bleu")
        score = bleu.compute(predictions=[predicted_answer],
                             references=[ref_answer])

        return {
            "Question: ": question,
            "Context: ": context,
            "Reference Answer: ": ref_answer,
            "Predicted Answer: ": predicted_answer,
            "BLEU Score: ": score
        }
    else:
        return predicted_answer

context = "Keep calm and say ..."
question = "Do you know the answer to this question?"
answer = "Ahaan!"

answer_resp = get_answer(context, question, answer)
print(json.dumps(answer_resp, indent=4))

Training Details

Training Data

The T5-Base was used and it was trined by augmenting the Squad V2 dataset with a custom Mokka chat dataset. Here are the links to these datasets - https://huggingface.co/datasets/ssounda1/mokka-chat-ds-v1 https://huggingface.co/datasets/squad_v2

Training Procedure

Training Hyperparameters

  • Training regime: fp32

Evaluation

1/20 -> Train loss: 0.8245184440580875	Validation loss: 0.4026999438791832
2/20 -> Train loss: 0.703028231633494	Validation loss: 0.30366039834675435
3/20 -> Train loss: 0.6249609817720345	Validation loss: 0.24144947223853383
4/20 -> Train loss: 0.5657204371531265	Validation loss: 0.19916585764708916
5/20 -> Train loss: 0.518096115625194	Validation loss: 0.16852003234101076
6/20 -> Train loss: 0.47824101336522334	Validation loss: 0.14573621848088278
7/20 -> Train loss: 0.4446890475844722	Validation loss: 0.1282667571046452
8/20 -> Train loss: 0.4158546521539049	Validation loss: 0.11418618139097068
9/20 -> Train loss: 0.39071896244012094	Validation loss: 0.10286468480848737
10/20 -> Train loss: 0.3685988230877622	Validation loss: 0.09348667512682264
11/20 -> Train loss: 0.3489853145691834	Validation loss: 0.0856158411675543
12/20 -> Train loss: 0.3313692257589271	Validation loss: 0.07894140510740721
13/20 -> Train loss: 0.3154840102660389	Validation loss: 0.07324570708649529
14/20 -> Train loss: 0.3010822039016147	Validation loss: 0.06825826695942235
15/20 -> Train loss: 0.28787958101105554	Validation loss: 0.06392730204562044
16/20 -> Train loss: 0.27582068473036314	Validation loss: 0.06014419615740111
17/20 -> Train loss: 0.2647442796077156	Validation loss: 0.0567684230703388
18/20 -> Train loss: 0.25449865650574116	Validation loss: 0.053749261090770835
19/20 -> Train loss: 0.24506365559240695	Validation loss: 0.051029498609284206
20/20 -> Train loss: 0.23624430357763543	Validation loss: 0.04856409976122556

References

Thanks to this article on helping me build and train this model https://medium.com/@ajazturki10/simplifying-language-understanding-a-beginners-guide-to-question-answering-with-t5-and-pytorch-253e0d6aac54

Model Card Contact

Sri Soundararajan ssounda1.work@gmail.com

Downloads last month
3
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Datasets used to train ssounda1/MokkaChat