图神经网络GNN入门:从原理到实战

如果你已经熟悉了CNN和RNN,第一次听到”图神经网络”这个词可能会觉得有点懵——数据不是早就被组织成图片或者序列了吗?为什么还要专门搞一个”图”的概念?

这篇文章就来帮你彻底搞清楚GNN到底是什么、为什么重要、以及怎么用它来解决实际问题。我们会从最基础的概念讲起,一直讲到代码实战,保证你看完之后能自己动手跑一个GNN模型。

一、为什么需要GNN:传统深度学习的局限

在聊GNN之前,我们先回顾一下传统深度学习模型擅长处理什么数据。

卷积神经网络CNN,大家最熟悉的应用场景是图像处理。图片本质上是一个规则的网格结构——每个像素都有固定的位置,相邻像素之间的空间关系是明确的。CNN通过卷积核在网格上滑动,捕捉局部特征,然后把局部特征组合成全局表示。这套方法在图像分类、目标检测、语义分割等任务上取得了巨大成功。

循环神经网络RNN及其变体LSTM、GRU,最擅长处理序列数据。文本、语音、时间序列,这些数据天然带有前后顺序,RNN通过隐藏状态的传递来捕捉序列中的时序依赖关系。

但是,现实世界中有很多数据并不具备规则的空间结构或时序顺序,它们本质上是一种图结构

比如社交网络——你是节点,你的好友关系是边,扎克伯格和你是好友,你的好友和扎克伯格也是好友,但这种关系完全不规则,没有人会说”第5个好友在第8个好友的左上角”。

比如分子结构——原子是节点,化学键是边,分子的性质很大程度上由其拓扑结构决定,但分子可不会乖乖排成28×28的网格让你做卷积。

比如推荐系统——用户和商品可以构成二分图,用户之间的行为相似性、商品之间的类别关系,都可以用图来表达。

比如交通网络——路口是节点,道路是边,车流量在图上传播,但道路的连接方式完全取决于城市的真实布局,不可能每条路都和其他路等距相邻。

这些场景的共同特点是:数据之间的关系本身是信息的重要组成部分,而且这种关系是任意、非规则的。用CNN处理吧,数据不是网格;用RNN处理吧,数据没有固定顺序。你很难把一个社交网络”铺平”成矩阵然后塞进卷积核——强行这么做会丢失大量结构信息。

GNN的出现,就是为了解决这个根本性的问题:如何让神经网络学会处理任意结构的图数据

从2017年GCN论文发表开始,GNN迅速成为深度学习领域最热门的研究方向之一。如今,GNN已经被广泛应用在社交网络分析、推荐系统、分子性质预测、药物发现、交通流量预测、知识图谱推理等各个领域。

二、图的基本概念:节点、边、矩阵表示

在深入GNN之前,我们需要先搞清楚一些图论的基础知识。这些概念看起来是数学,但它们直接决定了GNN的工作方式。

2.1 节点与边

一个图G由两部分组成:顶点集V(Vertex)和边集E(Edge)。我们通常用G = (V, E)来表示。

**节点(Vertex/Node)**是图中的基本单位。在不同的应用场景中,节点代表不同的实体:

  • 社交网络中,节点是一个人
  • 分子图中,节点是一个原子
  • 引用网络中,节点是一篇论文
  • 知识图谱中,节点是一个实体

**边(Edge)**连接两个节点,表示它们之间存在某种关系。边可以是有向的,也可以是无向的:

  • 无向图:边没有方向,表示对称关系。比如你和你的朋友,关系是双向的。
  • 有向图:边有方向,表示非对称关系。比如Twitter的关注关系,我可以关注你但你不一定关注我。

边还可以有权重,用来表示关系的强弱:

  • 社交网络中,边的权重可以是两个人之间的互动频率
  • 交通网络中,边的权重可以是道路长度或通行时间

2.2 邻接矩阵

邻接矩阵(Adjacency Matrix)是最常用的图表示方法之一。假设图中有N个节点,我们用一个N×N的矩阵A来表示它们之间的连接关系。

对于无向图:如果节点i和节点j之间有边,则A[i,j] = 1,否则A[i,j] = 0。因为关系是对称的,所以邻接矩阵是对称矩阵。

对于带权图:如果节点i和节点j之间有边,权重为w,则A[i,j] = w,否则A[i,j] = 0。

