from functools import partial
import torch
import dgl
[docs]class ConcatenationAttentionHeads(torch.nn.Module):
[docs] def __init__(
self,
in_features: int,
out_features: int,
num_heads: int = 4,
layer: type=dgl.nn.GATConv,
):
super().__init__()
self.layer = layer(
in_features, out_features // num_heads, num_heads,
allow_zero_in_degree=True,
)
self.__doc__ = self.layer.__doc__
[docs] def forward(self, graph, feat):
feat = self.layer(graph, feat)
feat = feat.flatten(-2, -1)
return feat
GAT = partial(ConcatenationAttentionHeads, layer=dgl.nn.GATConv)
GATDot = partial(ConcatenationAttentionHeads, layer=dgl.nn.DotGatConv)