# Imports
# > Standard Library
import inspect
from typing import Callable, Dict, Type, Union
# > Third-Party Dependencies
import tensorflow as tf
# > Internal
from vgslify.core.config import (
    ActivationConfig,
    Conv2DConfig,
    DenseConfig,
    DropoutConfig,
    InputConfig,
    Pooling2DConfig,
    ReshapeConfig,
    RNNConfig,
)
from vgslify.model_parsers.base import BaseModelParser
[docs]
class TensorFlowModelParser(BaseModelParser):
    """
    Parser for converting TensorFlow Keras models into VGSL (Variable-size Graph Specification Language) spec strings.
    This class extends the BaseModelParser to provide specific functionality for TensorFlow Keras models.
    It uses configuration dataclasses to represent different layer types and converts them into
    VGSL spec strings.
    Attributes
    ----------
    layer_parsers : Dict[Type[tf.keras.layers.Layer], Callable]
        A dictionary mapping TensorFlow Keras layer types to their corresponding parsing methods.
    Notes
    -----
    This parser supports a wide range of TensorFlow Keras layers and can be extended to support
    additional layer types by adding new parsing methods and updating the layer_parsers dictionary.
    """
    # A class-level dictionary: {layer_class -> parser_function}
    _custom_layer_parsers: Dict[Type[tf.keras.layers.Layer], Callable] = {}
[docs]
    @classmethod
    def register(cls, layer_cls: Type[tf.keras.layers.Layer], parser_fn: Callable):
        """
        Registers a custom parser function for a given TF layer class.
        Parameters
        ----------
        layer_cls : Type[tf.keras.layers.Layer]
            The TF layer class this parser function can handle.
        parser_fn : Callable
            A function with signature parser_fn(layer) -> str
            that returns a VGSL spec string for the given layer.
        """
        if layer_cls in cls._custom_layer_parsers:
            raise ValueError(
                f"A parser is already registered for {layer_cls.__name__}."
            )
        # Check signature to ensure `parser_fn` is (layer) -> str
        sig = inspect.signature(parser_fn)
        params = list(sig.parameters.values())
        if len(params) != 1:
            raise ValueError(
                "Custom parser function must define exactly one parameter: (layer)."
            )
        cls._custom_layer_parsers[layer_cls] = parser_fn 
[docs]
    @classmethod
    def get_custom_parsers(cls):
        """Return the dict of all registered custom parser functions."""
        return cls._custom_layer_parsers 
    def __init__(self):
        # Initialize the layer parsers mapping
        self.layer_parsers: Dict[Type[tf.keras.layers.Layer], Callable] = {
            tf.keras.layers.InputLayer: self.parse_input,
            tf.keras.layers.Conv2D: self.parse_conv2d,
            tf.keras.layers.Dense: self.parse_dense,
            tf.keras.layers.LSTM: self.parse_rnn,
            tf.keras.layers.GRU: self.parse_rnn,
            tf.keras.layers.Bidirectional: self.parse_rnn,
            tf.keras.layers.MaxPooling2D: self.parse_pooling,
            tf.keras.layers.AveragePooling2D: self.parse_pooling,
            tf.keras.layers.BatchNormalization: self.parse_batchnorm,
            tf.keras.layers.Dropout: self.parse_dropout,
            tf.keras.layers.Reshape: self.parse_reshape,
            tf.keras.layers.Flatten: self.parse_flatten,
            tf.keras.layers.Activation: self.parse_activation,
        }
        # Merge in any custom user-registered parsers from the class-level registry
        for layer_cls, parse_fn in self.get_custom_parsers().items():
            self.layer_parsers[layer_cls] = parse_fn
