Training Your Own Text Embedding Model
Explore how to train your text embedding model using the `sentence-transformers` library and generate our training data by leveraging a pre-trained LLM.
Read the entire series
- Natural Language Processing Fundamentals: Tokens, N-Grams, and Bag-of-Words Models
- Primer on Neural Networks and Embeddings for Language Models
- Sparse and Dense Embeddings
- Sentence Transformers for Long-Form Text
- Training Your Own Text Embedding Model
- Evaluating Your Embedding Model
- Class Activation Mapping: Unveiling The Visual Story
- CLIP Object Detection: Merging AI Vision with Language Understanding
- Discover SPLADE: Revolutionizing Sparse Data Processing
- Exploring BERTopic: A New Era of Neural Topic Modeling
- Streamlining Data: Effective Strategies for Reducing Dimensionality
- All-Mpnet-Base-V2: Enhancing Sentence Embedding with AI
- Time Series Embedding in Data Analysis
- Enhancing Information Retrieval with Sparse Embeddings
Introduction to training your own text embedding models
In the previous post, we discussed how Sentence-BERT works and showed how its Siamese training methodology enables it to effectively turn long-form text into embeddings.
In this post, we'll build on that knowledge by training our transformer-based text embedding model using the sentence-transformers
library. We'll start with our own corpus of data (the Milvus documentation) and get creative with generating query-document pairs by leveraging an LLM. We'll then use this "labeled" dataset to fine-tune our own embedding model.
Let's dive in.
Sentence-BERT: a quick recap
BERT was a significant leap forward for natural language understanding. As a model, BERT reads the entire sequence of words simultaneously, allowing it to capture context from both directions. BERT models are also pre-trained on a massive corpus of text, enabling them to achieve strong results with very little labeled data.
Despite its strengths, BERT still needs to be more efficient in vectorizing long-form text since it generates a single embedding for each token, which is then pooled to form a single embedding. SBERT utilizes a training strategy to generate more semantically rich and meaningful embeddings.
Generating a "labeled" dataset using an LLM
This approach is one of the most creative ways to generate a labeled dataset. Let's dive right into the code for something like this. In this example, we'll use OpenAI's GPT-4 and provide scripts for other datasets.
import openai
import csv
import re
# Set your OpenAI API key
openai.api_key = 'YOUR_API_KEY'
def generate_chatgpt_responses(prompts, max_tokens=100):
data = []
for prompt in prompts:
try:
response = openai.ChatCompletion.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens
)
generated_text = response['choices'][0]['message']['content']
data.append((prompt, generated_text))
except Exception as e:
print(f"An error occurred with prompt '{prompt}': {e}")
return data
def preprocess_text(text):
# Remove extra whitespaces and special characters if needed
text = re.sub(r'\s+', ' ', text).strip()
return text
def write_to_csv(data, filename="synthetic_data.csv"):
with open(filename, 'w', newline='', encoding='utf-8') as file:
writer = csv.writer(file)
writer.writerow(['Prompt', 'Response'])
for prompt, response in data:
writer.writerow([prompt, preprocess_text(response)])
def main():
# Example prompts
prompts = [
"Explain the concept of machine learning",
"Summarize the plot of Romeo and Juliet",
"What are the benefits of renewable energy?"
# Add more prompts as needed
]
# Generate responses
generated_data = generate_chatgpt_responses(prompts)
# Write to CSV
write_to_csv(generated_data)
if __name__ == "__main__":
main()
We can feed a corpus of documents into this script to get prompts. Then, we can use this script to generate a large CSV with prompt/document pairs directly.
Training your model
Now, let's load a model for training. The first step is the same as in the previous tutorial:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('bert-base-nli-mean-tokens')
Now, we'll create a dataloader for our dataset. This data loader enables the sentence-transformers
library to load directly from a CSV to use as training data:
import csv
from torch.utils.data import DataLoader
from sentence_transformers import InputExample
examples = []
with open('your_dataset.csv', 'r') as file:
reader = csv.reader(file)
for row in reader:
sentence1, sentence2, score = row
examples.append(InputExample(texts=[sentence1, sentence2], label=float(score)))
train_dataloader = DataLoader(examples, batch_size=16, shuffle=True)
We'll also specify a similarity metric that we'd like to use.
from sentence_transformers import losses
train_loss = losses.CosineSimilarityLoss(model)
All that's left to do now is save the model:
model.fit(train_objectives=[(train_dataloader, train_loss)],
epochs=4,
warmup_steps=100)
model.save('path/to/your-model')
Putting all of this code together, we get:
from sentence_transformers import losses
import csv
from torch.utils.data import DataLoader
from sentence_transformers import InputExample
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('bert-base-nli-mean-tokens')
examples = []
with open('your_dataset.csv', 'r') as file:
reader = csv.reader(file)
for row in reader:
sentence1, sentence2, score = row
examples.append(InputExample(texts=[sentence1, sentence2], label=float(score)))
train_dataloader = DataLoader(examples, batch_size=16, shuffle=True)
train_loss = losses.CosineSimilarityLoss(model)
model.fit(train_objectives=[(train_dataloader, train_loss)],
epochs=4,
warmup_steps=100)
model.save('path/to/your-model')
Wrapping up
In this post, we trained our own transformer-based text embedding model using the sentence-transformers
library. We also showed how to generate our own training data by leveraging a pre-trained LLM.
In the following tutorial, we'll talk about different ways to evaluate text embeddings. Stay tuned!