# -*- coding: utf-8 -*-
import torch.nn as nn
from torch_geometric.nn import GCNConv, GATConv, SAGEConv, SGConv
_layer_registry = {
'linear': nn.Linear,
'gcn': GCNConv,
'gat': GATConv,
'sage': SAGEConv,
'sg': SGConv
}
[docs]class LayerRegistry:
_registry = _layer_registry
[docs] @classmethod
def list_layers(cls):
return list(cls._registry.keys())
[docs] @classmethod
def get_layer(cls, name):
return cls._registry.get(name)
[docs]def register_layer(name: str, layer_class: nn.Module):
"""Register decorators/functions for custom layer types.
Args:
name (str): Name of the layer, available layer: linear, gcn, gat, sage, sg.
layer_class (nn.Module): Class of the layer.
"""
def decorator(cls):
_layer_registry[name] = layer_class
return cls
return decorator