[docs]
    def parse_model(self, model: tf.keras.models.Model) -> str:
        """
        Parse a TensorFlow Keras model into a VGSL spec string.
        Parameters
        ----------
        model : tf.keras.models.Model
            Keras model to be converted.
        Returns
        -------
        str
            VGSL spec string.
        Raises
        ------
        ValueError
            If the model contains unsupported layers or if the input shape is invalid.
        """
        configs = []
        # Handle InputLayer
        if not isinstance(model.layers[0], tf.keras.layers.InputLayer):
            input_layer = tf.keras.layers.InputLayer(
                input_shape=model.input_shape[1:], batch_size=model.input_shape[0]
            )
            input_config = self.parse_input(input_layer)
            configs.append(input_config)
        # Iterate through all layers in the model
        for idx, layer in enumerate(model.layers):
            layer_type = type(layer)
            parser_func = self.layer_parsers.get(layer_type, None)
            if parser_func:
                # Parse the layer
                config = parser_func(layer)
                # Append the config if not None
                if config:
                    configs.append(config)
            else:
                raise ValueError(
                    f"Unsupported layer type {layer_type.__name__} at position {idx}."
                )
        # Generate VGSL spec string from configs
        return self.generate_vgsl(configs) 
    # Parser methods for different layer types
[docs]
    def parse_conv2d(self, layer: tf.keras.layers.Conv2D) -> Conv2DConfig:
        """
        Parse a Conv2D layer into a Conv2DConfig dataclass.
        Parameters
        ----------
        layer : tf.keras.layers.Conv2D
            The Conv2D layer to parse.
        Returns
        -------
        Conv2DConfig
            The configuration for the Conv2D layer.
        """
        activation = self._extract_activation(layer)
        return Conv2DConfig(
            activation=activation,
            kernel_size=layer.kernel_size,
            strides=layer.strides,
            filters=layer.filters,
        ) 
[docs]
    def parse_dense(self, layer: tf.keras.layers.Dense) -> DenseConfig:
        """
        Parse a Dense layer into a DenseConfig dataclass.
        Parameters
        ----------
        layer : tf.keras.layers.Dense
            The Dense layer to parse.
        Returns
        -------
        DenseConfig
            The configuration for the Dense layer.
        """
        activation = self._extract_activation(layer)
        return DenseConfig(activation=activation, units=layer.units) 
[docs]
    def parse_rnn(
        self,
        layer: Union[
            tf.keras.layers.LSTM, tf.keras.layers.GRU, tf.keras.layers.Bidirectional
        ],
    ) -> RNNConfig:
        """
        Parse an RNN layer (LSTM, GRU, or Bidirectional) into an RNNConfig dataclass.
        Parameters
        ----------
        layer : Union[tf.keras.layers.LSTM, tf.keras.layers.GRU, tf.keras.layers.Bidirectional]
            The RNN layer to parse.
        Returns
        -------
        RNNConfig
            The configuration for the RNN layer.
        """
        if isinstance(layer, tf.keras.layers.Bidirectional):
            wrapped_layer = layer.forward_layer
            bidirectional = True
        else:
            wrapped_layer = layer
            bidirectional = False
        if isinstance(wrapped_layer, tf.keras.layers.LSTM):
            rnn_type = "lstm"
        elif isinstance(wrapped_layer, tf.keras.layers.GRU):
            rnn_type = "gru"
        else:
            raise ValueError(
                f"Unsupported RNN layer type {type(wrapped_layer).__name__}."
            )
        return RNNConfig(
            units=wrapped_layer.units,
            return_sequences=wrapped_layer.return_sequences,
            go_backwards=wrapped_layer.go_backwards if not bidirectional else False,
            dropout=wrapped_layer.dropout,
            recurrent_dropout=wrapped_layer.recurrent_dropout,
            rnn_type=rnn_type,
            bidirectional=bidirectional,
        ) 
[docs]
    def parse_pooling(
        self,
        layer: Union[tf.keras.layers.MaxPooling2D, tf.keras.layers.AveragePooling2D],
    ) -> Pooling2DConfig:
        """
        Parse a Pooling layer into a Pooling2DConfig dataclass.
        Parameters
        ----------
        layer : tf.keras.layers.MaxPooling2D or tf.keras.layers.AveragePooling2D
            The Pooling layer to parse.
        Returns
        -------
        Pooling2DConfig
            The configuration for the Pooling layer.
        """
        if isinstance(layer, tf.keras.layers.MaxPooling2D):
            pool_type = "max"
        elif isinstance(layer, tf.keras.layers.AveragePooling2D):
            pool_type = "average"
        return Pooling2DConfig(
            pool_type=pool_type,
            pool_size=layer.pool_size,
            strides=layer.strides if layer.strides else layer.pool_size,
        ) 
