# Fine-tuning a custom Sentence Transformers model using synthetic data

This notebook shows at a high level how we can define a pipeline for generating synthetic datasets for training/fine-tuning Sentence Transformers models for a custom domain using an LLM to help you generate relevant data.

## Why fine-tune?

There are already many good open source embedding models you can use but you may:

- work in a specific domain where existing embeddings might not work super well
- have a specific concept of similarity you want to capture
- want to optimize for a particular task

In all of these cases, even a little fine-tuning might help.

## How to get custom data?

One of the main barriers to fine-tuning a custom model has been the cost and effort involved in creating the datasets needed for this training. Recently, there has been an increased usage of LLMs for generating synthetic datasets. We'll see in this series of notebook how we can use an LLM for creating training datasets for fine-tuning a sentence similarity model.

Before we start creating our dataset we do some initial exploration and prep of the dataset we're working with. 


<div style="border-left: 4px solid #00A000; background-color: #F0FFF0; padding: 10px; margin: 10px 0;">
    <strong>Tip:</strong> We focus on a particular dataset in this case but you should be able to fairly easily adapt the notebook to use any other dataset on the Hugging Face Hub. 
</div>

If you are running this notebook in Collab you can use the following command to install the necessary libraries. If you are running in the Synthetic datasets workshop Space everything is already installed.

In [None]:
#%pip install datasets>=2.18.0 llama_index rich

## 01. Preparing the data

In this notebook, we'll focus on exploring the dataset and preparing it for generating our synthetic data. Depending on how well you know your dataset already you might spend less time on this step. However, it's always good to have a look at the data before starting to generate synthetic data since the approach you'll take might depend on the data you have.

In [1]:
import random
import uuid
from multiprocessing import cpu_count
from typing import Any, Dict, Optional

from datasets import load_dataset
from huggingface_hub import login
from llama_index.core import Document
from llama_index.core.node_parser import SentenceSplitter
from rich import print as rich_print

In [2]:
NUM_PROC = cpu_count()

## Authenticate with the Hub

You will need to authenticate with the Hub to be able to push datasets to the Hub. You can create a token by going to [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) and creating a new token. You will need a token with write access. It's suggested to create a new token for this workshop (you can always revoke it later).

In [3]:
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## The dataset 

For this example, we'll use [`dreamproit/bill_summary_us`](https://huggingface.co/datasets/dreamproit/bill_summary_us). This dataset "collects the text of bills, some metadata, as well as the CRS (Congressional Research Service) summaries". The dataset is originally focused on helping to develop models for summarization but we'll use it to generate synthetic data for training a sentence similarity model.

For datasets like this, we might see some benefits in fine-tuning a custom sentence similarity model. A standard sentence similarity model may do a good job of finding similar sentences but it might not be able to capture the specific similarity we're interested in. For example, it might be able to distinguish the text in US bills compared to recipes but if you want to find similar bills based on the topics they cover, you might need a more domain-specific model. Alongside being able to work with a specific domain, you may also want to have more control over what type of similarity you want a model to capture. Do you want the model to capture semantic similarity, topic similarity, or something else? If we create our own dataset and fine-tune a model we'll have more control over this.

Let's start by loading the data and having a look at it.

In [4]:
ds = load_dataset("dreamproit/bill_summary_us", split="train")
ds

Dataset({
    features: ['id', 'congress', 'bill_type', 'bill_number', 'bill_version', 'sections', 'sections_length', 'text', 'text_length', 'summary', 'summary_length', 'title'],
    num_rows: 125246
})

Let's see a few examples of the data we have in the dataset.

In [5]:
ds[:2]

