To use a GPU for faster embedding generation with Sentence Transformers, you need to ensure the model and data are moved to the GPU, and leverage parallel processing capabilities. The primary code changes involve specifying the device (e.g., cuda
) when initializing the model and ensuring input data is processed in batches. Here’s how it works:
1. Model Initialization on GPU
Sentence Transformers uses PyTorch under the hood, so you can move the model to the GPU using .to('cuda')
or by setting the device
parameter during initialization. For example:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-MiniLM-L6-v2', device='cuda') # or model.to('cuda')
This loads the model onto the GPU, enabling it to utilize CUDA cores for matrix operations. If the GPU is unavailable, fallback to CPU with a check:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SentenceTransformer('model_name').to(device)
2. Data Handling and Batching
The model.encode()
method automatically moves input data to the same device as the model. For example:
embeddings = model.encode(["Your text here"], batch_size=128)
The batch_size
parameter controls how many sentences are processed in parallel. Larger batches maximize GPU utilization but require sufficient VRAM. Adjust this based on your GPU’s memory (e.g., reduce batch size for long sequences).
3. Output Handling and Pitfalls
By default, model.encode()
returns numpy arrays, which are CPU-based. If you need PyTorch tensors (e.g., for further GPU computation), use:
embeddings = model.encode(texts, convert_to_tensor=True)
Ensure your environment has CUDA drivers and PyTorch installed with GPU support (torch.cuda.is_available()
). Common issues include out-of-memory errors (fixed by reducing batch size) or mismatched device errors (ensure model and data are on the same device).
Example Workflow
import torch
from sentence_transformers import SentenceTransformer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
sentences = ["This is a sample sentence.", "Another example text."]
embeddings = model.encode(sentences, batch_size=64, convert_to_tensor=True)
This code dynamically uses a GPU if available, processes data in batches, and returns tensors for further GPU-based operations. No other code changes are required—Sentence Transformers handles device alignment internally.