KI nutzen, um deinen Promi-Stylisten zu finden (Teil II)
In meinem vorherigen Blogbeitrag „Using AI to Find Your Celebrity Stylist“ habe ich erklärt, wie man Technologien der künstlichen Intelligenz (KI), wie Milvus, eine quelloffene, KI-native Vektordatenbank, und Hugging Face-Modelle nutzt, um Stilentscheidungen von Prominenten zu finden, die zu den eigenen passen. In diesem Folgebeitrag gehen wir einen Schritt weiter und zeigen, wie man durch einige Codeänderungen am vorherigen Projekt detailliertere und genauere Ergebnisse erhält. Außerdem gebe ich Vorschläge, wie du dieses Projekt selbst erweitern kannst.
Wenn du dieses Projekt direkt ausprobieren möchtest, lade die Fotos und das fertige Notebook herunter. Wenn du dich für das Projekt interessierst, das in meinem vorherigen Blog besprochen wurde, kannst du dir die Fotos, die es verwendet, und sein Tutorial ansehen.
Zusammenfassung des Tutorials für mein vorheriges Fashion-AI-Projekt
Bevor wir tief in dieses Projekt eintauchen, möchte ich das Tutorial, das wir in meinem vorherigen Beitrag besprochen haben, kurz zusammenfassen. So musst du diese Seite nicht verlassen, um den Kontext zu verstehen.
Importieren aller notwendigen Bibliotheken für die Bildbearbeitung
Wir beginnen den Code mit dem Import aller notwendigen Bibliotheken für die Bildbearbeitung, darunter torch für die Feature-Extraktion, das segformer-Objekt aus transformers, matplotlib und einige torchvision-Importe wie Resize, masks_to_boxes und crop.
import torch
from torch import nn, tensor
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from torchvision.transforms import Resize
import torchvision.transforms as T
from torchvision.ops import masks_to_boxes
from torchvision.transforms.functional import crop
Vorverarbeitung der Prominentenbilder
Nachdem du alle notwendigen Pakete für die Bildbearbeitung importiert hast, kannst du mit der Verarbeitung deiner Bilder beginnen. Die folgenden drei Funktionen (get_segmentation, get_masks und crop_images) werden verwendet, um Kleidungsstücke zu segmentieren und sie für die weitere Aufnahme zuzuschneiden.
def get_segmentation(extractor, model, image):
inputs = extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
return pred_seg
# returns two lists masks (tensor) and obj_ids (int)
# "mattmdjaga/segformer_b2_clothes" from hugging face
def get_masks(segmentation):
obj_ids = torch.unique(segmentation)
obj_ids = obj_ids[1:]
masks = segmentation == obj_ids[:, None, None]
return masks, obj_ids
def crop_images(masks, obj_ids, img):
boxes = masks_to_boxes(masks)
crop_boxes = []
for box in boxes:
crop_box = tensor([box[0], box[1], box[2]-box[0], box[3]-box[1]])
crop_boxes.append(crop_box)
preprocess = T.Compose([
T.Resize(size=(256, 256)),
T.ToTensor()
])
cropped_images = {}
for i in range(len(crop_boxes)):
crop_box = crop_boxes[i]
cropped = crop(img, crop_box[1].item(), crop_box[0].item(), crop_box[3].item(), crop_box[2].item())
cropped_images[obj_ids[i].item()] = preprocess(cropped)
return cropped_images
Speichern der Bilddaten in einer Vektordatenbank
Wir verwenden Milvus, eine quelloffene KI-native Vektordatenbank, um Bilddaten zu speichern. Um zu beginnen, entpacken Sie die ZIP-Datei photos für dieses Projekt und fügen Sie den Ordner in dasselbe Stammverzeichnis wie das Notebook ein. Sobald dieser Schritt abgeschlossen ist, können Sie den untenstehenden Code ausführen, um die Bilder zu verarbeiten und die Daten in Milvus zu speichern.
import os
image_paths = []
for celeb in os.listdir("./photos"):
for image in os.listdir(f"./photos/{celeb}/"):
image_paths.append(f"./photos/{celeb}/{image}")
from milvus import default_server
from pymilvus import utility, connections
default_server.start()
connections.connect(host="127.0.0.1", port=default_server.listen_port)
DIMENSION = 2048
BATCH_SIZE = 128
COLLECTION_NAME = "fashion"
TOP_K = 3
from pymilvus import FieldSchema, CollectionSchema, Collection, DataType
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),
FieldSchema(name="name", dtype=DataType.VARCHAR, max_length=200),
FieldSchema(name="seg_id", dtype=DataType.INT64),
FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]
schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)
index_params = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()
Als Nächstes können Sie den untenstehenden Code ausführen, um Embeddings mit dem Nvidia ResNet 50-Modell von Hugging Face zu generieren.
# run this before importing th resnet50 model if you run into an SSL certificate URLError
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# Load the embedding model with the last layer removed
embeddings_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
embeddings_model = torch.nn.Sequential(*(list(embeddings_model.children())[:-1]))
embeddings_model.eval()
Die folgende Funktion definiert, wie Daten eingebettet und eingefügt werden. Anschließend durchläuft der Code alle Bilder und bettet sie ein und fügt sie in Milvus ein.
Hinweis: Viele der untenstehenden Komponenten werden sich ändern oder entfernt werden, wenn die neue dynamische Schema-Funktion von Milvus genutzt wird.
def embed_insert(data, collection, model):
with torch.no_grad():
output = model(torch.stack(data[0])).squeeze()
collection.insert([data[1], data[2], data[3], output.tolist()])
from PIL import Image
data_batch = [[], [], [], []]
for path in image_paths:
image = Image.open(path)
path_split = path.split("/")
name = " ".join(path_split[2].split("_"))
segmentation = get_segmentation(extractor, model, image)
masks, ids = get_masks(segmentation)
cropped_images = crop_images(masks, ids, image)
for key, image in cropped_images.items():
data_batch[0].append(image)
data_batch[1].append(path)
data_batch[2].append(name)
data_batch[3].append(key)
if len(data_batch[0]) % BATCH_SIZE == 0:
embed_insert(data_batch, collection, embeddings_model)
data_batch = [[], [], [], []]
if len(data_batch[0]) != 0:
embed_insert(data_batch, collection, embeddings_model)
collection.flush()
Die Vektordatenbank abfragen
Der folgende Code demonstriert, wie Milvus mit Eingabebildern abgefragt wird und die drei besten Ergebnisse für jedes Kleidungsstück abgerufen werden.
def embed_search_images(data, model):
with torch.no_grad():
output = model(torch.stack(data))
if len(output) > 1:
return output.squeeze().tolist()
else:
return torch.flatten(output, start_dim=1).tolist()
# data_batch[0] is a list of tensors
# data_batch[1] is a list of filepaths to the images (string)
# data_batch[2] is a list of the names of the people in the images (string)
# data_batch[3] is a list of segmentation keys (int)
data_batch = [[], [], [], []]
search_paths = ["./photos/Taylor_Swift/Taylor_Swift_3.jpg", "./photos/Taylor_Swift/Taylor_Swift_8.jpg"]
for path in search_paths:
image = Image.open(path)
path_split = path.split("/")
name = " ".join(path_split[2].split("_"))
segmentation = get_segmentation(extractor, model, image)
masks, ids = get_masks(segmentation)
cropped_images = crop_images(masks, ids, image)
for key, image in cropped_images.items():
data_batch[0].append(image)
data_batch[1].append(path)
data_batch[2].append(name)
data_batch[3].append(key)
embeds = embed_search_images(data_batch[0], embeddings_model)
import time
start = time.time()
res = collection.search(embeds,
anns_field='embedding',
param={"metric_type": "L2",
"params": {"nprobe": 10}},
limit=TOP_K,
output_fields=['filepath'])
finish = time.time()
print(finish - start)
for index, result in enumerate(res):
print(index)
print(result)
Mehr Musterabgleich: Auswählen von Objekten aus jedem Bild
Wenn du dem oben zusammengefassten Tutorial folgst, kannst du die drei besten Promi-Style-Übereinstimmungen für jedes Kleidungsstück entdecken, nach dem du suchst. Du kannst auch ein Bild wie das untenstehende erstellen, ohne die Bounding Boxes der gefundenen Elemente. In diesem Abschnitt erkläre ich, wie du mit ein paar Codeänderungen gegenüber dem im vorherigen Tutorial verwendeten Code Modestile findest, deren Muster deinen eigenen näherkommen.
Bild
Importieren aller notwendigen Bibliotheken für die Bildbearbeitung
Importiere zunächst alle notwendigen Bibliotheken für die Bildbearbeitung in deinen Code. Wenn du dies bereits getan hast, kannst du diesen Schritt überspringen.
import torch
from torch import nn, tensor
from transformers import AutoFeatureExtractor, SegformerForSemanticSegmentation
import matplotlib.pyplot as plt
from torchvision.transforms import Resize
import torchvision.transforms as T
from torchvision.ops import masks_to_boxes
from torchvision.transforms.functional import crop
Vorverarbeiten deiner Bilder
Wenn du alle notwendigen Pakete für die Bildbearbeitung importiert hast, fahre mit dem Bildsegmentierungsprozess fort, der drei Funktionen umfasst: get_segmentation, get_masks und crop_images.
An der Funktion get_segmentation müssen wir keine Codeänderungen vornehmen.
Für die Funktion get_masks müssen wir nur die Segmentierungen abrufen, die den Segmentierungs-IDs in der Liste wanted entsprechen. Dies ist eine neue Ergänzung, die Segmentierungs-IDs für Kleidungsstücke enthält, wie in der Model Card auf Hugging Face angegeben.
Die meisten Codeänderungen nehmen wir an der Funktion crop_image vor. In meinem vorherigen Tutorial gab diese Funktion eine Liste zugeschnittener Bilder zurück. Nachdem wir etwas Code geändert haben, gibt sie nun drei Objekte zurück: Embeddings der zugeschnittenen Bilder, eine Liste der Koordinaten der Boxen auf dem Originalbild und eine Liste der Segmentierungs-IDs. Diese neue Einrichtung verschiebt das Embedding vom Batch-Insert in den Transformationsschritt.
wanted = [1, 3, 4, 5, 6, 7, 8, 9, 10, 16, 17]
def get_segmentation(image):
inputs = extractor(images=image, return_tensors="pt")
outputs = segmentation_model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
pred_seg = upsampled_logits.argmax(dim=1)[0]
return pred_seg
# gibt zwei Listen zurück: masks (Tensor) und obj_ids (int)
# "mattmdjaga/segformer_b2_clothes" von Hugging Face
def get_masks(segmentation):
obj_ids = torch.unique(segmentation)
obj_ids = obj_ids[1:]
wanted_ids = [x.item() for x in obj_ids if x in wanted]
wanted_ids = torch.Tensor(wanted_ids)
masks = segmentation == wanted_ids[:, None, None]
return masks, obj_ids
def crop_images(masks, obj_ids, img):
boxes = masks_to_boxes(masks)
crop_boxes = []
for box in boxes:
crop_box = tensor([box[0], box[1], box[2]-box[0], box[3]-box[1]])
crop_boxes.append(crop_box)
preprocess = T.Compose([
T.Resize(size=(256, 256)),
T.ToTensor()
])
cropped_images = []
seg_ids = []
for i in range(len(crop_boxes)):
crop_box = crop_boxes[i]
cropped = crop(img, crop_box[1].item(), crop_box[0].item(), crop_box[3].item(), crop_box[2].item())
cropped_images.append(preprocess(cropped))
seg_ids.append(obj_ids[i].item())
with torch.no_grad():
embeddings = embeddings_model(torch.stack(cropped_images)).squeeze().tolist()
return embeddings, boxes.tolist(), seg_ids
Da wir nun die Bilder haben, ist es an der Zeit, sie zu laden. Dieser Schritt umfasst das Einfügen in Batches, das wir in meinem vorherigen Tutorial behandelt haben. In diesem Tutorial fügen wir alle unsere Daten auf einmal als Liste von Dictionaries statt als Liste von Listen ein. Ich finde diese Einfügemethode viel übersichtlicher, und sie ermöglicht es uns, zur Einfügezeit ein neues Feld zum Schema hinzuzufügen. In diesem Fall fügen wir eine Liste von Crop-Ecken hinzu.
for path in image_paths:
image = Image.open(path)
path_split = path.split("/")
name = " ".join(path_split[2].split("_"))
segmentation = get_segmentation(image)
masks, ids = get_masks(segmentation)
embeddings, crop_corners, seg_ids = crop_images(masks, ids, image)
inserts = [{"embedding": embeddings[x], "seg_id": seg_ids[x], "name": name, "filepath": path, "crop_corner": crop_corners[x]} for x in range(len(embeddings))]
collection.insert(inserts)
collection.flush()
Die Vektordatenbank abfragen
Jetzt ist es an der Zeit, Abfragen in Milvus, unserer Vektordatenbank, durchzuführen. Im Vergleich zu den Schritten, die wir im vorherigen Tutorial verwendet haben, gibt es hier einige Unterschiede:
- Erstens begrenzen wir die Anzahl der „Treffer“, die uns in einem Bild interessieren, auf fünf.
- Zweitens zeigen wir die drei am besten passenden Bilder.
- Drittens fügen wir eine Funktion hinzu, um eine Farbkarte zum Zeichnen von Bounding Boxes in verschiedenen Farben zu erhalten.
Nun richten wir die matplotlib-Figure und -Axes ein. Dann durchlaufen wir alle unsere Bilder und wenden die drei oben genannten Verarbeitungsfunktionen an, um die Segmentierungen und Bounding Boxes zu erhalten.
Nachdem wir die Bilder vorverarbeitet haben, können wir in Milvus nach ihnen suchen. Wir erhalten die drei besten Antworten für jedes Bild basierend auf der Anzahl der „passenden“ Artikel, die sie enthalten. Schließlich geben wir die Ergebnisse zusammen mit den Bounding Boxes aus, die Treffer zurückgegeben haben.
from pprint import pprint
from PIL import ImageDraw
from collections import Counter
import matplotlib.patches as patches
LIMIT = 5 # Wie viele nächste Treffer pro Kleidungsstück analysiert werden sollen
CLOSEST = 3 # Wie viele nächste Bilder angezeigt werden sollen. CLOSEST <= Limit
search_paths = ["./photos/Taylor_Swift/Taylor_Swift_2.jpg", "./photos/Jenna_Ortega/Jenna_Ortega_6.jpg"] # Bilder, nach denen gesucht werden soll
def get_cmap(n, name='hsv'):
'''Gibt eine Funktion zurück, die jeden Index in 0, 1, ..., n-1 auf eine eindeutige
RGB-Farbe abbildet; das Schlüsselwortargument name muss ein standardmäßiger mpl-Colormap-Name sein.
Quelle: https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib'''
return plt.cm.get_cmap(name, n)
# Ergebnis-Subplots erstellen
f, axarr = plt.subplots(max(len(search_paths), 2), CLOSEST + 1)
for search_i, path in enumerate(search_paths):
# Zuschnitte und Einbettungen für alle gefundenen Elemente erzeugen
image = Image.open(path)
segmentation = get_segmentation(image)
masks, ids = get_masks(segmentation)
embeddings, crop_corners, _ = crop_images(masks, ids, image)
# Farbkarte erzeugen
cmap = get_cmap(len(crop_corners))
# Das erste Feld mit dem gesuchten Bild anzeigen
axarr[search_i][0].imshow(image)
axarr[search_i][0].set_title('Suchbild')
axarr[search_i][0].axis('off')
for i, (x0, y0, x1, y1) in enumerate(crop_corners):
rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=1, edgecolor=cmap(i), facecolor='none')
axarr[search_i][0].add_patch(rect)
# Die Datenbank nach allen Zuschnitten durchsuchen
start = time.time()
res = collection.search(embeddings,
anns_field='embedding',
param={"metric_type": "L2",
"params": {"nprobe": 10}, "offset": 0},
limit=LIMIT,
output_fields=['filepath', 'crop_corner'])
finish = time.time()
print("Gesamte Suchzeit: ", finish - start)
# Die wichtigsten eindeutigen Ergebnisse zusammenfassen und basierend auf ihrer Position in den Ergebnissen gewichten
filepaths = []
for hits in res:
seen = set()
for i, hit in enumerate(hits):
if hit.entity.get("filepath") not in seen:
seen.add(hit.entity.get("filepath"))
filepaths.extend([hit.entity.get("filepath") for _ in range(len(hits) - i)])
# Das am häufigsten eingestufte Ergebnisbild finden
counts = Counter(filepaths)
most_common = [path for path, _ in counts.most_common(CLOSEST)]
# Für jedes Bild das entsprechende gefundene Element extrahieren, das mit den Suchbildern korreliert
matches = {}
for i, hits in enumerate(res):
matches[i] = {}
tracker = set(most_common)
for hit in hits:
if hit.entity.get("filepath") in tracker:
matches[i][hit.entity.get("filepath")] = hit.entity.get("crop_corner")
tracker.remove( hit.entity.get("filepath"))
# Die häufigsten Bilder in den Ergebnissen anzeigen
for res_i, res_path in enumerate(most_common):
# Jedes der Bilder neben dem Suchbild anzeigen
image = Image.open(res_path)
axarr[search_i][res_i+1].imshow(image)
axarr[search_i][res_i+1].set_title(" ".join(res_path.split("/")[2].split("_")))
axarr[search_i][res_i+1].axis('off')
# Begrenzungsrahmen für alle übereinstimmenden Elemente hinzufügen
for key, value in matches.items():
if res_path in value:
x0, y0, x1, y1 = value[res_path]
rect = patches.Rectangle((x0, y0), x1-x0, y1-y0, linewidth=1, edgecolor=cmap(key), facecolor='none')
axarr[search_i][res_i+1].add_patch(rect)
Sobald Sie die obigen Schritte abgeschlossen haben, sollten Sie ein Ergebnis ähnlich dem unten gezeigten oder dem zu Beginn dieser Sitzung erhalten.
image
Was kommt als Nächstes? mögliche Projekterweiterungen
Ich lege dieses Projekt vorerst auf Eis, um an einigen anderen Projekten zu arbeiten, aber Sie können es gerne erweitern, wenn Sie möchten! Hier sind drei mögliche Erweiterungen.
Erstens können Sie das Vergleichsspiel etwas weiter ausarbeiten. Zum Beispiel können Sie aufgeteilte Elemente gruppieren, etwa indem Sie beide Schuhe als ein einzelnes Element markieren. Sie können auch weitere Bilder von Prominenten oder Freunden hinzufügen, um mehr Vergleiche zu ermöglichen.
Zweitens können Sie dieses Projekt in einen Mode-Identifikator oder ein Empfehlungssystem verwandeln. Anstatt Bilder von Prominenten zu verwenden, können Sie Bilder von Kleidung verwenden, die online gekauft werden kann. Wenn ein Benutzer ein Bild hochlädt, können Sie es mit den Bildern in Ihrer Vektordatenbank vergleichen und dem Benutzer die ähnlichsten Kleidungsstücke vorschlagen.
Drittens können Sie einen Style-Generator erstellen, was anspruchsvoller sein kann. Es gibt verschiedene Möglichkeiten, dies zu tun, aber eine Idee ist, mehrere Bilder eines Benutzers zu nehmen und darauf basierende Vorschläge zu generieren. Dieser Ansatz beinhaltet die Verwendung eines generativen Bildmodells, um Style-Vorschläge bereitzustellen und sie zum Vergleich mit den ähnlichsten Bildern von Benutzern abzugleichen. Anschließend können wir auf Grundlage dieses Vergleichs etwas Sinnvolles vorschlagen.
Diese drei Erweiterungen sind nur einige Beispiele dafür, wie sich mein einfaches Projekt mithilfe eines Bildmodells und einer Vektordatenbank wie Milvus verbessern lässt. Die Verwendung einer Vektordatenbank ermöglicht eine Vielzahl von Ähnlichkeitssuchaufgaben, was besonders wertvoll für den Vergleich von Bildern ist.
Zusammenfassung
In diesem Tutorial haben wir unser erstes Celebrity-Style-Projekt erweitert, indem wir Milvus' neues dynamisches Schema verwendet, bestimmte Segmentierungs-IDs herausgefiltert und die Bounding Boxes unserer Treffer nachverfolgt haben. Außerdem haben wir unsere Suchergebnisse sortiert, um die besten drei Ergebnisse basierend auf der Anzahl der Treffer zurückzugeben.
Milvus' neues dynamisches Schema ermöglicht es uns, zusätzliche Felder hinzuzufügen, wenn wir Daten im Wörterbuchformat hochladen, wodurch sich die Art und Weise ändert, wie wir ursprünglich eine Liste von Listen im Batch hochgeladen haben. Es erleichterte außerdem das Hinzufügen von Zuschneidekoordinaten, ohne das Schema zu ändern.
Als neuen Vorverarbeitungsschritt haben wir bestimmte IDs herausgefiltert, die laut Model Card in Hugging Face nicht kleidungsbezogen sind. Wir filtern diese IDs in der Funktion get_masks heraus. Interessante Tatsache: Das Objekt obj_ids in dieser Funktion ist tatsächlich ein Tensor.
Wir haben außerdem die Bounding Boxes nachverfolgt. Wir haben den Einbettungsschritt in die Bildzuschneidefunktion verschoben und die Einbettungen zusammen mit den Bounding Boxes und Segmentierungs-IDs zurückgegeben. Anschließend haben wir diese Einbettungen mithilfe eines dynamischen Schemas in Milvus gespeichert.
Zur Abfragezeit haben wir alle zurückgegebenen Bilder nach der Anzahl der enthaltenen Bounding Boxes aggregiert, sodass wir das am besten passende Celebrity-Bild über verschiedene Kleidungsstücke hinweg finden konnten. Jetzt liegt es an Ihnen. Sie können meine Vorschläge aufgreifen und etwas anderes daraus machen, zum Beispiel ein Fashion-Recommender-System, ein besseres Style-Vergleichssystem für Sie und Ihre Freunde oder eine generative Fashion-KI-App.
Weiterlesen

Why Not All VectorDBs Are Agent-Ready
Explore why choosing the right vector database is critical for scaling AI agents, and why traditional solutions fall short in production.

Building RAG Pipelines for Real-Time Data with Cloudera and Milvus
explore how Cloudera can be integrated with Milvus to effectively implement some of the key functionalities of RAG pipelines.

DeepSeek-VL2: Mixture-of-Experts Vision-Language Models for Advanced Multimodal Understanding
Explore DeepSeek-VL2, the open-source MoE vision-language model. Discover its architecture, efficient training pipeline, and top-tier performance.