对于有向图:通常把第i行第j列理解为”从节点i指向节点j的边”。如果”用户i关注了用户j”,那么A[i,j] = 1。

举一个具体的例子。假设有四个人构成的社交网络:A和B是朋友,B和C是朋友,C和D是朋友,D和A是朋友。这个图的邻接矩阵大约是:

     A  B  C  D
A [ 0  1  0  1 ]
B [ 1  0  1  0 ]
C [ 0  1  0  1 ]
D [ 1  0  1  0 ]

这是一个对称矩阵,因为”是朋友”这个关系是双向的。

2.3 度矩阵

度(Degree)是一个节点拥有的边的数量。对于无向图,节点的度就是与它相连的边的总数。对于有向图,我们通常区分入度(指向该节点的边数)和出度(从该节点指出的边数)。

度矩阵(Degree Matrix)是一个对角矩阵D,对角线上的元素D[i,i]等于节点i的度,其他位置都是0。对于上面的例子,四个节点的度都是2,所以度矩阵是:

     A  B  C  D
A [ 2  0  0  0 ]
B [ 0  2  0  0 ]
C [ 0  0  2  0 ]
D [ 0  0  0  2 ]

度矩阵和邻接矩阵的关系很密切。在后续讲到的很多算法中,你会看到D和A被放在一起使用。

2.4 拉普拉斯矩阵

拉普拉斯矩阵(Laplacian Matrix)是图论中最重要的概念之一,也是很多GNN模型的数学基础。它定义为:

L = D - A

也就是说,拉普拉斯矩阵等于度矩阵减去邻接矩阵。

还是上面的例子,拉普拉斯矩阵是:

     A  B  C  D
A [ 2 -1  0 -1 ]
B [-1  2 -1  0 ]
C [ 0 -1  2 -1 ]
D [-1  0 -1  2 ]

拉普拉斯矩阵有一些非常好的数学性质:

性质一:对称半正定。L是一个对称矩阵,而且它的所有特征值都大于等于0。这意味着L可以进行特征分解,可以做谱分析。

性质二:行和为零。每一行的和都是0,所以L至少有一个特征值为0,对应的特征向量是全1向量。

性质三:二次型表达图的平滑性。对于一个定义在节点上的信号x,x^T L x = 1/2 × Σ_{(i,j)∈E} (x_i - x_j)^2。这个公式的几何意义非常重要:它衡量的是信号在相邻节点之间的变化程度。如果相邻节点的信号值相近,这个值就小;如果变化剧烈,这个值就大。

这个”平滑性”的含义直接启发了GNN的设计:我们在聚合邻居信息的时候,自然希望聚合结果是平滑的,也就是说相邻节点的表示应该趋于一致。这就是为什么拉普拉斯矩阵在谱域GNN中扮演核心角色。

在实际计算中,我们还经常使用标准化的拉普拉斯矩阵

  • 对称归一化:L_sym = D^{-1/2} L D^{-1/2} = I - D^{-1/2} A D^{-1/2}
  • 随机游走归一化:L_rw = D^{-1} L = I - D^{-1} A

这两种归一化方式在后续的GCN等模型中都会用到。

三、图卷积网络GCN:从谱域到空域

现在我们进入GNN最经典、最基础的部分:图卷积网络(Graph Convolutional Network)。

GCN的核心思想其实和普通CNN很像:每个节点的表示应该是它自己和它邻居的某种加权组合。但关键的区别在于,在规则网格上,我们可以定义固定的卷积核大小和滑动窗口;而在图上,邻居的数量和连接方式都是任意的。

为了解决这个问题,研究者们走了两条路:谱域方法空域方法。我们先讲谱域,因为它更早出现,数学上更优美。

3.1 谱域方法:借助图信号处理的视角

谱域方法的核心思想是把图的拓扑结构通过拉普拉斯矩阵的特征分解映射到”谱域”——也就是特征值构成的空间。在谱域中,我们可以像处理普通信号一样,对图信号进行滤波操作。

具体来说,对于图信号x(一个N维向量,x_i表示节点i的信号值),我们定义图上的傅里叶变换为:

x̂ = U^T x

其中U是拉普拉斯矩阵特征向量构成的矩阵(满足U U^T = I)。对应的逆变换是:

x = U x̂

在谱域中进行滤波,就是把信号乘以一个对角滤波器g:

x_filtered = U g_θ U^T x

