Announcing our Document Research Assistant, a collaboration with NVIDIA!
LlamaIndex

Jerry Liu 2023-08-25

Fine-Tuning Embeddings for RAG with Synthetic Data

UPDATE 9/10/2023: We’ve included embedding finetuning abstractions into the LlamaIndex repo, so this repo is technically outdated! Please check out our embedding fine-tuning guides in the core documentation.

We’ve created a comprehensive, end-to-end guide showing you how to fine-tune an embedding model to improve performance of Retrieval Augmented Generation (RAG) systems over any unstructured text corpus (no labels required!).

The result is a 5–10% performance increase in retrieval evaluation metrics — our finetuned bge model almost reaches text-embedding-ada-002 levels of retrieval performance in terms of hit rate. This enables more accurate retrieval which leads to better RAG systems as a whole.

This tutorial is helpful to anyone building RAG systems:

  • If you’re new to finetuning, no problem! We have step by step notebooks walking through the key steps. Simply substitute the file links for your own data, and just run every cell.
  • Finetuning embedding models is lightweight and doesn’t require a GPU. These notebooks were tested on an M2 Macbook Pro.

Resources

Background/Context

The Current RAG Stack

RAG is a popular paradigm for connecting Large Language Models (LLMs) with an external source of data that was not present in its training corpus. It pairs a retrieval model over a knowledge bank with the LLM through its input prompt space. RAG stacks typically look like the following:

  • Indexing: Prepare a corpus of unstructured text, parse/chunk it. Then embed each chunk and put in a vector database.
  • Query-time: Retrieve context from the vector db using top-k embedding similarity lookup, and stuff context into the LLM input space.

(Of course RAG can be much more advanced than this, and LlamaIndex provides tools for both simple and advanced RAG)

Unfortunately RAG is easy to prototype by cobbling together the different components, but hard to productionize. The simple stack has many failure modes and oftentimes the issue lies with bad retrieval — if the returned context is irrelevant to the query, then the capability of the LLM is irrelevant; the answer will always be bad.

How Can We Make Retrieval Better?

We can try more sophisticated retrieval algorithms (e.g. hybrid search, reranking).

An insight from our recent production RAG webinar, however, is that the embeddings themselves may not live in an optimal latent space for your data. Embeddings generated by pre-trained models may be close/far from each other based on the pre-training objective, but may not completely align with your own retrieval objective. For instance, if you’re building search over ML ArXiv papers, you may want the embeddings to align semantically with specific ML concepts (e.g. “LLMs”, “NLP”) and not filler words “This paper is…”).

Finetuning is a way to solve that. The concept of finetuning has become increasingly popular in the LLM space, with technological advancements as well as easy-to-use services.

In this tutorial, we focus on finetuning the embedding model. We show how finetuning the embedding model can lead to better retrieval performance.

Challenges/Considerations

When you finetune embeddings, you need training examples. In the case of embeddings, this typically means that you have both “positive” and “negative” examples — pairs of texts that should be close to each other and far from each other.

An issue is that we don’t have these positive or negative examples apriori. Given a dataset of unstructured text, is it possible to automatically generate these example pairs?

With LlamaIndex you can! We use LlamaIndex modules to automatically generate a set of questions from unstructured text chunks. These (question, chunk) pairs are then used as positive examples as training signals for the model (negative examples are randomly sampled across other chunks).

The next section shows a full walkthrough across all of our modules.

Walkthrough

At a high-level, we do the following:

  1. Generating synthetic dataset for training and evaluation (Notebook)
  2. Finetuning an opensource embedding model (Notebook)
  3. Evaluating the embedding model (Notebook)

Generating synthetic dataset for training and evaluation

The key idea here is that we can leverage an LLM to generate hypothetical questions that are best answered by a given piece of context. This allows us to generate synthetic positive pairs of (query, relevant documents) in a scalable way without requiring human labellers.

