Query Generation with LoRA Finetuning
This project fine-tunes a language model using supervised fine-tuning (SFT) and LoRA adapters to generate queries from documents. The model was trained on the prdev/qtack-gq-embeddings-unsupervised
dataset using an A100 GPU.
Overview
Objective:
The goal is to train a model that, given a document, generates a relevant query. Each training example is formatted with custom markers:<|document|>\n
precedes the document text.<|query|>\n
precedes the query text.- An EOS token is appended at the end to signal termination.
Text Chunking:
For optimal performance, chunk your text into smaller, coherent pieces before providing it to the model. Long documents can lead the model to focus on specific details rather than the overall context.Training Setup:
The model is fine-tuned using the Unsloth framework with LoRA adapters, taking advantage of an A100 GPU for efficient training. See W&B loss curve here: https://wandb.ai/prdev/lora_model_training/panel/jp2r24xk7?nw=nwuserprdev
Quick Usage
Below is an example code snippet to load the finetuned model and test it with a chunked document:
from unsloth import FastLanguageModel
from transformers import TextStreamer
# Load the finetuned model and tokenizer from Hugging Face Hub.
model, tokenizer = FastLanguageModel.from_pretrained("prdev/query-gen", load_in_4bit=True)
# Enable faster inference if supported.
FastLanguageModel.for_inference(model)
# Example document chunk (ensure text is appropriately chunked).
document_chunk = (
"liberal arts. 1. the academic course of instruction at a college intended to provide general knowledge "
"and comprising the arts, humanities, natural sciences, and social sciences, as opposed to professional or technical subjects."
)
# Create the prompt using custom markers.
prompt = (
"<|document|>\n" + document_chunk + "\n<|query|>\n"
)
# Tokenize the prompt.
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Set up a TextStreamer to view token-by-token generation.
streamer = TextStreamer(tokenizer, skip_prompt=True)
# Generate a query from the document.
_ = model.generate(
input_ids=inputs["input_ids"],
streamer=streamer,
max_new_tokens=100,
temperature=0.7,
min_p=0.1,
eos_token_id=tokenizer.eos_token_id, # Ensures proper termination.
)
Uploaded model
- Developed by: prdev
- License: apache-2.0
- Finetuned from model : unsloth/llama-3.2-1b-instruct-unsloth-bnb-4bit
This llama model was trained 2x faster with Unsloth and Huggingface's TRL library.