GraphSAGE: Nutzerverhalten mit ML-basierten Empfehlungsdiensten vorhersagen

Autor*in:
Anna
Höpfner

„Gefällt mir!“ Tagtäglich erhalten wir personalisierte Empfehlungen zu Produkten oder Inhalten – egal ob beim Shopping auf diversen Online-Marktplätzen, beim Streaming von Filmen und Musik oder auf Social-Media-Plattformen. Verantwortlich dafür sind sogenannte Empfehlungsdienste, welche auf Methoden des Maschinellen Lernens basieren. Im Bereich der Empfehlungsdienste finden Graph Neuronal Networks (GNNs) immer häufiger Anwendung. Diese Art von künstlichen neuronalen Netzen arbeitet auf Daten mit Graphstruktur und eignet sich daher, um als Graph codierte Nutzer- Produkt-Interaktionen zu analysieren und damit Vorhersagen über das Nutzerverhalten zu ermöglichen. Die entstehenden Graphen bestehen in der Regel aus Billionen von Knoten und Kanten und entwickeln sich zudem dynamisch, da sich täglich neue Nutzer*innen anmelden und ständig weitere Produkte hinzukommen. GraphSAGE erweitert die Idee von klassischen Convolutional Neural Networks auf Graphen durch die Einführung einer Sampling-Strategie und Aggregationsfunktionen, um das Lernen von Knoteneigenschaften in besonders großen und dynamisch wachsenden Graphen zu ermöglichen. Ein bekanntes Beispiel für den Einsatz von GraphSAGE ist das Empfehlungssystem von Pinterest, welches den Algorithmus nutzt, um personalisierte Vorschläge für Nutzer*innen zu erstellen. Auch UberEats verwendet eine modifizierte Version von GraphSAGE, um Verbraucher*innen passende Essensvorschläge zu machen. In diesem Beitrag stellen wir den GraphSAGE Algorithmus genauer vor. Für einen Einstieg in das Thema Graph Neural Networks empfehlen wir den Blogbeitrag ”Graph Neural Networks”.

Abbildung 1: Auswahl der Nachbarn für einen Zielknoten: In Suchtiefe 1 werden vier Nachbarn ausgewählt. In Suchtiefe 2 werden für jeden dieser Knoten je zwei weitere Nachbarn gewählt.

Auswahl der Nachbarn – “Neighborhoodsampling”

Möchte man ein GNN-Modell auf einem großen Graphen trainieren, stößt man in der Praxis häufig auf zwei Probleme, welche die Genauigkeit und Effizienz des Modells beeinträchtigen können. Zum einen ist die Verarbeitung von Graphen mit Milliarden von Knoten sehr rechenaufwendig und erfordert viel Speicherplatz. Zum anderen kann die Nachbarschaftsgröße von Knoten stark variieren, wodurch Knoten mit überdurchschnittlich vielen Nachbarn existieren können. Diese können das Modell stark beeinflussen, da sie Informationen von sehr vielen Nachbarn erhalten. Dies führt zu einem enormen Rechenaufwand und kann darüber hinaus die Repräsentation anderer Knoten verzerren, was die Genauigkeit des Modells beeinträchtigen kann. Um die genannten Probleme zu lösen, wird für jeden Knoten eine konstante Anzahl von Nachbarn ausgewählt, mit denen Nachrichten ausgetauscht werden. Dieser Neighborhoodsampling-Prozess umfasst folgende Schritte:

  1. Festlegen der Suchtiefe k: Es werden k Iterationen des Message Passing durchgeführt.
  2. Auswahl der Nachbarn in jeder Suchtiefe: Die Auswahl kann uniform zufällig (Ziehen mit Zurücklegen) oder mit einer, für den Anwendungsfall geeigneten, Methode erfolgen.

Diese Schritte werden für jeden Knoten durchgeführt, um schließlich einen Teilgraph mit allen Knoten zu erhalten, die für weitere Berechnungen benötigt werden.

Abbildung 2: Die Aggregation am Beispiel eines Lieferservices verdeutlicht, wie GraphSAGE die Informationen der Nachbarschaft eines Knotens mithilfe der Mittelwertfunktion aggregiert. Hierbei repräsentieren die Knoten die verschiedenen Restaurants und die Kanten zwischen den Knoten zeigen an, ob ein*e Nutzer*in bereits bei diesem Restaurant bestellt hat oder nicht. Jedes Restaurant hat verschiedene Attribute, wie beispielsweise die Option vegetarisches Essen zu bestellen (1) oder nicht (0).

Informationsverarbeitung – Aggregationsfunktionen

