Source code for vgslify.torch.layers

# Imports

# > Standard library
import inspect
from typing import Callable, Tuple

# > Third-party dependencies
import torch
from torch import nn

from vgslify.core.config import (
    Conv2DConfig,
    DenseConfig,
    DropoutConfig,
    InputConfig,
    Pooling2DConfig,
    ReshapeConfig,
    RNNConfig,
)

# > Internal dependencies
from vgslify.core.factory import LayerFactory
from vgslify.torch.reshape import Reshape


[docs] class TorchLayerFactory(LayerFactory): """ TorchLayerFactory is responsible for creating PyTorch-specific layers based on parsed VGSL (Variable-size Graph Specification Language) specifications. This factory handles the creation of various types of layers, including convolutional layers, pooling layers, RNN layers, dense layers, activation layers, and more. Attributes ---------- layers : list A list of PyTorch layers that have been added to the factory. shape : tuple of int The current shape of the tensor, excluding the batch size. _input_shape : tuple of int or None The original input shape provided during initialization. """ # A class-level dictionary that holds {prefix -> callable} for custom layers _custom_layer_registry = {}
[docs] @classmethod def register(cls, prefix: str, builder_fn): """ Register a custom layer builder function under a given spec prefix. Parameters ---------- prefix : str The VGSL spec prefix that triggers this custom layer (e.g. "Xsw"). builder_fn : callable A function with signature `builder_fn(self, spec: str) -> layer` that, given the VGSL spec string, returns the framework-specific layer. """ if prefix in cls._custom_layer_registry: raise ValueError(f"Prefix '{prefix}' is already registered.") # Inspect the builder function’s signature sig = inspect.signature(builder_fn) params = list(sig.parameters.values()) # Check that we have exactly two parameters if len(params) != 2: raise ValueError( "Custom layer builder_fn must define exactly two parameters: " "(factory_self, spec)." ) cls._custom_layer_registry[prefix] = builder_fn
[docs] @classmethod def get_custom_layer_registry(cls): """Return the dict of all registered custom layers for this factory class.""" return cls._custom_layer_registry
def __init__(self, input_shape: Tuple[int, ...] = None): """ Initialize the TorchLayerFactory. Parameters ---------- input_shape : tuple of int, optional The input shape for the model, excluding batch size. """ super().__init__(input_shape, data_format="channels_first")
[docs] def build(self, name: str = "VGSL_Model") -> nn.Module: """ Build the final model using the accumulated layers. Parameters ---------- name : str, optional The name of the model, by default "VGSL_Model" Returns ------- torch.nn.Module The constructed PyTorch model. Raises ------ ValueError If no layers have been added to the model. ValueError If no input shape has been specified for the model. """ if not self.layers: raise ValueError("No layers added to the model.") if not self._input_shape: raise ValueError("No input shape specified for the model.") # model = VGSLModel(self.layers) # TODO: Implement VGSLModel class model = nn.Sequential(*self.layers) model.__class__.__name__ = name return model
# Layer creation methods def _input(self, config: InputConfig, input_shape: Tuple[int, ...]): """ Create a PyTorch input layer (placeholder method). Parameters ---------- config : InputConfig Configuration object (unused in PyTorch). input_shape : tuple of int The input shape for the layer. Returns ------- None PyTorch doesn't require a separate input layer. """ return None def _conv2d(self, config: Conv2DConfig): """ Create a PyTorch Conv2d layer. Parameters ---------- config : Conv2DConfig Configuration object for the Conv2D layer. Returns ------- torch.nn.Conv2d The created Conv2d layer. """ padding = ( "same" if torch.__version__ >= "1.7" else self._compute_same_padding(config.kernel_size, config.strides) ) return nn.Conv2d( in_channels=self.shape[0], out_channels=config.filters, kernel_size=config.kernel_size, stride=config.strides, padding=padding, ) def _pooling2d(self, config: Pooling2DConfig): """ Create a PyTorch Pooling2d layer. Parameters ---------- config : Pooling2DConfig Configuration object for the Pooling2D layer. Returns ------- torch.nn.Module The created Pooling2d layer (either MaxPool2d or AvgPool2d). """ padding = self._compute_same_padding(config.pool_size, config.strides) pool_layer = nn.MaxPool2d if config.pool_type == "max" else nn.AvgPool2d return pool_layer( kernel_size=config.pool_size, stride=config.strides, padding=padding ) def _dense(self, config: DenseConfig): """ Create a PyTorch Linear (Dense) layer. Parameters ---------- config : DenseConfig Configuration object for the Dense layer. Returns ------- torch.nn.Linear The created Linear layer. """ return nn.Linear(self.shape[-1], config.units) def _rnn(self, config: RNNConfig): """ Create a PyTorch RNN layer (LSTM or GRU), either unidirectional or bidirectional. Parameters ---------- config : RNNConfig Configuration object for the RNN layer. Returns ------- torch.nn.Module The created RNN layer (either LSTM or GRU, unidirectional or bidirectional). Raises ------ ValueError If an unsupported RNN type is specified. """ if config.rnn_type.upper() == "L": rnn_class = nn.LSTM elif config.rnn_type.upper() == "G": rnn_class = nn.GRU else: raise ValueError(f"Unsupported RNN type: {config.rnn_type}") return rnn_class( input_size=self.shape[-1], hidden_size=config.units, num_layers=1, batch_first=True, dropout=config.dropout, bidirectional=config.bidirectional, ) def _batchnorm(self): """ Create a PyTorch BatchNorm layer. Returns ------- torch.nn.Module The created BatchNorm layer (either BatchNorm1d or BatchNorm2d). Raises ------ ValueError If the input shape is not supported for BatchNorm. """ if len(self.shape) == 3: return nn.BatchNorm2d(self.shape[0]) elif len(self.shape) == 2: return nn.BatchNorm1d(self.shape[0]) else: raise ValueError("Unsupported input shape for BatchNorm layer.") def _dropout(self, config: DropoutConfig): """ Create a PyTorch Dropout layer. Parameters ---------- config : DropoutConfig Configuration object for the Dropout layer. Returns ------- nn.Dropout The created Dropout layer. """ return nn.Dropout(p=config.rate) def _activation(self, activation_function: str): """ Create a PyTorch activation layer. Parameters ---------- activation_function : str Name of the activation function. Supported values are 'softmax', 'tanh', 'relu', 'linear', 'sigmoid'. Returns ------- nn.Module The created activation layer. Raises ------ ValueError If the activation function is not supported. """ activations = { "softmax": nn.Softmax(dim=1), "tanh": nn.Tanh(), "relu": nn.ReLU(), "linear": nn.Identity(), "sigmoid": nn.Sigmoid(), } if activation_function in activations: return activations[activation_function] else: raise ValueError(f"Unsupported activation: {activation_function}") def _reshape(self, config: ReshapeConfig): """ Create a PyTorch Reshape layer. Parameters ---------- config : ReshapeConfig Configuration object for the Reshape layer. Returns ------- nn.Module The created Reshape layer. """ return Reshape(*config.target_shape) def _flatten(self): """ Create a PyTorch Flatten layer. Returns ------- nn.Flatten The created Flatten layer. """ return nn.Flatten() # Helper methods def _compute_same_padding(self, kernel_size, stride): """ Compute the padding size to achieve 'same' padding. Parameters ---------- kernel_size : int or tuple Size of the kernel. stride : int or tuple Stride of the convolution. Returns ------- tuple Padding size for height and width dimensions. """ if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) if isinstance(stride, int): stride = (stride, stride) padding = [] for k, s in zip(kernel_size, stride): p = (k - 1) // 2 padding.append(p) return tuple(padding) def _get_activation_layer(self, activation_name: str): """ Return a PyTorch activation layer based on the activation name. Parameters ---------- activation_name : str Name of the activation function. Returns ------- torch.nn.Module The activation layer. Raises ------ ValueError If the activation_name is not recognized. """ activations = { "softmax": nn.Softmax(dim=1), "tanh": nn.Tanh(), "relu": nn.ReLU(), "linear": nn.Identity(), "sigmoid": nn.Sigmoid(), } if activation_name in activations: return activations[activation_name] else: raise ValueError(f"Unsupported activation: {activation_name}")
[docs] def register_custom_layer(prefix: str) -> Callable: """ Decorator to register a custom layer builder function for TorchLayerFactory. This allows users to easily extend TorchLayerFactory with custom layer types by defining a function that constructs a PyTorch layer from a VGSL spec string. Parameters ---------- prefix : str The VGSL spec prefix that triggers this custom layer (e.g. "Xsw"). Returns ------- Callable A decorator that registers the provided function as a builder for the given prefix. Raises ------ ValueError If a builder for the prefix is already registered or if the function signature is invalid. Examples -------- >>> from vgslify.torch.layers import register_custom_layer >>> from torch import nn >>> @register_custom_layer("Xsw") ... def build_custom_layer(factory, spec): ... # Custom layer building logic ... return nn.Linear(factory.shape[-1], 10) """ def decorator(fn: Callable) -> Callable: TorchLayerFactory.register(prefix, fn) return fn return decorator