jeanong2's picture
Set `library_name` to `tf-keras`. (#1)
175a2dd verified
metadata
datasets:
  - jeanong2/AITA-datasets
language:
  - en
library_name: tf-keras
license: apache-2.0

Model description

Fine-tuned bert-base-uncased model for AITA classification tasks The concept for this AITA classifier emerged thanks to a suggestion from my friend, Venessa Tan, for our project in module CS5246 during the second semester of AY23/24 at the National University of Singapore. I had the opportunity to build and fine-tune this model from scratch. I am thankful for the contributions of my other group members Ming Xuan and Hui Khang, who supported the project in valuable ways through data scraping and providing feedback. Find our main project here

Intended uses & limitations

Currently, it has limitations with shorter sequences. There are many edge cases that it doesn't perform well on. We hope this project inspires more developers to continue advancing this work, fostering greater ethical awareness in AI development.

Training and evaluation data

This model has been trained on train.csv and evaluated on test.csv

Prediction Scores :

  • Precision: 0.8123
  • Recall: 1.0000
  • F1 Score: 0.8965
  • Computed Accuracy: 0.9615

Example Run

from tensorflow.keras.models import load_model
from huggingface_hub import from_pretrained_keras
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer


class BERTForClassification(tf.keras.Model):
    def __init__(self, bert_model, num_classes):
        super(BERTForClassification, self).__init__()
        self.bert = bert_model
        self.fc = tf.keras.layers.Dense(num_classes, activation='softmax')

    def call(self, inputs):
        x = self.bert(inputs)[1]
        return self.fc(x)

bert_model = TFAutoModel.from_pretrained("bert-base-uncased")
custom_objects = {
    'BERTForClassification': BERTForClassification(bert_model, num_classes=2)  
}

model = from_pretrained_keras("jeanong2/finetuned-bert-aita-classifier", custom_objects=custom_objects)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Inference
def inference_analysis(model, text):
  encoding = tokenizer(text, padding='max_length', truncation=True, max_length=512, return_tensors="tf")
  inputs = {
      'input_ids': encoding['input_ids'],
      'attention_mask': encoding['attention_mask']
  }
  if 'token_type_ids' in encoding:
      inputs['token_type_ids'] = encoding['token_type_ids']
  test_dataset = tf.data.Dataset.from_tensor_slices((inputs))
  test_dataset = test_dataset.batch(1)
  predictions = model.predict(test_dataset)
  print("Probabilities for 0 and 1 :")
  print(predictions)

text = '''AITA for making out with this dude's ex in front of him? | I play rugby with the guy in question (let's say, "Mark") but I don't usually hang out with him outside of matches and practice. 
He broke up with a woman (Jia) that I'm rather attracted to last week. For the sake of propriety, I had no real intention to make moves or anything.
But last night I'm at a bar, and both Mark and Jia are there. I was at a table with some friends, and he was a couple tables over. 
Jia's with her friends as well but after a time comes over to my table and sits next to me, starts chatting me up. 
We flirt, and eventually she leans in and kisses me, and I reciprocate. I tend to think that PDA of that kind is a bit trashy so after a few seconds I get up with her and we go outside, 
but I can see that Mark has been watching the entire time. He makes a rude comment to both of us as we pass.
Today at practice he picked a fight with me that would have come to blows if the other guys on the team hadn't held him back. 
He's steaming mad. I feel a little sorry for him, but at the moment I can't actually bring myself to feel bad about hooking up with Jia, or the fact that he was there for it. AITA here?'''

inference_analysis(model, text)

Training hyperparameters

The following hyperparameters were used during training:

Hyperparameters Value
name Adam
weight_decay None
clipnorm None
global_clipnorm None
clipvalue None
use_ema False
ema_momentum 0.99
ema_overwrite_frequency None
jit_compile True
is_legacy_optimizer False
learning_rate 9.999999747378752e-06
beta_1 0.9
beta_2 0.999
epsilon 1e-07
amsgrad False
training_precision float32