Source code for pydgc.modules.utils

# -*- coding: utf-8 -*-
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, SGConv


_layer_registry = {
    'linear': nn.Linear,
    'gcn': GCNConv,
    'gat': GATConv,
    'sage': SAGEConv,
    'sg': SGConv
}


[docs]class LayerRegistry: _registry = _layer_registry
[docs] @classmethod def list_layers(cls): return list(cls._registry.keys())
[docs] @classmethod def get_layer(cls, name): return cls._registry.get(name)
[docs]def register_layer(name: str, layer_class: nn.Module): """Register decorators/functions for custom layer types. Args: name (str): Name of the layer, available layer: linear, gcn, gat, sage, sg. layer_class (nn.Module): Class of the layer. """ def decorator(cls): _layer_registry[name] = layer_class return cls return decorator