Source code for malt.models.zoo.gat

from functools import partial
import torch
import dgl

[docs]class ConcatenationAttentionHeads(torch.nn.Module):
[docs] def __init__( self, in_features: int, out_features: int, num_heads: int = 4, layer: type=dgl.nn.GATConv, ): super().__init__() self.layer = layer( in_features, out_features // num_heads, num_heads, allow_zero_in_degree=True, ) self.__doc__ = self.layer.__doc__
[docs] def forward(self, graph, feat): feat = self.layer(graph, feat) feat = feat.flatten(-2, -1) return feat
GAT = partial(ConcatenationAttentionHeads, layer=dgl.nn.GATConv) GATDot = partial(ConcatenationAttentionHeads, layer=dgl.nn.DotGatConv)