这里的g_θ是一个对角矩阵,对角线上的值是滤波器参数,通常写成关于特征值的函数g_θ(Λ),其中Λ是特征值组成的对角矩阵。

这就是谱域图卷积的基本形式。它的工作流程是:先把信号变换到谱域,在谱域中做滤波(相乘),再变回来。

但是,这个计算方式有一个严重的问题:特征分解的计算复杂度是O(N^3),对于大规模图来说完全不可接受。更实际的做法是对滤波器g_θ做参数化,使得计算可以在O(E)的复杂度内完成。

ChebNet做了一个关键的改进:它使用切比雪夫多项式来近似滤波器。具体地:

g_θ(Λ) ≈ Σ_{k=0}^{K-1} θ_k T_k(Λ̃)

其中Λ̃ = 2Λ/λ_max - I是对特征值做了归一化,T_k是k阶切比雪夫多项式,θ_k是学习参数。这样一来,滤波操作变成了:

x_filtered = U g_θ(Λ) U^T x = Σ_{k=0}^{K-1} θ_k T_k(L̃) x

而T_k(L̃) x的计算可以通过递归的切比雪夫多项式递推关系来实现,不需要显式的特征分解,计算复杂度降到了O(K E)。

3.2 GCN的简化与空域直觉

2017年的经典GCN论文在ChebNet的基础上做了进一步的简化,使用了K=1的一阶近似:

假设K=1(只看直接邻居),并且进一步约束参数θ_0 = -θ_1 = θ,那么经过推导,GCN的层间传播公式变成了:

H^{(l+1)} = σ(D^{-1/2} A D^{-1/2} H^{(l)} W^{(l)})

这就是经典的GCN前向传播公式。让我逐项解释:

  • A是邻接矩阵,加上自环(I + A)意味着每个节点在聚合邻居信息之前,先把自己加进来。
  • D^{-1/2} A D^{-1/2}是对称归一化后的邻接矩阵,它保证了对每个节点,所有邻居的聚合权重之和为1。
  • H^{(l)}是第l层的节点特征矩阵,每一行是一个节点的表示向量。
  • W^{(l)}是第l层的可学习参数矩阵。
  • σ是激活函数,通常是ReLU。

这个公式的物理含义非常直观:每个节点的下一层表示,等于它自己和所有邻居的当前表示的加权平均,再经过一个线性变换和非线性激活

这就是GCN从谱域走向空域的关键一步。虽然它的数学推导来自谱域,但最终的公式在空域上有非常清晰的解释:消息传递 + 聚合

3.3 GCN的局限性

GCN虽然简洁有效,但它有几个明显的局限:

局限一:感受野固定且有限。GCN每一层只能聚合直接邻居(一阶邻域)的信息。如果你想聚合二阶邻居的信息,需要叠加多层GCN。但叠加层数过多会导致过平滑问题——所有节点的表示趋于相同。

局限二:无法处理有向图和异构图。GCN的对称归一化假设边是对称的,这对于无向图成立,但对于有向图(比如知识图谱中的关系)就不适用了。

局限三:参数效率不高。每层的参数矩阵W是全局共享的,无法捕捉不同类型邻居的不同重要性。

这些局限推动了后续模型的发展——GraphSAGE用采样和多样化的聚合函数解决感受野问题,GAT用注意力机制解决邻居权重问题,R-GCN等模型处理异构图。

四、GraphSAGE:采样+聚合机制

GraphSAGE是2017年由斯坦福大学团队提出的,它的全称是”Graph Sample and Aggregate”,即图采样与聚合。这个名字直接点出了它的核心思想:不一次性聚合所有邻居,而是采样固定数量的邻居,然后通过可学习的聚合函数来组合信息

4.1 邻居采样

GraphSAGE的第一个创新是邻居采样。在真实世界的图中,节点的度分布往往是幂律分布——少数节点有大量邻居(比如微博大V),大多数节点只有少量邻居。

如果一个节点有一百万个粉丝,GCN在第一层就要聚合一百万条边的信息,这在计算上是完全不可行的。而且,一百万个邻居的信息可能存在大量冗余,聚合这么多信息反而会带来噪声。

GraphSAGE的解决方案是对每个节点的邻居进行随机采样。假设我们设置采样数量为S,那么无论一个节点实际有多少邻居,聚合时只使用其中的S个。

