PyTorch and Zilliz Cloud Integration
PyTorch and Zilliz Cloud integrate to build advanced AI applications like image search and recommendation systems, combining PyTorch's deep learning framework with GPU acceleration and pre-trained models alongside Zilliz Cloud's high-performance vector database for scalable similarity search.
Use this integration for FreeWhat is PyTorch
PyTorch is an open-source machine learning library developed by Meta AI Research lab and now part of the Linux Foundation. It offers flexible computational capabilities for deep learning with intuitive APIs, GPU acceleration, and a comprehensive ecosystem. The library supports various machine learning domains including computer vision, NLP, and reinforcement learning, with a key feature being its dynamic computation graphs that can be modified during runtime, enabling rapid experimentation and prototyping.
By integrating with Zilliz Cloud (fully managed Milvus), PyTorch-generated embeddings from multiple data types — images, text, audio — can be stored, managed, and queried through a scalable and efficient vector database, enabling applications like image search engines, recommendation systems, and anomaly detection tools through high-performance vector similarity search.
Benefits of the PyTorch + Zilliz Cloud Integration
- Flexible embedding generation with scalable storage: PyTorch's pre-trained models (e.g., ResNet50) generate high-quality feature vectors from diverse data types, while Zilliz Cloud stores and indexes these embeddings for fast similarity search at scale.
- GPU-accelerated pipeline: PyTorch leverages GPU acceleration for efficient embedding generation, and Zilliz Cloud delivers low-latency vector retrieval, creating a high-performance end-to-end pipeline.
- Rich model ecosystem: PyTorch's Torchvision library provides pre-trained models for computer vision tasks, making it straightforward to build image search and visual similarity applications backed by Zilliz Cloud's vector storage.
- Batch processing support: The integration supports batch embedding generation and insertion, enabling efficient processing of large-scale datasets into Zilliz Cloud.
How the Integration Works
PyTorch serves as the deep learning framework, generating feature vector embeddings from raw data using pre-trained models. For image search, it uses models like ResNet50 from the Torchvision library to extract high-dimensional feature representations from images, with the final classification layer removed to produce pure embedding vectors.
Zilliz Cloud serves as the vector database layer, storing and indexing the embeddings generated by PyTorch. It provides high-performance similarity search using metrics like L2 distance, enabling fast retrieval of the most similar items from large collections.
Together, PyTorch and Zilliz Cloud create a complete similarity search solution: PyTorch processes raw data (images, text, audio) into vector embeddings using pre-trained models, and Zilliz Cloud stores and indexes these embeddings. When a query comes in, PyTorch embeds the query data, and Zilliz Cloud performs similarity search to find the closest matching items, enabling applications like visual search, recommendations, and anomaly detection.
Step-by-Step Guide
1. Install Required Packages
pip install pymilvus torch gdown torchvision tqdm2. Download the Dataset
Use
gdownto download the Impressionist-Classifier Dataset and extract it:import gdown import zipfile url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_' output = './paintings.zip' gdown.download(url, output) with zipfile.ZipFile("./paintings.zip","r") as zip_ref: zip_ref.extractall("./paintings")3. Set Global Parameters
COLLECTION_NAME = 'image_search' DIMENSION = 2048 MILVUS_HOST = "localhost" MILVUS_PORT = "19530" BATCH_SIZE = 128 TOP_K = 34. Set Up Milvus Collection
Connect to Milvus, create a collection with the appropriate schema, and build an index:
from pymilvus import connections, utility, FieldSchema, CollectionSchema, DataType, Collection connections.connect(host=MILVUS_HOST, port=MILVUS_PORT) if utility.has_collection(COLLECTION_NAME): utility.drop_collection(COLLECTION_NAME) fields = [ FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200), FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION) ] schema = CollectionSchema(fields=fields) collection = Collection(name=COLLECTION_NAME, schema=schema) index_params = { 'metric_type':'L2', 'index_type':"IVF_FLAT", 'params':{'nlist': 16384} } collection.create_index(field_name="image_embedding", index_params=index_params) collection.load()5. Generate Embeddings and Insert Data
Load images, generate embeddings using pre-trained ResNet50 model, and insert them into Milvus:
import glob import torch from torchvision import transforms from PIL import Image from tqdm import tqdm paths = glob.glob('./paintings/paintings/**/*.jpg', recursive=True) model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) model = torch.nn.Sequential(*(list(model.children())[:-1])) model.eval() preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def embed(data): with torch.no_grad(): output = model(torch.stack(data[0])).squeeze() collection.insert([data[1], output.tolist()]) data_batch = [[],[]] for path in tqdm(paths): im = Image.open(path).convert('RGB') data_batch[0].append(preprocess(im)) data_batch[1].append(path) if len(data_batch[0]) % BATCH_SIZE == 0: embed(data_batch) data_batch = [[],[]] if len(data_batch[0]) != 0: embed(data_batch) collection.flush()6. Perform Image Similarity Search
Embed test images and search for the most similar paintings in the collection:
import time from matplotlib import pyplot as plt search_paths = glob.glob('./paintings/test_paintings/**/*.jpg', recursive=True) def embed(data): with torch.no_grad(): ret = model(torch.stack(data)) if len(ret) > 1: return ret.squeeze().tolist() else: return torch.flatten(ret, start_dim=1).tolist() data_batch = [[],[]] for path in search_paths: im = Image.open(path).convert('RGB') data_batch[0].append(preprocess(im)) data_batch[1].append(path) embeds = embed(data_batch[0]) start = time.time() res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath']) finish = time.time() f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False) for hits_i, hits in enumerate(res): axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i])) axarr[hits_i][0].set_axis_off() axarr[hits_i][0].set_title('Search Time: ' + str(finish - start)) for hit_i, hit in enumerate(hits): axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath'))) axarr[hits_i][hit_i + 1].set_axis_off() axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance)) plt.savefig('search_result.png')Learn More
- Image Search with PyTorch and Milvus — Official Milvus tutorial for building image search with PyTorch
- Image Search with Zilliz Cloud and PyTorch — Zilliz Cloud documentation for PyTorch integration
- Zilliz Partnership with PyTorch — Zilliz and PyTorch partnership page with tutorials
- Elevating User Experience with Image-based Fashion Recommendations — Zilliz blog on image-based recommendations
- Getting Started with Hybrid Search with Milvus — Zilliz blog on hybrid search techniques