Jerry Liu • 2023-09-06
Fine-Tuning a Linear Adapter for Any Embedding Model
We’ve added capabilities in LlamaIndex allowing you to fine-tune a linear adapter on top of embeddings produced from any model (sentence_transformers
, OpenAI, and more).
This allows you to transform your embedding representations into a new latent space that’s optimized for retrieval over your specific data and queries. This can lead to small increases in retrieval performance that in turn translate to better performing RAG systems.
A nice bonus: you do not need to re-embed your documents by using this adapter! Simply transform the query instead.
We have a full end-to-end guide showing how you can generate a synthetic dataset, fine-tune the linear adapter, and evaluate its performance.
Context
The concept of fine-tuning your embedding model is powerful. In fact, we were inspired to both add a full example repository / blog post as well as native abstractions in LlamaIndex showing how you can fine-tune a sentence_transformers model over any unstructured text corpus (with our SentenceTransformersFinetuneEngine
).
However, this approach has some limitations:
- The
SentenceTransformersFinetuneEngine
is limited to fine-tuningsentence_transformers
models. - After finetuning the embedding model, you will need to re-embed your document corpus.
During our Finetuning + RAG webinar last Friday, Jo (Vespa) mentioned the exact same problem: fine-tuning the embeddings model requires you to reindex your documents. However, his work with Vespa explored the concept of “freezing” document embeddings using a foundation model, and instead training a transformation on the query embedding.
This inspired us to explore a similar embedding fine-tuning approach that was simultaneously more general but also allowed us to freeze existing document embeddings.
Approach
Our brand-new EmbeddingAdapterFinetuneEngine
fine-tunes a linear adapter on top of query embeddings produced by any model. The linear adapter is simply a linear transformation that specifically transforms the query embedding while keeping document embeddings fixed.
The linear adapter can be used on top of any existing embeddings model: SBERT embeddings, OpenAI embeddings, Cohere embeddings, and more. As a result you can just plug this in on top of any embedding model that you’re already using!
Since document embeddings are unchanged, this means that you can always fine-tune this linear adapter after you’ve generated embeddings for your documents. You can choose to arbitrarily re-train this adapter on top of changing data distributions, without needing to re-embed all your documents.
Technical Details
As mentioned above, the linear adapter simply performs a linear transformation on top of the query embedding while keeping the Document embeddings fixed (with a weight matrix W + bias term b):
And that’s it! If document embeddings can be represented as a (n x d) matrix D, where n is number of documents and d is the embedding dimension, then embedding similarity is just measured by
The linear adapter is trained using a similar loss term as the MultipleNegativesRankingLoss
function in sentence_transformers
— given a batch of positive (question, context) examples, the function uses cross-entropy loss under the hood to penalize the ground-truth (question, context) pairs for being far apart and swapped pairs for being too close.
Additional Notes: We ended up writing the bulk of this fine-tuning logic in plain PyTorch, but taking heavy inspiration from the sentence_transformers
source code. We couldn’t use sentence_transformers directly since we take in embeddings as inputs rather than raw text. You can take a look at some of our training code here.
Notebook Walkthrough
In this notebook walkthrough, we follow a similar set of steps as our previous blog post on embedding fine-tuning:
- Generate a synthetic question-context dataset for both training and evaluation.
- Fine-tuning our linear adapter on top of an existing model (e.g. SBERT)
- Getting the embedding model, and evaluating it.
As with the previous post, we use the UBER and LYFT 10K as example data. We use Lyft to generate our training dataset and Uber to generate our evaluation dataset.
The full guide is here: https://gpt-index.readthedocs.io/en/latest/examples/finetuning/embeddings/finetune_embedding_adapter.html
Generate a Synthetic Dataset for Trraining and Evaluation
We use our helper abstractions, generate_qa_embedding_pairs
, to generate our training and evaluation dataset. This function takes in any set of text nodes (chunks) and generates a structured dataset containing (question, context) pairs.
from llama_index.finetuning import (
generate_qa_embedding_pairs,
EmbeddingQAFinetuneDataset,
)
# generate
train_dataset = generate_qa_embedding_pairs(train_nodes)
val_dataset = generate_qa_embedding_pairs(val_nodes)
# save
train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")
# load
train_dataset = EmbeddingQAFinetuneDataset.from_json("train_dataset.json")
val_dataset = EmbeddingQAFinetuneDataset.from_json("val_dataset.json")
Fine-tuning our Linear Adapter
We then fine-tune our linear adapter on top of an existing embedding model. We import our newEmbeddingAdapterFinetuneEngine
abstraction, which takes in an existing embedding model and a set of training parameters.
In this example we use the bge-small-en
sentence-transformers model, but we can also use any embedding model in LlamaIndex/LangChain.
from llama_index.finetuning import EmbeddingAdapterFinetuneEngine
from llama_index.embeddings import resolve_embed_model
import torch
base_embed_model = resolve_embed_model("local:BAAI/bge-small-en")
# alternative: use OpenAI
# from llama_index.embeddings import OpenAIEmbedding
# openai = OpenAIEmbedding()
finetune_engine = EmbeddingAdapterFinetuneEngine(
train_dataset,
base_embed_model,
model_output_path="<model_output_path>",
epochs=4,
verbose=True,
# can optionally pass along any parameters that go into `train_model`
# optimizer_class=torch.optim.SGD,
# optimizer_params={"lr": 0.01}
)
We can then call fine-tune to kick off the fine-tuning job. Training a linear model is quite straightforward and doesn’t require heavy machinery — this can easily run on a Macbook.
finetune_engine.finetune()
Getting the Embedding Model, and Evaluating it
Once the fine-tuning job is then, we can then fetch our embedding model.
We can either directly fetch it from our finetune_engine
, or import our new LinearAdapterEmbeddingModel
and construct it in a more manual fashion.
Option 1:
embed_model = finetune_engine.get_finetuned_model()
Option 2:
from llama_index.embeddings import LinearAdapterEmbeddingModel
embed_model = LinearAdapterEmbeddingModel(base_embed_model, "<model_output_path>")
The next step is to evaluate it. We compare the fine-tuned model against the base model, as well as against text-embedding-ada-002
.
We evaluate with two ranking metrics:
- Hit-rate metric: For each (query, context) pair, we retrieve the top-k documents with the query. It’s a hit if the results contain the ground-truth context.
- Mean Reciprocal Rank: A slightly more granular ranking metric that looks at the “reciprocal rank” of the ground-truth context in the top-k retrieved set. The reciprocal rank is defined as 1/rank. Of course, if the results don’t contain the context, then the reciprocal rank is 0.
Some additional comments:
- We ran with 4 epochs over the Lyft documents
- We used Adam as an optimizer with the default learning rate (we tried SGD and it didn’t work as well)
Results
In terms of hit-rate, the base model gets 78.7% hit-rate on the validation dataset, and the fine-tuned model gets 79.8%. In the meantime text-embedding-ada-002
gets 87.0%.
In terms of MRR, the base model gets 64.3%, and the fine-tuned model gets 66%. text-embedding-ada-002
gets 68.4%.
There is some performance bump from the fine-tuned model, though admittedly it is small — it is smaller than the performance bump gained through fine-tuning sentence_transformers directly on the latest dataset.
That said, a performance bump is still a performance bump, and it’s very cheap for you to spin up and try yourself! So you can decide whether or not this would make sense for you.
Conclusion
We created a brand-new module in LlamaIndex that allows you fine-tune a linear adapter on top of any embedding model.
It can help you eke out some marginal improvement in retrieval metrics; importantly, it allows you to keep document embeddings fixed and only transform the query.
Resources
Training code (if you want to take a look for yourself): https://github.com/jerryjliu/llama_index/blob/main/llama_index/finetuning/embeddings/adapter_utils.py