采样策略可以是:

  • 均匀随机采样:最简单直接
  • 重要性采样:根据邻居的重要性(比如度数的倒数)加权采样
  • 多阶采样:第一层采样S_1个,第二层在前一层采样的基础上再采样S_2个,以此类推

这样,GraphSAGE的计算复杂度就从O(节点度数)降低到了O(S),变成了常数级别。

4.2 聚合函数

GraphSAGE的第二个创新是多样化的聚合函数。GCN实际上使用的是均值聚合(因为对称归一化本质上是对邻居表示求平均),但GraphSAGE允许使用更丰富的聚合方式。

Mean聚合(均值聚合)

这是最接近GCN的方式,计算自己和邻居表示的均值:

h_v^{(l+1)} = σ(W · MEAN({h_v^{(l)}} ∪ {h_u^{(l)}, ∀u ∈ N(v)}))

LSTM聚合

利用LSTM来处理邻居的顺序信息。由于邻居本身没有顺序,LSTM聚合会对邻居进行随机排列,把打乱顺序的邻居序列输入LSTM:

h_v^{(l+1)} = LSTM([h_u^{(l)}, ∀u ∈ π(v)])

其中π(v)是节点v邻居的随机排列。LSTM聚合能够捕捉更复杂的邻居交互模式,但计算开销也更大。

Pool聚合(池化聚合)

先对每个邻居的表示做一次非线性变换(MLP),然后对所有变换后的邻居表示做池化(通常是最大池化):

h_v^{(l+1)} = σ(W · pool({MLP(h_u^{(l)}), ∀u ∈ N(v)}))

Pool聚合的好处是能够捕捉邻居表示的”亮点”——即使大多数邻居的某个维度值都不高,但只要有一个邻居在这个维度上特别突出,最大池化就会捕捉到这个信号。

GraphSAGE的完整前向传播算法如下:

输入:图、节点特征、邻居采样函数
1. 对每个节点v,采样固定数量的邻居N(v)
2. 对每个邻居u ∈ N(v),递归地获取其邻居表示
3. 聚合邻居表示:h_{N(v)} = AGGREGATE({h_u, ∀u ∈ N(v)})
4. 更新节点表示:h_v = σ(W · CONCAT(h_v, h_{N(v)}))
5. 可选:归一化h_v

五、GAT:注意力机制在图上的应用

如果说GraphSAGE解决了”采样哪些邻居”的问题,那么GAT(Graph Attention Network)解决的是另一个问题:如何为不同邻居分配不同的重要性

在GCN和GraphSAGE中,邻居的聚合权重是由图的结构决定的——要么是固定的归一化系数,要么是简单的平均。GAT的创新在于:让模型自己学习不同邻居的重要性权重

5.1 注意力机制

GAT使用了一种叫做缩放点积注意力的机制。对于节点v的邻居u,注意力系数e_{vu}计算为:

e_{vu} = a(W h_v, W h_u)

其中a是一个前馈神经网络,W是将节点表示映射到相同空间的线性变换。a的输出是一个标量,表示节点u对节点v的重要程度。

为了让不同节点的注意力系数可以比较,我们对所有邻居的注意力系数做softmax归一化:

α_{vu} = softmax_u(e_{vu}) = exp(e_{vu}) / Σ_{k∈N(v)} exp(e_{vk})

最终的聚合表示是所有邻居表示的加权平均:

h_v’ = σ(Σ_{u∈N(v)} α_{vu} W h_u)

这就是GAT的单层前向传播公式。它和GCN的核心区别在于:权重α_{vu}不是由图结构决定的,而是由模型学习得到的

5.2 多头注意力

单层注意力可能不够稳定和表达力不足,GAT引入了多头注意力机制:并行运行多个独立的注意力机制,然后把结果拼接起来。

h_v^{(final)} = CONCAT_{k=1}^{K} (σ(Σ_{u∈N(v)} α_{vu}^k W^k h_u))

其中K是注意力头的数量,α_{vu}^k和W^k是第k个头的参数。

在最后一层,多头注意力的输出也可以换成求平均:

h_v^{(final)} = σ(1/K · Σ_{k=1}^{K} Σ_{u∈N(v)} α_{vu}^k W^k h_u)

多头注意力有几个好处:

好处一:稳定性。类似于Dropout,不同的注意力头学习到的权重模式可能不同,拼接或平均之后能抵消一些噪声。

