GraphSAGE: Inductive Representation Learning on Large Graphs
Graph representation learning has emerged as a significant area of research, offering methods to analyze abstract graph structures in information networks. These methods enhance the performance of machine learning solutions for real-world applications such as social recommendations, targeted advertising, and user search.
Introduction to Graph Representation Learning
Graphs are fundamental data structures consisting of vertices (nodes) and edges (relationships) that model connections between objects. The terms "graph" and "network" are often used interchangeably. Extracting insights from graph data, known as network mining, has applications in social computing, biology, recommender systems, and language modeling.
Graph representation learning encodes graphs into a low-dimensional vector space to facilitate efficient graph data analysis. Graph embeddings are low-dimensional representations that minimize redundancy and noise, avoiding the need for complex algorithms on the original graph. These vectors serve as features for downstream tasks like node classification, link prediction, clustering, and recommendations. The quality of these learned vectors is crucial, and embeddings can be generated for nodes, edges, subgraphs, or entire graphs, with node embeddings being the most common focus.
Many approaches to graph representation learning leverage connections between nodes, exploiting homophily-the tendency for nodes to share attributes with their neighbors. Representation methods assign similar values to neighboring nodes, reflecting the principle that similar nodes are often connected.
The goal of representation learning is to incorporate graph structural information into machine learning algorithms. The challenge lies in encoding high-dimensional, non-Euclidean graph information into feature vectors, ensuring that similar vertices in the original graph are represented similarly in the learned vector space. This similarity is reflected in both the local community structure and the global network structure.
Read also: Overview of Masked Contrastive Learning
Key Considerations in Graph Representation Learning
Preserving Content
Graph vertices often possess rich attribute content that provides attribute-level node similarity information.
Real-World Networks
In real-world networks, both the network structure and node attributes are missing due to privacy or legal restrictions.
Enterprise Networks
Enterprise networks often contain millions or billions of nodes, posing scalability challenges.
Graph Definition
Formally, a graph G can be defined as a tuple G = (V, E), where V is the set of nodes/vertices and E is the set of edges. A graph can be directed or undirected.
Homogeneous and Heterogeneous Graphs
Let T_V denote the set of node types and T_E denote the set of edge types. A homogeneous graph has only one node type and one edge type (i.e., |T_V| = 1 and |T_E| = 1). Real-world networks are always evolving, with nodes or edges being added or removed between snapshots.
Read also: A Guide to CPC
In a multi-relational graph G = (V, E), each edge is defined as a tuple e_k = (v_i, τ, v_j), where τ represents a particular relationship τ ∈ T between the two nodes.
First-Order Proximity
Nodes connected by an edge have first-order proximity.
k-Partite Graphs
In many real-world graphs, edges are allowed between groups but not within the same group. A bi-partite graph has two groups of nodes where nodes of one group can only connect to nodes of the other group. In k-partite graphs, "partite" refers to the partitions of node groups, and "k" refers to the number of those partitions.
Transductive vs. Inductive Methods
Embedding methods can be broadly classified as transductive and inductive. Transductive methods assume the complete graph is known beforehand, utilizing both labeled and unlabeled data for training and generating an embedding lookup table. To generate an embedding for a new node, transductive methods must retrain the entire set of nodes.
Inductive learning techniques are equivalent to supervised learning and do not assume the input graph to be complete. Trained on training data, they can be applied to newly observed data outside the training sets, generalizing well if the new data follows the same distributional assumptions as the training data.
Read also: Learn about Matryoshka Representation Learning
Laplacian Matrix
The degree matrix minus the Adjacency Matrix (D-A).
Structural Similarity
Computing the structural similarity between nodes is a central component in graph representation learning. This similarity conveys the relative positioning of nodes in the learned embedding space.
Common Neighbors
The number of common neighbors for a given pair of nodes. This metric is specifically used for networks in which the number of nodes increases over time.
PageRank
This metric uses PageRank with a random walk.
Graph Statistics
Before deep learning, machine learning methods relied on graph statistics such as size (number of edges), order (number of nodes), degree distribution, and connectedness. Nodes have structural properties like degree and centrality, and non-structural properties like node attributes. Centrality indicators like eigenvector centrality rank the likelihood that a node is visited on a random walk of infinite length on the graph. Betweenness centrality measures how often a node lies on the shortest path between two nodes, while closeness centrality measures the average shortest path between a node and all other nodes. Edges may also have associated edge weights.
These features are not very helpful in designing flexible functions, and finding appropriate features for a given application can be time-consuming.
Encoder-Decoder Framework
Hamilton proposed an encoder-decoder framework for graph representation learning based on neighborhood reconstruction methods. An encoder maps graph nodes into a low-dimensional vector, and a decoder uses this vector embedding to reconstruct original information about nodes’ neighbors. In many systems, the encoder is a lookup table based on node IDs, known as shallow embedding methods.
Adjacency Matrix
One way to represent a graph is to use its adjacency matrix, a square matrix whose elements indicate whether pairs of vertices are adjacent. This matrix can be weighted and symmetric depending on the type of graph. Alternatives include the Laplacian matrix and node transition probability matrix.
These methods are not very effective because the simple matrix format isn’t informative enough and can’t store complex relationships such as paths and frequent subgraphs.
Random-Walk Based Models
A random walk represents a sequence of nodes where the next node is selected based on a probabilistic distribution, with no restriction on backtracking or revisiting nodes. The goal is to extract rich structural information from the local structure of the graph. A truncated random walk transforms graph input to node sequences, and an NLP-based approach uses a skip-gram model to maximize the conditional probability of association between each input and its surrounding context. Deepwalk was one of the first models based on this idea.
End-to-End Deep Learning
The methods discussed above used a piecemeal approach where a model is used to learn node embedding which is then fed to a downstream model for performing a task. Deep learning enables end-to-end solutions for these graph-based tasks and inherently learns the embedding in the architecture. One naive way of using deep learning for graphs is to directly input the adjacency matrix as input to a neural network. However, this approach is not permutation invariant and depends on the arbitrary ordering of the nodes in the input matrix. Hence inductive biases must be built into the deep architectures used for graph data to effectively model this non-Euclidian, non-i.i.d.
Graph Neural Networks (GNNs)
There has been a surge in research on using deep learning for graph representation learning, including techniques to generalize convolutional neural networks to graph-structure data, and neural message-passing approaches inspired by belief propagation. Graph Neural Networks (GNNs) are one of the most popular deep-learning based solutions. GNNs use the Message Passing technique to encode information about the graph by successively propagating a message one hop further away in each iteration. There are diminishing returns involved when using information from multiple message-passing layers. Hence GNNs are not as deep as some of the more conventional neural networks. Different GNN models utilize different message-passing schemes.
Evaluation of Learned Network Representations
Various researchers have proposed different taxonomies to efficiently navigate the graph representation learning domain. One such intuitive approach is proposed by Khoshraftar et. Various downstream network analytic tasks are commonly used to evaluate learned network representations. We can use the learned network embeddings to reconstruct the network and examine the differences between the reconstructed and original network. We can also use node embeddings as features to train a classifier on labeled nodes. We first remove a small number of edges of the original network, and then use embeddings to predict the removed edges. We can apply k-means clustering on the learned embeddings and evaluate the detected communities. The learned embedding space can be reduced to 2 or 3 dimensions by applying t-SNE or PCA.
Challenges and Future Directions
The vast majority of the literature in graph representation learning is focused on non-attributed, static, and unsigned homogeneous networks. This means that there are several challenges that need more extensive research when applying these techniques to real-world graphs. For example, the propagation of information between nodes and links in heterogeneous networks is much more complex than the propagation in homogeneous networks. Applying graph learning to enterprise graphs with billions of nodes and edges is also very complex. Fortunately, various proposals have been made to learn effective graph representations from heterogeneous and dynamic networks. And, many recent research works have proposed adjustments to the training and inference algorithms, and storage techniques to enable learning from large-scale graphs.
GraphSAGE: An Inductive Approach
GraphSAGE (Graph SAmple and aggreGatE) is a framework for inductive representation learning on large graphs. It generates low-dimensional vector representations for nodes and is especially useful for graphs with rich node attribute information.
The Need for Inductive Node Embeddings
Prior to GraphSAGE, most node embedding models were based on spectral decomposition or matrix factorization methods. These methods are inherently transductive, meaning they do not perform well on data they have never seen before. They expect the entire graph structure to be present at training time to generate node embeddings. If a new node is added to the graph later, the model must be retrained.
An inductive approach, conversely, can generalize to unseen data, making it more useful.
The Main Idea Behind GraphSAGE
GraphSAGE learns a representation for every node based on a combination of its neighboring nodes. Every node can have its own feature vector, and the algorithm runs for k iterations, resulting in a node representation h for every node at every k iteration.
The GraphSAGE Algorithm
The GraphSAGE algorithm follows a two-step process:
- Aggregate: Aggregate neighboring node representations for the target node. The f_aggregate function is a placeholder for any differentiable function, such as an averaging function or a neural network. This step aggregates the embedding vectors for all nodes u in the immediate neighborhood of the target node v, resulting in the aggregated node representation a_v.
- Update: After obtaining an aggregated representation for node v based on its neighbors, update the current node v using a combination of its previous representation and the aggregated representation. The f_update function is a placeholder for any differentiable function, such as an averaging function or a neural network. This step creates an updated representation for node v based on its neighborhood aggregated representation and node v’s previous representation.
The k parameter tells the algorithm how many neighborhoods or hops to use to compute the representation for node v.
Implementation Details
- Aggregator Functions: The authors of the GraphSAGE paper experimented with a variety of aggregator functions, including max-pool, mean aggregation, and LSTM aggregation. The LSTM aggregation method required the nodes to be shuffled every k-iteration to avoid temporally favoring any one node when computing the aggregation.
- Update Function: The f_update function in the paper was a concatenation operation. After concatenation, the shape of the output was of dimensionality (2F,1). The concatenated output undergoes a transformation by matrix multiplication with a weight matrix W^k. This weight matrix is intended to reduce the dimensionality of the output to (F,1). Finally, the concatenated and transformed node embedding vector undergoes a non-linearity.
- Weight Matrices: There is a separate weight matrix for each k-iteration. This has the interpretation of learning weights that have a sense of how important multiple neighborhoods are to the target node.
- Normalization: The node embedding is normalized by dividing by the vector norm to prevent gradient explosion.
Unsupervised Loss Function
The authors train both unsupervised and supervised GraphSAGE models. The supervised setting follows a regular cross-entropy style prediction for a node classification task. The unsupervised case, however, tries to preserve graph structure by enforcing the following loss function:
The loss function enforces that if nodes u and v are close in the actual graph, then their node embeddings should be semantically similar. Conversely, if nodes u and v are far away in the actual graph, their node embeddings should be different/opposite.
Aggregation Functions in Detail
The neighbors of a given node are unordered. The aggregation function should output the same result regardless of the order of neighborhood nodes.
Mean Aggregator
The mean aggregator is nearly equivalent to the convolutional propagation rule used in the GCN framework. This is because the aggregation function is only applied to a given set of local neighbors, unlike how a GCN operation considers the entire neighbor set normalized by the degree. The concat operation improves performance.
LSTM Aggregator
LSTMs (Long Short-Term Memory networks) are an RNN architecture designed to process sequential data by maintaining memory of previous inputs (time steps). LSTMs are sensitive to sequence, making them asymmetric. The writers simply ignore this and have given a random permutation of neighbors to the LSTM. LSTMs capture long-range dependencies and context.
Pooling Aggregator
The pooling aggregator works differently. It does not directly operate on the node features of the neighborhood nodes. Instead, there is a fully connected layer that has trainable weights which operate on the node features. A max pooling operation is done on the vectors resulting from the outputs of the MLP, capturing different aspects of the neighborhood set.
Conclusion
GraphSAGE is a powerful technique for inductive representation learning on large graphs. By aggregating neighboring node embeddings for a given target node, GraphSAGE can find new node representations for every node in the graph. Stacked layers of GraphSAGE can create complex, structural, and semantic-level features for any downstream task.
tags: #inductive #representation #learning #on #large #graphs

