Training Your Own Text Embedding Model
Explore how to train your text embedding model using the `sentence-transformers` library and generate our training data by leveraging a pre-trained LLM.
Read the entire series
- Natural Language Processing Fundamentals: Tokens, N-Grams, and Bag-of-Words Models
- Primer on Neural Networks and Embeddings for Language Models
- Sparse and Dense Embeddings
- Sentence Transformers for Long-Form Text
- Training Your Own Text Embedding Model
- Evaluating Your Embedding Model
- Class Activation Mapping (CAM): Better Interpretability in Deep Learning Models
- CLIP Object Detection: Merging AI Vision with Language Understanding
- Discover SPLADE: Revolutionizing Sparse Data Processing
- Exploring BERTopic: An Advanced Neural Topic Modeling Technique
- Streamlining Data: Effective Strategies for Reducing Dimensionality
- All-Mpnet-Base-V2: Enhancing Sentence Embedding with AI
- Time Series Embedding in Data Analysis
- Enhancing Information Retrieval with Sparse Embeddings
- What is BERT (Bidirectional Encoder Representations from Transformers)?
- What is Mixture of Experts (MoE)?
Introduction to training your own text embedding models
In the previous post, we explored the inner workings of one of the text embedding models, Sentence-BERT, focusing on its Siamese network architecture and how it excels at converting long-form text into dense vector embeddings. This foundational understanding is critical, as it paves the way for more advanced techniques in natural language processing (NLP) and information retrieval.
Building on that foundation, in this post, we will guide you through the process of training your own transformer-based text embedding models using the sentence-transformers library. This tutorial is designed to be practical, taking you from the initial setup to the fine-tuning phase. Our objective is to help you create embedding models that are finely tuned to your specific needs, using your own data corpus.
We'll start by selecting a corpus, in this case, the Milvus documentation, which will serve as the base for our embedding model. The Milvus documentation is ideal for this exercise because it contains a wealth of technical information that can be transformed into valuable embeddings. The challenge lies in generating high-quality query-document pairs, which are essential for effective model training.
To tackle this, we’ll get creative by leveraging a large language model (LLM) to help us generate these pairs. The idea is to use the LLM to simulate various user queries that could be related to the content in the Milvus documentation. By doing so, we create a "labeled" dataset where each query is paired with a relevant document or passage. This dataset will be used to fine-tune our embedding model, improving its ability to understand and represent the specific types of queries and documents that are most relevant to our use case.
Once our dataset is ready, we’ll walk you through the fine-tuning process. Fine-tuning involves adjusting the pre-trained model on your specific dataset to enhance its performance in your domain. This step is crucial for tailoring the model’s output to the unique characteristics of your data, ensuring that the embeddings it generates are highly relevant and accurate.
Throughout this tutorial, we’ll cover best practices, potential pitfalls, and tips for optimizing the training process. Whether you’re looking to improve search functionality, enhance recommendation systems, or develop more sophisticated NLP applications, training your own text embedding model can provide significant advantages.
By the end of this post, you’ll have a trained embedding model that is well-suited to your specific domain to create embeddings, and you’ll be equipped with the knowledge to apply these techniques to other projects in the future. Let's dive in and start building a custom embedding model that meets your exact needs.
Sentence-BERT: a quick recap
BERT marked a significant advancement in natural language understanding, setting a new standard for how models process and comprehend text. One of the key innovations of BERT is its bidirectional nature, which allows it to read an entire sequence of words simultaneously. This bidirectionality enables BERT to capture context from both preceding and succeeding words in a sentence, leading to a more nuanced understanding of language compared to previous models that processed text in a unidirectional manner.
BERT models are pre-trained on an extensive corpus of text, covering a wide range of topics and language structures. This extensive pre-training allows BERT to perform well across various natural language processing tasks, even with limited labeled data. The model’s ability to generalize from its vast pre-training makes it highly effective for tasks like text classification, question answering, and named entity recognition.
However, despite its many strengths, BERT has limitations, particularly when it comes to efficiently vectorizing long-form text. BERT generates an embedding for each token in the input sequence, and these token-level embeddings are then pooled to create a single embedding representing the entire text. While this approach works well for shorter texts, it can fall short when dealing with longer, more complex documents. The pooling process can dilute the richness of the individual token embeddings, leading to less semantically meaningful representations of the text.
To address these limitations, SBERT (Sentence-BERT) was developed as an extension of BERT. SBERT utilizes a different training strategy that focuses on generating embeddings that are more semantically rich and meaningful. Instead of pooling token embeddings, SBERT employs a Siamese network architecture that allows it to directly compare sentences and generate embeddings that better capture the semantic relationships between them.
This training strategy enables SBERT to produce embeddings that are more suitable for tasks like semantic search, clustering, and sentence similarity, where understanding the meaning of entire sentences or documents is crucial. SBERT’s embeddings are designed to preserve the contextual richness of the input text, making it more effective for applications that require high-quality text representations, especially when dealing with longer texts.
Generating a "labeled" dataset using an LLM
This approach is one of the most creative ways to generate a labeled dataset. Let's dive right into the code for something like this. In this example, we'll use OpenAI's GPT-4 and provide scripts for other datasets.
import openai
import csv
import re
# Set your OpenAI API key
openai.api_key = 'YOUR_API_KEY'
def generate_chatgpt_responses(prompts, max_tokens=100):
data = []
for prompt in prompts:
try:
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens
)
generated_text = response['choices'][0]['message']['content']
data.append((prompt, generated_text))
except Exception as e:
print(f"An error occurred with prompt '{prompt}': {e}")
return data
def preprocess_text(text):
# Remove extra whitespaces and special characters if needed
text = re.sub(r'\s+', ' ', text).strip()
return text
def write_to_csv(data, filename="synthetic_data.csv"):
with open(filename, 'w', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
writer.writerow(['Prompt', 'Response'])
for prompt, response in data:
writer.writerow([prompt, preprocess_text(response)])
def main():
# Example prompts
prompts = [
"Explain the concept of machine learning",
"Summarize the plot of Romeo and Juliet",
"What are the benefits of renewable energy?"
# Add more prompts as needed
]
# Generate responses
generated_data = generate_chatgpt_responses(prompts)
# Write to CSV
write_to_csv(generated_data)
if __name__ == "__main__":
main()
We can feed a corpus of documents into this script to get prompts. Then, we can use this script to generate a large CSV with prompt/document pairs directly.
Training your model
Now, let's load a model for training. The first step is the same as in the previous tutorial:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('bert-base-nli-mean-tokens')
Now, we'll create a dataloader for our dataset. This data loader enables the sentence-transformers
library to load directly from a CSV to use as training data:
import csv
from torch.utils.data import DataLoader
from sentence_transformers import InputExample
examples = []
with open('your_dataset.csv', 'r') as file:
reader = csv.reader(file)
for row in reader:
sentence1, sentence2, score = row
examples.append(InputExample(texts=[sentence1, sentence2], label=float(score)))
train_dataloader = DataLoader(examples, batch_size=16, shuffle=True)
We'll also specify a similarity metric that we'd like to use.
from sentence_transformers import losses
train_loss = losses.CosineSimilarityLoss(model)
All that's left to do now is save the model:
model.fit(train_objectives=[(train_dataloader, train_loss)],
epochs=4,
warmup_steps=100)
model.save('path/to/your-model')
Putting all of this code together, we get:
from sentence_transformers import losses
import csv
from torch.utils.data import DataLoader
from sentence_transformers import InputExample
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('bert-base-nli-mean-tokens')
examples = []
with open('your_dataset.csv', 'r') as file:
reader = csv.reader(file)
for row in reader:
sentence1, sentence2, score = row
examples.append(InputExample(texts=[sentence1, sentence2], label=float(score)))
train_dataloader = DataLoader(examples, batch_size=16, shuffle=True)
train_loss = losses.CosineSimilarityLoss(model)
model.fit(train_objectives=[(train_dataloader, train_loss)],
epochs=4,
warmup_steps=100)
model.save('path/to/your-model')
Wrapping up
In this post, we trained our own transformer-based text embedding models using the sentence-transformers library for embedding generation. We also showed how to generate our own training data by leveraging a pre-trained LLM.
In the following tutorial, we'll talk about different ways to evaluate text embeddings. Stay tuned!
- Introduction to training your own text embedding models
- Sentence-BERT: a quick recap
- Generating a "labeled" dataset using an LLM
- Training your model
- Wrapping up
Content
Start Free, Scale Easily
Try the fully-managed vector database built for your GenAI applications.
Try Zilliz Cloud for Free