好处二:捕捉多方面的关系。不同的头可能关注邻居的不同方面,比如在社交网络中,一个头可能关注”互动频率”,另一个头可能关注”共同好友数量”。

好处三:表达能力。多个头增加了模型的参数量和表达能力。

5.3 GAT vs GCN

GAT相比GCN有几个显著优势:

特性GCNGAT
邻居权重由图结构决定(度数的归一化)由模型学习得到
处理异构图困难较容易(不同关系可用不同注意力)
对噪声边的鲁棒性较低(所有邻居等权聚合)较高(可以给噪声邻居低权重)
计算复杂度O(VE)O(VE)(多了注意力计算)

GAT的灵活性使它在很多任务上取得了比GCN更好的效果,尤其是当边带有不同语义(比如知识图谱中的不同关系类型)时。

六、图同构网络GIN:Weisfeiler-Lehman测试

在讲GIN之前,我们需要先理解一个图论中的经典问题:如何判断两个图是否同构

图同构问题是一个计算复杂度很高的难题(目前没有已知的多项式时间算法)。但是,有一个非常巧妙的方法可以在大多数情况下高效地判断图是否同构:Weisfeiler-Lehman(WL)测试

6.1 一维Weisfeiler-Lehman测试

WL测试的思路很简单:

  1. 给每个节点分配一个初始颜色(可以是基于度的哈希值)
  2. 对于每个节点,把它的颜色和所有邻居的颜色放在一起,组成一个多重集
  3. 对这个多重集做一个哈希,得到节点的新颜色
  4. 重复上述过程,直到颜色分布不再变化

如果两个图在WL测试后颜色分布不同,我们就可以确定它们不是同构的。如果颜色分布相同,则两个图可能是同构的(WL测试不是完美的图同构判定算法,但在很多场景下足够有效)。

6.2 GIN的设计原则

GIN(Graph Isomorphism Network)的作者提出了一个核心问题:什么样的GNN能够达到WL测试的表达能力?

他们证明了一个非常重要的结论:一个GNN如果能够区分不同节点的多重集(multiset),并且聚合函数满足单射性(injective),那么它就具有和WL测试一样的表达能力。

为了实现这个目标,GIN的设计者选择了多层感知机(MLP)+ Sum聚合的组合。

GIN的更新公式是:

h_v^{(l+1)} = MLP^{(l)}( (1 + ε^{(l)}) · h_v^{(l)} + Σ_{u∈N(v)} h_u^{(l)} )

其中ε是一个可学习的标量参数或者固定的常数。

这里有两个关键设计:

关键一:使用Sum而不是Mean或Max。Mean和Max聚合都不是单射的——它们无法区分某些不同的多重集。比如{1, 1}和{1, 2}的平均值都是1,但它们的Sum是不同的。Sum聚合能够保留更多信息。

关键二:(1 + ε) · h_v的设计。这确保了节点的自身表示被显式地纳入聚合过程,而不是被邻居的聚合淹没。如果把所有表示直接相加,自身信息可能会被大量邻居信息稀释。

GIN的理论分析是GNN发展史上一个重要的里程碑。它让我们对GNN的表达能力有了更清晰的认识,也为后续模型的设计提供了指导原则。

七、消息传递神经网络框架MPNN

前面我们讲了GCN、GraphSAGE、GAT、GIN,这些模型虽然各有特点,但它们其实都可以被纳入一个统一的框架:消息传递神经网络(Message Passing Neural Network, MPNN)

MPNN是2017年由Google的研究者提出的一个抽象框架,它把GNN的操作分成两个阶段:消息传递节点更新

7.1 消息传递阶段

对于每个节点v,我们在每一层迭代地:

  1. 从每个邻居u那里接收一个消息m_{uv}
  2. 聚合所有邻居的消息

形式化地,消息函数通常定义为:

m_{uv}^{(l)} = MSG^{(l)}(h_u^{(l)}, h_v^{(l)}, e_{uv})

其中h_u和h_v是节点u和v的当前表示,e_{uv}是边(u,v)的特征(如果有的话)。

7.2 节点更新阶段

聚合完邻居消息后,我们更新节点自身的表示:

h_v^{(l+1)} = UPDATE^{(l)}(h_v^{(l)}, AGGREGATE({m_{uv}^{(l)}, ∀u ∈ N(v)}))