{'id': ['108hconres408ih', '108hconres449ih'],
 'congress': [108, 108],
 'bill_type': ['hconres', 'hconres'],
 'bill_number': [408, 449],
 'bill_version': ['ih', 'ih'],
 'sections': [[{'text': 'That Congress— (1) congratulates the University of Denver men’s hockey team for winning the 2004 NCAA men’s hockey national championship; (2) recognizes the achievements of all the team’s players, coaches, and support staff and invites them to the United States Capitol Building to be honored; (3) requests that the President recognize the achievements of the University of Denver men’s hockey team and invite the team members to the White House for an appropriate ceremony honoring a national championship team; and (4) directs the Clerk of the House of Representatives to make available to the University of Denver enrolled copies of this resolution for appropriate display and to transmit an enrolled copy of this resolution to each coach and member of the 2004 NCAA men’s hockey national championship t

Since we're going to be working to create an embedding dataset for the texts in the `text` column, let's take a closer look at what this looks like.

In [6]:
rich_print(ds[4:6]['text'])

We can see these texts are relatively short but if we take a look at other examples in this dataset we'll see there are some much longer ones. For most datasets we're working with which haven't already been preprocessed in some way, we'll find that we need to do some work to split the texts into smaller segments. 

## Chunking our text

We'll need to split our text into smaller chunks to be able to use it for training a sentence similarity model. There are two main reasons for this:

- Sentence Transformers models have a maximum input length for text/tokens they can process. This number depends on the model you're using. 
- Longer sections of text are more likely to be about multiple topics which can make it harder for the model to learn a specific type of similarity.

Whilst the maximum embedding size for many open source models has grown recently we may still want to split our text into smaller chunks to ensure we have logical units of text to work with.

### How to decide on the right chunk size

Deciding on the right chunk size can be a bit of a balancing act and can depend on the specific dataset you're working with and the end application for your embedding model. One of the main applications of a custom sentence similarity model is to help improve the performance of a Retrieval Augmented Generation (RAG) application. In this case, you might want to split your text into chunks that are similar in length to the passages you'll be working with in your RAG application. 


### Splitting with Llama-index

There are many libraries that have been developed for helping with RAG applications that can also help us with splitting our text into chunks. One of these is `Llama-index` which we'll use in this notebook.

LLama-index has many different approaches for splitting texts (see [node_parsers](https://docs.llamaindex.ai/en/stable/api_reference/node_parsers/)). In this notebook we'll use the rather simple `SentenceSplitter` which splits text into sentences:

>In general, this class tries to keep sentences and paragraphs together. Therefore compared to the original TokenTextSplitter, there are less likely to be hanging sentences or parts of sentences at the end of the node chunk.

If your data is in a format like HTML or Markdown, other parsers are likely to be worth exploring. There is also a `SemanticSplitterNodeParser` which "splits a document into Nodes, with each node being a group of semantically related sentences.". This could be worth exploring but is more computationally expensive to use and depending on the text you are working with might not lead to much better results.

### What size should we split our text into?

If we look at the doc string for `SentenceSplitter` we can see that the default value for `max_tokens` is `1024`. We might want to adjust this to see what size makes sense for our data. 


In [7]:
?SentenceSplitter

[0;31mInit signature:[0m
[0mSentenceSplitter[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mseparator[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m' '[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mchunk_size[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m1024[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mchunk_overlap[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m200[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtokenizer[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mCallable[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mparagraph_separator[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'\n\n\n'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mchunking_tokenizer_fn[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mCallable[0m[0;34m[[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m,[0m [0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msecondary_chun

In [8]:
splitter = SentenceSplitter()
splitter.chunk_size, splitter.chunk_overlap

(1024, 200)

Let's load an example text and see how different sizes of chunks look.

In [9]:
doc = Document.from_dict({"text": ds[200]['text']})

In [10]:
splits = splitter.get_nodes_from_documents([doc])
splits

[TextNode(id_='b77fc036-8f99-415c-8142-f9a672f450bc', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='e59781d0-ad81-4411-9415-697b650e3d7e', node_type=<ObjectType.DOCUMENT: '4'>, metadata={}, hash='19c0d72fca6cf938248ff70d34f7af953ce06be4b57f1748652d7fa878b87185'), <NodeRelationship.NEXT: '3'>: RelatedNodeInfo(node_id='5f7e3f35-a6dd-481b-968a-c70b2717c3d5', node_type=<ObjectType.TEXT: '1'>, metadata={}, hash='45627c584bb296bcb3c49c428092399194c37c6726210acedd5b22530af6d5a8')}, text='1. Short title; table of contents \n(a) Short title \nThis Act may be cited as the National Forest Organizational Camp Fee Improvement Act of 2003. (b) Table of contents \nThe table of contents for this Act is as follows: Sec. 1. Short title; table of contents Sec. 2. Findings, purpose, and definitions Sec. 3. Fees for occupancy and use of National Forest System lands and facilities by organi

In [11]:
rich_print(splits[0].text)

In [12]:
splitter = SentenceSplitter(chunk_size=128, chunk_overlap=0)

In [13]:
splits = splitter.get_nodes_from_documents([doc])
rich_print(splits[0].text)

In [14]:
sample_size = 12
documents = [Document.from_dict({"text": ds[i]['text']}) for i in range(10)]
splits = splitter.get_nodes_from_documents(documents)
len(splits)
# uncomment to see the output
# for split in splits:
#     rich_print(split.text)

20

For this particular dataset, since the texts are quite dense in the topics they cover, it seems to make sense to aim for a smaller chunk size like 128. This will help us to ensure that we're capturing the specific topics in the text. If you are using a different dataset you might want to experiment with different chunk sizes to see what works best for your data.

## Process our full dataset

Now that we've decided on a chunk size, let's process our full dataset. We'll split each text into chunks and save these to a new dataset.

In [15]:
def split_texts(
    examples: Dict[str, Any],
    text_column_name: str = "text",
    id_column_name: Optional[str] = None,
    splitter: Optional[SentenceSplitter] = None,
):
    if splitter is None:
        # if not provided, use the default splitter
        splitter = SentenceSplitter()
    texts = examples[text_column_name]
    if id_column_name is None:
        # Generate random ids if not provided
        ids = [str(uuid.uuid4()) for _ in range(len(texts))]
    else:
        ids = examples[id_column_name]
    sections = []
    ids_ = []
    for text, id_ in zip(texts, ids):
        # Create a document for each text
        document = Document(text=text)
        # Split the document into nodes
        nodes = splitter.get_nodes_from_documents([document])
        # Extract the text from each node
        sentences = [n.text for n in nodes]
        # Extend the sections list with these sentences
        sections.extend(sentences)
        # Extend the ids_ list with the corresponding id, repeated for each sentence
        ids_.extend([id_] * len(sentences))
    return {"section": sections, "id": ids_}

We can now split the full dataset. 

If you are using a different dataset remember to adjust the `text_column_name` if the name of the column containing the text for your dataset is different. If there is an `id` column you can specify that as well otherwise set this to `None` and the function will generate an id for each row.

In [16]:
splitter = SentenceSplitter(chunk_size=128, chunk_overlap=0)

In [17]:
chunked_ds = ds.map(
    split_texts,
    batched=True,
    num_proc=NUM_PROC,
    remove_columns=list(ds.column_names),
    fn_kwargs={"text_column_name": "text", "id_column_name": "id", "splitter": splitter},
)
ds

Map (num_proc=8):   0%|          | 0/125246 [00:00<?, ? examples/s]

Dataset({
    features: ['id', 'congress', 'bill_type', 'bill_number', 'bill_version', 'sections', 'sections_length', 'text', 'text_length', 'summary', 'summary_length', 'title'],
    num_rows: 125246
})

In [18]:
chunked_ds

Dataset({
    features: ['id', 'section'],
    num_rows: 3446013
})

In [19]:
sample_idx = random.sample(range(len(chunked_ds)),k=3)
chunked_ds.select(sample_idx)[:]['section']

['(4) Coal \nThe term coal means bituminous\t\t\t\tcoal, subbituminous coal, and lignite. (d) Aggregate\t\t\t\tcredits \n(1) In\t\t\t\tgeneral \nNo credit shall be allowed under this section with\t\t\t\trespect to any qualifying clean coal project unless such project is certified\t\t\t\tby the Secretary under subsection (e).',
 '1. Short\t\t\t title; table of contents \n(a) Short\t\t\t title \nThis Act may be cited\t\t\t as the Skilled Worker Immigration and\t\t\t Fairness Act. (b) Table of\t\t\t contents \nThe table of contents for this Act is as follows: Sec. 1. Short title; table of\t\t\t\tcontents. Sec. 2. H–1B visas. Sec. 3. Employment-based immigration. Sec. 4. H–1B visa fraud and abuse protections. 2.',
 'Remote control locomotive use \n(a) Prohibition \nNo railroad carrier shall operate or cause to be operated on the general system of railroad transportation a remote control locomotive to carry hazardous materials. (b) Penalty \n(1) A railroad carrier that knowingly violates th

## Pushing the data to the hub

We can save the data locally to use in the next notebook but it's often easier to work with the data if we push it to the hub. This way we can easily access the data in the next notebook.

In [21]:
chunked_ds.push_to_hub("davanstrien/bill_summary_us_chunks")

Uploading the dataset shards:   0%|          | 0/4 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/862 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/862 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/862 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/862 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/davanstrien/bill_summary_us_chunks/commit/e9c23f8e002cda39422c1a39bc95c8e5cd37213b', commit_message='Upload dataset', commit_description='', oid='e9c23f8e002cda39422c1a39bc95c8e5cd37213b', pr_url=None, pr_revision=None, pr_num=None)

## Next steps

In the next notebook, we'll look at how we can use an LLM to generate synthetic data for fine-tuning our custom Sentence Transformers model. If you are running this notebook in the Synthetic Dataset Workshop Space you can find the next notebook in the workspace. If you are running this notebook locally you can find the next notebook in the Hugging Face repository.  