More concretely, we first process the given documents into a corpus of text chunks. We do this with the SimpleNodeParser module in LlamaIndex:

parser = SimpleNodeParser()
nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)
corpus = {
  node.node_id: node.get_content(metadata_mode=MetadataMode.NONE) 
  for node in nodes
}

Then for each text chunk, we use LLM to generate a few hypothetical questions that can be answered with information form that text chunk. The example prompt is shown below as well.

prompt_template = prompt_template or """\
  Context information is below.
  
  ---------------------
  {context_str}
  ---------------------
  
  Given the context information and not prior knowledge.
  generate only questions based on the below query.
  
  You are a Teacher/ Professor. Your task is to setup \
  {num_questions_per_chunk} questions for an upcoming \
  quiz/examination. The questions should be diverse in nature \
  across the document. Restrict the questions to the \
  context information provided."
  """

# for a given node, extract questions (do this over all nodes in outer loop)
query = prompt_template.format(context_str=text, num_questions_per_chunk=num_questions_per_chunk)
response = llm.complete(query)

result = str(response).strip().split("\n")
questions = [
    re.sub(r"^\d+[\).\s]", "", question).strip() for question in result
]
questions = [question for question in questions if len(question) > 0]

Finally, we collect all pairs of questions and text chunks as the dataset. Example query, chunk, and mapping is shown below.


# example query
f331640a-b407-4028-8db8-4b8db691dd34: "What is the market value of Lyft's common stock held by non-affiliates as of June 30, 2021, based on the closing sales price of the Class A common stock on that date?"

# example corpus
d5554f3e-cdaf-41d7-ac49-8f0ffe3f5759:"UNITED STATESSECURITIES AND..."

# example mapping
f331640a-b407-4028-8db8-4b8db691dd34: d5554f3e-cdaf-41d7-ac49-8f0ffe3f5759

Finetuning an opensource embedding model

We leverage the high-level model fitting API from sentencetransformers to very easily setup a training process.

We use MultipleNegativesRankingLoss as the training object and InformationRetrievalEvaluator as the evaluator during training. Also, we useBAAI/bge-small-en on Hugging Face as the base model and train for a small number of epochs.

# define model
model_id = "BAAI/bge-small-en"
model = SentenceTransformer(model_id)

...

# define loss
from sentence_transformers import losses
loss = losses.MultipleNegativesRankingLoss(model)

# define evaluator
from sentence_transformers.evaluation import InformationRetrievalEvaluator
# define over validation dataset
...
evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)

# run training
...
model.fit(
    train_objectives=[(loader, loss)],
    epochs=EPOCHS,
    warmup_steps=warmup_steps,
    output_path='exp_finetune',
    show_progress_bar=True,
    evaluator=evaluator, 
    evaluation_steps=50,
)

Evaluating the embedding model

We compare the finetuned model against the base model, as well as the OpenAI embedding model text-embedding-ada-002 .

We evaluate with two main metrics:

  • Hit-rate metric: For each (query, relevant_doc) pair, we retrieve the top-k documents with the query. It’s a hit if the results contain relevant_doc.
  • InformationRetrievalEvaluator from sentence_transformers. This provides a comprehensive suite of metrics such as cosine similarity accuracy, precision, recall at different top-k values.

Results

In terms of hit-rate metric, the base model gets 78% hit-rate on the validation dataset, and the fine-tuned model gets 84%. text-embedding-ada-002 gets 87%, which means that our fine-tuned model is only 3% off!

Hit-rate for `text-embedding-ada-002`, base model, finetuned model

The InformationRetrievalEvaluator shows a similar improvement across an entire suite of metrics. The fine-tuned model increases evaluation metrics by 5–10% compared to the base-model.

Evaluation suite from `InformationRetrievalEvaluator`

Conclusion

We successfully finetuned an embedding model over unlabeled, unstructured data to give better retrieval performance for downstream RAG systems. We show a 5–10% improvement across all metrics!

Resources

(copied from intro)