7.3 不同模型的MPNN视角

从这个统一的框架来看:

  • GCN的消息函数是隐式的(没有显式的消息函数定义),更新函数是均值聚合加上线性变换和激活。
  • GraphSAGE的Mean聚合对应于消息就是邻居的表示本身,聚合就是简单的平均。
  • GAT的消息函数包含了注意力权重的计算:m_{uv} = α_{vu} W h_u。
  • GIN的消息就是邻居表示本身,聚合是求和。

MPNN框架的价值在于:它给了我们一个思考和设计GNN的统一视角。当你需要为一个新的图任务设计模型时,你只需要思考:什么样的消息函数和更新函数适合我的问题?

比如,在分子性质预测中,一个分子图由原子和化学键构成。如果我们希望消息函数考虑化学键的类型,就可以设计成:

m_{uv} = f(h_u, h_v, bond_type(e_{uv}))

这样,不同类型的化学键会传递不同的消息,模型就能学习到化学键类型对分子性质的影响。

八、代码实战:用PyTorch Geometric实现GCN和GAT

理论讲完了,现在我们来动手实现一个GCN和一个GAT模型,然后在一个真实数据集上训练它们。

8.1 环境准备

首先,确保你安装了PyTorch和PyTorch Geometric:

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Cora
from torch_geometric.nn import GCNConv, GATConv

PyTorch Geometric(简称PyG)是目前最流行的GNN开源库,它封装了大量经典的GNN层和常用的图数据集,可以让我们几行代码就完成模型搭建和训练。

8.2 数据加载

Cora数据集是一个经典的引用网络数据集,包含2708篇机器学习论文,5429条引用关系,7个论文类别。节点特征是1433维的词袋向量,边代表论文之间的引用关系。

dataset = Cora(root='./data/Cora')
data = dataset[0]
 
print(f"节点数: {data.num_nodes}")
print(f"边数: {data.num_edges}")
print(f"节点特征维度: {data.num_node_features}")
print(f"类别数: {dataset.num_classes}")
print(f"训练/验证/测试集划分: {data.train_mask.sum()}/{data.val_mask.sum()}/{data.test_mask.sum()}")

8.3 GCN模型实现

class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_channels=16):
        super(GCN, self).__init__()
        torch.manual_seed(42)
        self.conv1 = GCNConv(num_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, num_classes)
 
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return x
 
    def fit(self, data, epochs=100):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01, weight_decay=5e-4)
        criterion = torch.nn.CrossEntropyLoss()
 
        self.train()
        for epoch in range(epochs):
            optimizer.zero_grad()
            out = self(data.x, data.edge_index)
            loss = criterion(out[data.train_mask], data.y[data.train_mask])
            loss.backward()
            optimizer.step()
 
            if epoch % 20 == 0:
                self.eval()
                _, pred = self(data.x, data.edge_index).max(dim=1)
                correct = int(pred[data.val_mask].eq(data.y[data.val_mask]).sum())
                acc = correct / int(data.val_mask.sum())
                print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Val Acc: {acc:.4f}')
                self.train()
 
    def test(self, data):
        self.eval()
        _, pred = self(data.x, data.edge_index).max(dim=1)
        correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum())
        acc = correct / int(data.test_mask.sum())
        return acc

8.4 GAT模型实现

class GAT(torch.nn.Module):
    def __init__(self, num_features, num_classes, hidden_channels=8, heads=8):
        super(GAT, self).__init__()
        torch.manual_seed(42)
        self.conv1 = GATConv(num_features, hidden_channels, heads=heads, dropout=0.6)
        self.conv2 = GATConv(hidden_channels * heads, num_classes, heads=1, concat=False, dropout=0.6)
 
    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x
 
    def fit(self, data, epochs=100):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.005, weight_decay=5e-4)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
 
        self.train()
        for epoch in range(epochs):
            optimizer.zero_grad()
            out = self(data.x, data.edge_index)
            loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
            loss.backward()
            optimizer.step()
 
            if epoch % 20 == 0:
                self.eval()
                _, pred = self(data.x, data.edge_index).max(dim=1)
                correct = int(pred[data.val_mask].eq(data.y[data.val_mask]).sum())
                acc = correct / int(data.val_mask.sum())
                print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Val Acc: {acc:.4f}')
                self.train()
            scheduler.step()
 
    def test(self, data):
        self.eval()
        _, pred = self(data.x, data.edge_index).max(dim=1)
        correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum())
        acc = correct / int(data.test_mask.sum())
        return acc

