本文是为了解释 OGB LSC 中使用 DGL 实现 GCNconv 的代码。
GCN
邻接矩阵, , 是对角度矩阵, ,
要在 GCN 中加入边信息, 对于单个节点的更新
假设传入的图 g
是无向图并且没有加入自环(例如,ogb smile2graph 中将分子从SMILES转化为分子图时,没有加入自环)。如下的代码表示 ,为了节省内存,实际上就是度向量,而且我们没有向 g
中加入自环(当然也可以这样做)。这样每个节点的度至少为1,不会出现 1 / degs
为 inf
的情况。
degs = (g.out_degrees().float() + 1).to(x.device)
接下来对度矩阵取-1/2幂,
deg_inv_sqrt = torch.pow(degs, -0.5).unsqueeze(-1) # (N, 1)
g.ndata["norm"] = deg_inv_sqrt
使用 apply_edges
为边增加特征 norm
,
g.apply_edges(fn.u_mul_v("norm", "norm", "norm"))
我们不更新边的特征,在每层对原始的边特征做嵌入。节点 传递到节点 ()传递的消息为,
边的嵌入,根据原始边特征的表示方式可以有两种方案:
- one-hot + nn.Linear
- index + nn.Embedding
x = self.linear(x)
g.ndata["x"] = x
g.apply_edges(fn.copy_u("x", "m"))
g.edata["m"] = g.edata["norm"] * F.relu(
g.edata["m"] + edge_embedding)
加下来只需要聚合函数更新节点特征,
g.update_all(fn.copy_e("m", "m"), fn.sum("m", "new_x"))
接下来我们还需要两个操作,
- 第一个是由于我们没有加入自环,所以上述操作不会聚合自己上一层的信息.
- 在消息传递过程中,我们加入了边的信息,若想等价于先加入自环再作用GCN的效果,我们同时需要传递自环的信息(自环本身就是边),因此为自环设置一个单独的 embedding
root_emb = nn.Embedding(1, emb_dim)
.
out = g.ndata["new_x"] + F.relu(
x + self.root_emb.weight
) / degs.view(-1, 1)
Self-loop feature
设置 root_emb
可以看成是自环的替代方案,否则需要为边的特征加入一维表示该边是否为自环,这种方法,从实现层面可以有两种。
one-hot
如果边的特征使用 one-hot 向量表示的,DGLlife中的 BondFeaturizer 是这样表示的。例如若第一个特征有3个取值,第2个特征有2个取值,对于两条边表示如下
[1, 0, 0, | 0, 1]
[0, 0, 1, | 1, 0]
则可以这样加入自环,
[1, 0, 0, | 0, 1, | 0]
[0, 0, 1, | 1, 0, | 0]
[0, 0, 0, | 0, 0, | 1]
之后在 GCNconv 的每层对边特征接一个 nn.Linear
,可达到相同的效果。
index
边的特征还可以是由每一个特征对应的 index 表示, 这样可以减少内存的消耗,例如对于上面的例子,边的特征可以表示为
[0, 1]
[2, 0]
通常后面接 nn.Embedding
得到边的嵌入。OGB 中 smiles2graph 是这样得到 边特征 的。
若是加入自环,则可以这样表示加入的自环: 将原始的 index + 1,即每个维度的取值个数都+1,并加入是否自环这一特征,
[1, 2, 0]
[3, 1, 0]
[0, 0, 1]
这样每一维中 index=0 就表示 padding 的特征,设置 nn.Embedding(padding_idx=0)
,对应的 index=0 的向量为0,这样原来图中有的边既不会增加额外信息,又为自环这一特征做了嵌入,达到相同的效果。
Comment
第二种方案需要把 ogb 的涉及到上面的代码 copy 过来做修改,并且又要加入自环,具有额外的空间开销,我想它们是想避免这个问题从而使用相同效果的 root_emb
的,可以参考 root_emb issue 中大佬对这个问题的的回复。
Code
代码来自 OGB LSC GCNconv DGL,并稍作修改。BondEncoder
是直接得到分子图中键的嵌入。
class GCNConv(nn.Module):
def __init__(self, emb_dim):
"""
emb_dim (int): node embedding dimensionality
"""
super(GCNConv, self).__init__()
self.linear = nn.Linear(emb_dim, emb_dim)
self.root_emb = nn.Embedding(1, emb_dim)
self.bond_encoder = BondEncoder(emb_dim=emb_dim)
def forward(self, g, x, edge_attr):
with g.local_scope():
x = self.linear(x)
edge_embedding = self.bond_encoder(edge_attr)
# Molecular graphs are undirected
# g.out_degrees() is the same as g.in_degrees()
degs = (g.out_degrees().float() + 1).to(x.device)
deg_inv_sqrt = torch.pow(degs, -0.5).unsqueeze(-1) # (N, 1)
g.ndata["norm"] = deg_inv_sqrt
g.apply_edges(fn.u_mul_v("norm", "norm", "norm"))
g.ndata["x"] = x
g.apply_edges(fn.copy_u("x", "m"))
g.edata["m"] = g.edata["norm"] * F.relu(
g.edata["m"] + edge_embedding
)
g.update_all(fn.copy_e("m", "m"), fn.sum("m", "new_x"))
out = g.ndata["new_x"] + F.relu(
x + self.root_emb.weight
) / degs.view(-1, 1)
return out