Die Aggregationsfunktion beschreibt, wie die Informationen der oben beschriebenen Nachbarschaft zusammengeführt und verarbeitet werden, um dadurch Informationen über den Zielknoten abzuleiten. Um die Aggregationsfunktionen und damit die Vorhersagen des Modells zu verbessern, hat jede Aggregationsfunktion eine Gewichtsmatrix als Parameter, welche die Gewichtung der Nachbarknoten bestimmt. Die Anpassung der Gewichtungen, und somit der Relevanz einzelner Nachbarknoten, erfolgt durch das Training des Modells mit Verfahren wie dem Gradientenabstieg [Link zu entsprechendem Beitrag]. Die Wahl der Aggregationsfunktion hängt von der spezifischen Aufgabe ab und kann die Leistung des Algorithmus deutlich beeinflussen. Einige Beispiele für Aggregationsfunktionen sind die Mittelwertfunktion, welche den Durchschnitt der Merkmale in der Nachbarschaft berechnet oder die Maximumsfunktion, welche das am meisten repräsentative Merkmal herausfiltert.

Wie lernt GraphSAGE?

Da wir nun mit den Grundlagen vertraut sind, können wir uns abschließend den gesamten GraphSAGE Lernalgorithmus anschauen. Er nimmt einen Graphen als Eingabe und gibt eine Vektordarstellung jedes Knotens im Graph aus. Dabei werden folgende Schritte angewendet:

  1. Initialisierung: Merkmale der Knoten werden für eine initiale Darstellung verwendet.
  2. Neighborhoodsampling
  3. Aggregation: Durch eine Aggregationsfunktion wird eine Repräsentation der Nachbarschaft jedes Knotens generiert.
  4. Konkatenation: Die aggregierten Merkmale der Nachbarn werden mit der aktuellen Repräsentation des Knotens verknüpft, um zu einer verbesserten Darstellung beizutragen.
  5. Vollvernetzte Schicht mit Aktivierungsfunktion: Die vollvernetzte Schicht besteht aus einer Matrix von Gewichten, die während des Trainings angepasst werden, um die bestmögliche Kombination von Merkmalen zu erzielen. Die Aktivierungsfunktion dient dazu, auch nicht- lineare Beziehungen zwischen den Knoten zu erfassen.
  6. Normalisierung: Um die Skalierung der Repräsentationen der Knoten zu standardisieren und damit die Stabilität des Modells zu verbessern, werden die neuen Merkmalsvektoren der Knoten schließlich normalisiert, indem der Vektor durch seine Euklidische Norm geteilt wird.
  7. Optimierung der Modellparameter mithilfe des Gradientenabstiegsverfahren

Schritte 2.-7. werden wiederholt, um die Aggregationsfunktionen und damit die Repräsentationen der Knoten iterativ zu verbessern.

Die finalen Darstellungen der Knoten können als Eingabe für Machine-Learning-Modelle verwendet werden, um verschiedene Aufgaben wie Klassifikation oder Vorhersage von Beziehungen zu lösen. Durch die Nutzung der Knotenrepräsentationen können diese Modelle das Wissen über die Struktur und Beziehungen des Graphen effektiver nutzen und somit bessere Vorhersagen treffen.

Fazit

GraphSAGE ist ein leistungsfähiges maschinelles Lernverfahren für die Analyse von Daten mit Graphstruktur. Es findet Anwendung in Empfehlungsdiensten, um Nutzer*innenverhalten basierend auf der Interaktion mit Produkten oder Inhalten vorherzusagen. Die Methode von GraphSAGE erweitert die Idee des Message Passing Modell, indem es eine Sampling-Strategie und Aggregationsfunktionen hinzufugt und erlaubt damit das Lernen von Knoteneigenschaften in besonders großen und dynamisch wachsenden Graphen. GraphSAGE ist ein vielversprechender Ansatz für die Verarbeitung von großen Graphdaten, der neben Empfehlungsdiensten auch in anderen Bereichen, wie der sozialen Netzwerkanalyse, der Molekularbiologie und Fraud Detection Anwendung findet.

Auch wenn das Semester noch in vollem Gange ist, verabschiedet sich unsere Reihe ML Classroom erst einmal in die Semesterferien. Weiter geht es hier nächsten Monat mit spannenden Beiträgen unserer Lamarr-Forscher*innen und Einblicken in Anwendungsfälle aus der Praxis. Also bleiben Sie dran! Sie wollen keinen Beitrag mehr verpassen? Dann melden Sie sich jetzt zu unserem Newsletter an und folgen Sie uns auf Twitter und LinkedIn.

Autor*in

Anna
Höpfner

Anna Höpfner studiert Informatik an der Rheinischen Friedrich-Wilhelms-Universität Bonn und interessiert sich besonders für Maschinelles Lernen und Künstliche Intelligenz.