8.5 训练与评估

num_features = data.num_features
num_classes = dataset.num_classes
 
print("训练GCN...")
gcn_model = GCN(num_features, num_classes)
gcn_model.fit(data, epochs=100)
gcn_acc = gcn_model.test(data)
print(f"GCN Test Accuracy: {gcn_acc:.4f}")
 
print("\n训练GAT...")
gat_model = GAT(num_features, num_classes)
gat_model.fit(data, epochs=100)
gat_acc = gat_model.test(data)
print(f"GAT Test Accuracy: {gat_acc:.4f}")

典型输出结果(取决于随机种子):

GCN Test Accuracy: 0.8150
GAT Test Accuracy: 0.8300

可以看到,GAT通常能取得比GCN稍好的效果,这是因为注意力机制让它能够更智能地聚合邻居信息。但这不是绝对的——在某些场景下,GCN的简单粗暴反而更有效。

九、实战案例

9.1 引用网络节点分类

节点分类是GNN最经典的应用场景之一。Cora数据集的任务就是:根据论文的内容特征(词袋向量)和论文之间的引用关系,预测每篇论文属于哪个类别。

这个任务的难点在于:我们只有部分节点有标签,其他节点的类别是未知的。GNN通过消息传递机制,能够让有标签节点的信息”流”向无标签节点,从而实现半监督学习。

具体流程是:

  1. 特征编码:首先用GCNConv或GATConv对节点特征进行编码,每一层的消息传递让节点能够”看到”越来越远的邻居
  2. 标签传播:有标签节点的监督信号通过反向传播影响所有节点的特征表示
  3. 分类预测:最终层的节点表示通过一个线性分类器输出类别预测

这种方法比传统的半监督方法(如标签传播)更强大,因为它能同时学习节点特征图结构,而不是只依赖图结构。

9.2 知识图谱补全

知识图谱是一种特殊的图,它的节点是实体,边是关系,而且每条边都标注了关系类型。比如,“北京”——(位于)——“中国”,“北京”——(人口)——“2154万”。

知识图谱补全的任务是:根据已有的三元组(头实体,关系,尾实体),预测缺失的三元组。比如,给定”北京”和”位于”,预测”中国”。

这类任务通常使用关系图卷积网络(R-GCN)或者TransE/TransR等距离模型来解决。

R-GCN的核心改进是在消息函数中加入了关系类型:

h_v^{(l+1)} = σ(Σ_{r∈R} Σ_{u∈N(v)^r} W_r^{(l)} h_u^{(l)} / |N(v)^r| + W_0^{(l)} h_v^{(l)})

其中R是所有关系类型的集合,N(v)^r是节点v在关系r下的邻居。

最近,大型语言模型(LLM)的兴起也为知识图谱补全带来了新的可能性。你可以把知识图谱补全理解为一个链接预测任务,LLM通过阅读图谱中的文本描述,能够推断出缺失的链接。

结语

GNN是一个快速发展的领域,从2017年GCN的开山之作,到GraphSAGE、GAT、GIN等模型的百花齐放,再到MPNN的统一框架,GNN在短短几年内就走完了CNN几十年的发展历程。

如果你想进一步深入GNN的学习,我有几个建议:

建议一:从论文原文开始。GCN、GraphSAGE、GAT、GIN这些经典论文都不长,公式推导也很清晰。读懂原文比看十篇解读文章更有收获。

建议二:多跑实验。PyTorch Geometric让GNN实验变得非常容易。试着在不同的数据集上运行不同的模型,改变超参数,观察结果的变化。你会发现很多insight是实验出来的,不是看出来的。

建议三:理解问题的图结构。GNN的效果很大程度上取决于图的结构是否与你的任务相关。如果你的图结构是任务的自然表达(比如分子图、交通图),GNN往往能发挥巨大威力;如果图结构是人为强加的,效果可能就不如直接在节点特征上做MLP。

建议四:关注最新进展。GNN领域的论文更新非常快。如果你想了解最新的研究成果,可以关注NeurIPS、ICML、ICLR等顶级会议的相关论文。

希望这篇文章能帮你建立起GNN的完整知识框架。动手实践吧,GNN的世界比你想象的更有趣。