Source code for pydgc.modules.decoder

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from typing import List

from torch import Tensor

from . import layer_registry


[docs]class BaseDecoder(nn.Module): """Base Decoder class. Args: dims (List[int]): A list of dimensions from input to output. layer (str): Type of layers, e.g., 'linear', 'gcn', 'gat', 'sage', 'sg'. act (str): Activation function, e.g., 'relu', '' act_last (bool): Whether to apply activation function to the last layer. add_self_loops (bool): Whether to add self-loops to the graph. """ def __init__(self, dims: List[int] = None, layer: str = 'linear', act: str = 'relu', act_last: bool = False, add_self_loops: bool = True): super(BaseDecoder, self).__init__() self.act = act self.act_last = act_last self.add_self_loops = add_self_loops self.decoder = nn.Sequential() if not dims: raise ValueError("dims cannot be None and should be a list of dimensions from input to output") if len(dims) < 2: raise ValueError("dims must contain at least input and output dimensions") registered = layer_registry.list_layers() if layer not in registered: raise ValueError(f"Unsupported layer type: {layer}. Registered types: {registered}") LayerClass = layer_registry.get_layer(layer) self.LayerClass = LayerClass # reverse dims dims = dims[::-1] for i in range(len(dims) - 1): if layer == 'gcn': layer_instance = LayerClass(dims[i], dims[i + 1], add_self_loops=self.add_self_loops) else: layer_instance = LayerClass(dims[i], dims[i + 1]) self.decoder.add_module(f'{layer}_{i}', layer_instance) if self.act_last or i < len(dims) - 2: self.decoder.add_module(f'{self.act}_{i}', self.act_func) self.reset_parameters()
[docs] def reset_parameters(self): for layer in self.decoder: if isinstance(layer, self.LayerClass): layer.reset_parameters()
@property def act_func(self): if self.act == 'relu': return nn.ReLU() elif self.act == 'tanh': return nn.Tanh() elif self.act == 'sigmoid': return nn.Sigmoid() elif self.act == 'leaky_relu': return nn.LeakyReLU() elif self.act == 'elu': return nn.ELU() else: return nn.Identity()
[docs] def forward(self, *args, **kwargs): raise NotImplementedError()
[docs]class MLPDecoder(BaseDecoder): """MLP Decoder class. Args: dims (List[int]): A list of dimensions from input to output. act (str): Activation function, e.g., 'relu', '' act_last (bool): Whether to apply activation function to the last layer. """ def __init__(self, dims, act='relu', act_last=False): super(MLPDecoder, self).__init__(dims=dims, layer='linear', act=act, act_last=act_last)
[docs] def forward(self, x) -> List[Tensor]: decodes = [] for i in range(len(self.decoder)): x = self.decoder[i](x) if isinstance(self.decoder[i], nn.Linear): decodes.append(x) return decodes
[docs]class GNNAttributeDecoder(BaseDecoder): """GNN Attribute Decoder class. Args: dims (List[int]): A list of dimensions from input to output. layer (str): Type of layers, e.g., 'linear', 'gcn', 'gat', 'sage', 'sg'. act (str): Activation function, e.g., 'relu', '' act_last (bool): Whether to apply activation function to the last layer. add_self_loops (bool): Whether to add self-loops to the graph. """ def __init__(self, dims, layer='gcn', act='relu', act_last=False, add_self_loops=True): super(GNNAttributeDecoder, self).__init__(dims=dims, act=act, layer=layer, act_last=act_last, add_self_loops=add_self_loops) self.reset_parameters()
[docs] def reset_parameters(self): for layer in self.decoder: if isinstance(layer, self.LayerClass): layer.reset_parameters()
[docs] def forward(self, x, edge_index) -> Tensor: for layer in self.decoder: if isinstance(layer, self.LayerClass): x = layer(x, edge_index) else: x = layer(x) return x
[docs]class InnerProductDecoder(nn.Module): """Inner Product Decoder class. $\hat{A} = sigmoid(ZZ^T)$ Args: None """ def __init__(self): super(InnerProductDecoder, self).__init__() self.reset_parameters()
[docs] def reset_parameters(self): pass
[docs] @staticmethod def forward(embedding) -> Tensor: hat_adj = torch.sigmoid(torch.matmul(embedding, embedding.t())) return hat_adj