[docs]
    def parse_batchnorm(self, layer: tf.keras.layers.BatchNormalization) -> None:
        """
        Parse a BatchNormalization layer.
        Since BatchNormalization does not require a VGSL spec beyond 'Bn', return a placeholder.
        Parameters
        ----------
        layer : tf.keras.layers.BatchNormalization
            The BatchNormalization layer to parse.
        Returns
        -------
        None
            Indicates that the VGSL spec should include 'Bn'.
        """
        return "Bn" 
[docs]
    def parse_dropout(self, layer: tf.keras.layers.Dropout) -> DropoutConfig:
        """
        Parse a Dropout layer into a DropoutConfig dataclass.
        Parameters
        ----------
        layer : tf.keras.layers.Dropout
            The Dropout layer to parse.
        Returns
        -------
        DropoutConfig
            The configuration for the Dropout layer.
        """
        return DropoutConfig(rate=layer.rate) 
[docs]
    def parse_flatten(self, layer: tf.keras.layers.Flatten) -> None:
        """
        Parse a Flatten layer.
        Since Flatten does not require a VGSL spec beyond 'Flatten', return a placeholder.
        Parameters
        ----------
        layer : tf.keras.layers.Flatten
            The Flatten layer to parse.
        Returns
        -------
        None
            Indicates that the VGSL spec should include 'Flatten'.
        """
        return "Flt" 
[docs]
    def parse_reshape(self, layer: tf.keras.layers.Reshape) -> ReshapeConfig:
        """
        Parse a Reshape layer into a ReshapeConfig dataclass.
        Parameters
        ----------
        layer : tf.keras.layers.Reshape
            The Reshape layer to parse.
        Returns
        -------
        ReshapeConfig
            The configuration for the Reshape layer.
        """
        target_shape = layer.target_shape
        return ReshapeConfig(target_shape=target_shape) 
[docs]
    def parse_activation(self, layer: tf.keras.layers.Activation) -> ActivationConfig:
        """
        Parse an Activation layer.
        Parameters
        ----------
        layer : tf.keras.layers.Activation
            The Activation layer to parse.
        Returns
        -------
        ActivationConfig
            The configuration for the Activation layer.
        """
        activation = self._extract_activation(layer)
        return ActivationConfig(activation=activation) 
    # Helper methods
    def _extract_activation(self, layer: tf.keras.layers.Layer) -> str:
        """
        Extract the activation function from a TensorFlow Keras layer.
        Parameters
        ----------
        layer : tf.keras.layers.Layer
            The layer from which to extract the activation.
        Returns
        -------
        str
            The activation function name.
        """
        if hasattr(layer, "activation") and callable(layer.activation):
            activation = layer.activation.__name__
        elif isinstance(layer, tf.keras.layers.Activation):
            activation = layer.activation.__name__
        else:
            activation = "linear"
        return activation 
[docs]
def register_custom_parser(layer_cls: Type[tf.keras.layers.Layer]) -> Callable:
    """
    Decorator to register a custom parser function for a given TensorFlow Keras layer class.
    This allows users to extend `TensorFlowModelParser` by defining a function that
    converts a TensorFlow Keras layer into a VGSL specification.
    Parameters
    ----------
    layer_cls : Type[tf.keras.layers.Layer]
        The TensorFlow Keras layer class to associate with the parser function.
    Returns
    -------
    Callable
        A decorator that registers the provided function as a parser for `layer_cls`.
    Raises
    ------
    ValueError
        If a parser for `layer_cls` is already registered or if the function does not
        accept exactly one argument (the layer instance).
    Examples
    --------
    Registering a custom parser for a `MyCustomLayer`:
    >>> from vgslify.model_parsers.tensorflow import register_custom_parser
    >>> import tensorflow as tf
    >>> class MyCustomLayer(tf.keras.layers.Layer):
    ...     def __init__(self, units: int):
    ...         super().__init__()
    ...         self.units = units
    ...
    >>> @register_custom_parser(MyCustomLayer)
    ... def parse_my_custom_layer(layer: MyCustomLayer):
    ...     return f"MyCustomSpec({layer.units})"
    ...
    >>> # Now the parser is automatically registered inside TensorFlowModelParser
    """
    def decorator(fn: Callable) -> Callable:
        TensorFlowModelParser.register(layer_cls, fn)
        return fn
    return decorator