import datetime import json import logging import os import re import datasets import dateutil.parser logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # %% # Now, create message groups ('conversations') # The idea is to group messages that are close in time # We'll use a 180 minute threshold MINUTES_THRESHOLD = 180 MIN_MESSAGES_THRESHOLD = 5 def group_messages(messages_iterable): groups = [] current_group = [next(messages_iterable)] for message in messages_iterable: assert len(current_group) > 0 # We should never have an empty group if ( message["timestamp"] - current_group[-1]["timestamp"] < MINUTES_THRESHOLD * 60 ): current_group.append(message) else: groups.append(current_group) current_group = [message] groups.append(current_group) return groups def printable_conversation(conversation): return "\n".join( [f"{message['contact_name']}: {message['message']}" for message in conversation] ) import contextualSpellCheck # %% # Use spacy to spell check the messages import spacy from spellchecker import SpellChecker spell = SpellChecker() # nlp = spacy.load("es_core_news_sm") nlp = spacy.load("en_core_web_sm") def spell_check_conversation(conversation): for i, message in enumerate(conversation["conversations"]): # Use SpaCy to get the words words = spell.split_words(message["message"]) logger.info(f"Words: {words}") corrected_message = [] for word in words: correction = spell.correction(word) if (correction != None) and (correction != word): logger.info(f"Spell check: {word} -> {correction}") corrected_message.append(correction) else: corrected_message.append(word) logger.info(f"Corrected message: {corrected_message}") joined_message = " ".join(corrected_message) conversation["conversations"][i]["message"] = joined_message return conversation def spell_check_conversation_spacy(conversation): nlp.add_pipe( "contextual spellchecker", config={ "model_name": "bert-base-multilingual-uncased", "max_edit_dist": 2, }, ) docs = list(nlp.pipe([msg["message"] for msg in conversation["conversations"]])) for i, doc in enumerate(docs): if doc._.performed_spellCheck: logger.info(f"Spell checked: {doc.text} -> {doc._.outcome_spellCheck}") conversation["conversations"][i]["message"] = doc._.outcome_spellCheck return conversation def remove_whatapp_annotations(conversation): """ Removes the following annotations from the messages: - """ for message in conversation["conversations"]: message["message"] = re.sub( r"", "", message["message"] ) return conversation # %% """ Sometimes, people write concurrently in the same conversation. We'll try to detect that and reorder the messages. For example, if we have a conversation like this: A: Hi A: How are you? B: Hi B: I'm fine, thanks A: I'm fine too We'll reorder it to: A: Hi B: Hi A: How are you? B: I'm fine, thanks A: I'm fine too To do it, we'll use MobileBERT with the next sentence prediction head. We'll use the first message as the first sentence, and the second message as the second sentence. If the model predicts that the second sentence is more likely to be the next sentence, we'll swap the messages. """ import torch from transformers import AutoModelForNextSentencePrediction, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") model = AutoModelForNextSentencePrediction.from_pretrained("bert-base-uncased") if torch.cuda.is_available(): model.cuda() def swap_messages_if_needed(message1, message2): # If the messages have the same contact, we don't swap them if message1["contact_name"] == message2["contact_name"]: return message1, message2 # The timestamp must have a difference of less than 2 minutes. First, convert to datetime datetime1 = datetime.datetime.fromtimestamp(message1["timestamp"]) datetime2 = datetime.datetime.fromtimestamp(message2["timestamp"]) if (datetime2 - datetime1).total_seconds() > 2 * 60: return message1, message2 # If one of the messages has less than 3 words, we don't swap them if len(message1["message"].split()) < 3 or len(message2["message"].split()) < 3: return message1, message2 # We'll use the first message as the first sentence, and the second message as the second sentence inputs = tokenizer(message1["message"], message2["message"], return_tensors="pt") reverse_inputs = tokenizer( message2["message"], message1["message"], return_tensors="pt" ) # Join them in a single batch joined_inputs = torch.cat([inputs["input_ids"], reverse_inputs["input_ids"]], dim=0) if torch.cuda.is_available(): joined_inputs = joined_inputs.cuda() with torch.no_grad(): outputs = model(input_ids=joined_inputs) # The output is a tuple with the logits for each class (next sentence or not) # We'll take the first one (next sentence) logits = outputs[0] # Apply softmax logits = torch.softmax(logits, dim=1) # We have two probabilities: the probability of 1 -> 2, and the probability of 2 -> 1 # We'll take the difference swap = logits[0, 0] - logits[1, 0] < -0.2 if swap: # Swap the messages logger.info( f"Swapping messages: {message1['message']} <-> {message2['message']}" ) return message2, message1 else: # logger.info(f"NOT swapping messages: {message1['message']} <-> {message2['message']}") return message1, message2 def swap_messages_if_needed_in_conversation(conversation): # We'll use the first message as the first sentence, and the second message as the second sentence if len(conversation) <= 2: return conversation new_conversation = [ conversation[0], conversation[1], ] # We'll always keep the first message in the same position for i in range(2, len(conversation)): message1 = new_conversation[-1] message2 = conversation[i] message1, message2 = swap_messages_if_needed(message1, message2) new_conversation[-1] = message1 new_conversation.append(message2) # logger.info(f"\nOriginal conversation:\n{printable_conversation(conversation)}") # logger.info(f"\nNew conversation:\n{printable_conversation(new_conversation)}") return new_conversation test_conversation = [ {"message": "Hola!", "contact_name": "A", "timestamp": 1}, { "message": "Está todo bien, gracias por preguntar!", "contact_name": "B", "timestamp": 2, }, { "message": "Hola, qué tal estás? Espero que vaya todo bien por España.", "contact_name": "A", "timestamp": 3, }, ] # logger.info(swap_messages_if_needed_in_conversation(test_conversation)) # %% # Now, we'll train an mT5 model to generate the next message in a conversation import os # %% def process_chat_file(file, do_spelling_correction, whatsapp_name, datetime_dayfirst, message_line_format, do_reordering=False): """ Process a chat file and return a dataset with the conversations. """ exp = re.compile( # r"(?P.+?) - (?P.+): (?P.+)" # r"\[?(?P\S+,\s\S+?(?:\s[APap][Mm])?)\]? (?:- )?(?P.+): (?P.+)" message_line_format ) def process_line(example): # The lines have this format: dd/mm/yy, hh:mm - : try: groups = exp.match(example["text"]).groupdict() timestamp = dateutil.parser.parse(groups['msg_datetime'], dayfirst=datetime_dayfirst).timestamp() return { "message": groups["message"], "contact_name": groups["contact_name"], "timestamp": timestamp, } except Exception as e: logger.exception(example["text"]) raise e ds = ( datasets.load_dataset("text", data_files=[file])["train"] .filter( # Has to begin by date, time, contact name, and contain at least a ':' symbol lambda x: re.match( r"^\d{1,2}/\d{1,2}/\d{1,2},\s\d{2}:\d{2}\s-\s.+:", x["text"] ) ) .map(process_line, remove_columns=["text"]) ) # Filter out messages that just say '' ds = ds.filter(lambda x: x["message"] != "") groups = group_messages(iter(ds)) # Generate the dataset conversations_ds = datasets.Dataset.from_dict({"conversations": groups}) # Filter out conversations with less than 5 messages conversations_ds = conversations_ds.filter( lambda x: len(x["conversations"]) >= MIN_MESSAGES_THRESHOLD ) conversations_ds_without_whatsapp_annotations = conversations_ds.map( remove_whatapp_annotations, num_proc=os.cpu_count() - 1, ) if do_spelling_correction: spell_checked_conversations_ds = ( conversations_ds_without_whatsapp_annotations.map(spell_check_conversation) ) else: spell_checked_conversations_ds = conversations_ds_without_whatsapp_annotations if do_reordering: reordered_conversations_ds = spell_checked_conversations_ds.map( swap_messages_if_needed_in_conversation ) else: reordered_conversations_ds = spell_checked_conversations_ds # For the contact_name, rewrite everything that is not 'my_whatsapp_name' to 'Other' def rewrite_contact_name(conversation): for message in conversation["conversations"]: if message["contact_name"] != whatsapp_name: message["contact_name"] = "Other" return conversation changed_contact_name_ds = reordered_conversations_ds.map( rewrite_contact_name ) # , num_proc=os.cpu_count() - 1) # Filter out conversations with only one contact changed_contact_name_ds = changed_contact_name_ds.filter( lambda x: len(set([msg["contact_name"] for msg in x["conversations"]])) > 1 ) return changed_contact_name_ds SPLIT_CONVERSATION_THRESHOLD = 40 MAX_CHARACTERS_PER_MESSAGE = 10000 # Max is 8,192 tokens (https://cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-supervised-tuning-about#sample-datasets) def transform_conversations_dataset_into_training_examples( conversations_ds, system_prompt, user_role, model_role, whatsapp_name ): """ Takes in a dataset with conversations and returns a dataset with training examples. The input dataset contains a single column (conversations), with each row being a list of messages with this format: ``` [{'contact_name': 'Aldi', 'message': , 'timestamp':