from abc import ABC, abstractmethod
from typing import List, Union
from vgslify.core.config import (
ActivationConfig,
Conv2DConfig,
DenseConfig,
DropoutConfig,
InputConfig,
Pooling2DConfig,
ReshapeConfig,
RNNConfig,
)
[docs]
class BaseModelParser(ABC):
"""
Abstract base class for model parsers.
Provides common utility methods for parsing different frameworks and generating VGSL spec strings.
"""
[docs]
def generate_vgsl(
self,
configs: List[
Union[
Conv2DConfig,
Pooling2DConfig,
DenseConfig,
RNNConfig,
DropoutConfig,
ReshapeConfig,
InputConfig,
ActivationConfig,
]
],
) -> str:
"""
Convert a list of layer configuration dataclasses into a VGSL specification string.
Parameters
----------
configs : List[Union[Conv2DConfig, Pooling2DConfig, DenseConfig, RNNConfig,
DropoutConfig, ReshapeConfig, InputConfig, ActivationConfig]]
List of layer configurations.
Returns
-------
str
VGSL specification string.
"""
vgsl_parts = []
i = len(configs) - 1 # Start from the end of the list to merge activations
while i >= 0:
config = configs[i]
if isinstance(config, ActivationConfig):
# Check if there is a preceding layer to merge with
if i > 0:
preceding_config = configs[i - 1]
if (
isinstance(
preceding_config, (Conv2DConfig, DenseConfig, RNNConfig)
)
and preceding_config.activation == "linear"
):
# Merge the activation into the preceding layer
preceding_config.activation = config.activation
# Skip adding this ActivationConfig
i -= 1
continue
# If cannot merge, add the activation spec
vgsl_parts.append(self._vgsl_activation(config))
# Handle non-activation layers and strings
elif isinstance(config, InputConfig):
vgsl_parts.append(self._vgsl_input(config))
elif isinstance(config, Conv2DConfig):
vgsl_parts.append(self._vgsl_conv2d(config))
elif isinstance(config, Pooling2DConfig):
vgsl_parts.append(self._vgsl_pooling2d(config))
elif isinstance(config, DenseConfig):
vgsl_parts.append(self._vgsl_dense(config))
elif isinstance(config, RNNConfig):
vgsl_parts.append(self._vgsl_rnn(config))
elif isinstance(config, DropoutConfig):
vgsl_parts.append(self._vgsl_dropout(config))
elif isinstance(config, ReshapeConfig):
vgsl_parts.append(self._vgsl_reshape(config))
elif isinstance(config, str):
vgsl_parts.append(config)
else:
raise ValueError(
f"Unsupported configuration type: {type(config).__name__}"
)
i -= 1 # Move to the previous config
# Reverse to restore the original order
return " ".join(vgsl_parts[::-1])
[docs]
@abstractmethod
def parse_model(self, model) -> str:
"""Parse the model into a VGSL spec string."""
pass
[docs]
@abstractmethod
def parse_conv2d(self, layer) -> Conv2DConfig:
"""Parse the Conv2D layer into a Conv2DConfig dataclass."""
pass
[docs]
@abstractmethod
def parse_dense(self, layer) -> DenseConfig:
"""Parse the Dense layer into a DenseConfig dataclass."""
pass
[docs]
@abstractmethod
def parse_rnn(self, layer) -> RNNConfig:
"""Parse the RNN layer into a RNNConfig dataclass."""
pass
[docs]
@abstractmethod
def parse_pooling(self, layer) -> Pooling2DConfig:
"""Parse the Pooling layer into a Pooling2DConfig dataclass."""
pass
[docs]
@abstractmethod
def parse_batchnorm(self, layer) -> str:
"""Parse the BatchNorm layer into a VGSL spec string."""
pass
[docs]
@abstractmethod
def parse_dropout(self, layer) -> DropoutConfig:
"""Parse the Dropout layer into a DropoutConfig dataclass."""
pass
[docs]
@abstractmethod
def parse_flatten(self, layer) -> str:
"""Parse the Flatten layer into a VGSL spec string."""
pass
[docs]
@abstractmethod
def parse_reshape(self, layer) -> ReshapeConfig:
"""Parse the Reshape layer into a ReshapeConfig dataclass."""
pass
[docs]
@abstractmethod
def parse_activation(self, layer) -> ActivationConfig:
"""Parse the Activation layer into a ActivationConfig dataclass."""
pass
# VGSL Generation Methods
def _vgsl_input(self, config: InputConfig) -> str:
"""
Generate VGSL string for input layer.
Parameters
----------
config : InputConfig
Configuration for the input layer.
Returns
-------
str
VGSL string representation of the input layer.
"""
return ",".join(
map(
str,
filter(
lambda x: x != -1,
[
config.batch_size,
config.depth,
config.height,
config.width,
config.channels,
],
),
)
)
def _vgsl_conv2d(self, config: Conv2DConfig) -> str:
"""
Generate VGSL string for Conv2D layer.
Parameters
----------
config : Conv2DConfig
Configuration for the Conv2D layer.
Returns
-------
str
VGSL string representation of the Conv2D layer.
"""
act = self._get_activation_code(config.activation)
stride_spec = (
",".join(map(str, config.strides)) if config.strides != (1, 1) else ""
)
stride_str = f",{stride_spec}" if stride_spec else ""
return f"C{act}{config.kernel_size[0]},{config.kernel_size[1]}{stride_str},{config.filters}"
def _vgsl_pooling2d(self, config: Pooling2DConfig) -> str:
"""
Generate VGSL string for Pooling2D layer.
Parameters
----------
config : Pooling2DConfig
Configuration for the Pooling2D layer.
Returns
-------
str
VGSL string representation of the Pooling2D layer.
"""
pool_type_code = "Mp" if config.pool_type.lower() == "max" else "Ap"
pool_size_str = ",".join(map(str, config.pool_size))
strides_str = (
",".join(map(str, config.strides))
if config.strides != config.pool_size
else ""
)
return (
f"{pool_type_code}{pool_size_str}{',' + strides_str if strides_str else ''}"
)
def _vgsl_dense(self, config: DenseConfig) -> str:
"""
Generate VGSL string for Dense layer.
Parameters
----------
config : DenseConfig
Configuration for the Dense layer.
Returns
-------
str
VGSL string representation of the Dense layer.
"""
act = self._get_activation_code(config.activation)
return f"F{act}{config.units}"
def _vgsl_rnn(self, config: RNNConfig) -> str:
"""
Generate VGSL string for RNN layer.
Parameters
----------
config : RNNConfig
Configuration for the RNN layer.
Returns
-------
str
VGSL string representation of the RNN layer.
Raises
------
ValueError
If an unsupported RNN type is provided.
"""
if config.bidirectional:
layer_type = "B"
rnn_type = "l" if config.rnn_type.lower() == "lstm" else "g"
else:
if config.rnn_type.lower() == "lstm":
layer_type = "L"
elif config.rnn_type.lower() == "gru":
layer_type = "G"
else:
raise ValueError(f"Unsupported RNN type: {config.rnn_type}")
rnn_type = "r" if config.go_backwards else "f"
return_sequences = (
"s" if config.return_sequences and not config.bidirectional else ""
)
spec = f"{layer_type}{rnn_type}{return_sequences}{config.units}"
if config.dropout > 0:
spec += f",D{int(config.dropout * 100)}"
if config.recurrent_dropout > 0:
spec += f",Rd{int(config.recurrent_dropout * 100)}"
return spec
def _vgsl_dropout(self, config: DropoutConfig) -> str:
"""
Generate VGSL string for Dropout layer.
Parameters
----------
config : DropoutConfig
Configuration for the Dropout layer.
Returns
-------
str
VGSL string representation of the Dropout layer.
"""
return f"D{int(config.rate * 100)}"
def _vgsl_reshape(self, config: ReshapeConfig) -> str:
"""
Generate VGSL string for Reshape layer.
Parameters
----------
config : ReshapeConfig
Configuration for the Reshape layer.
Returns
-------
str
VGSL string representation of the Reshape layer.
"""
if len(config.target_shape) == 2 and (
None in config.target_shape or -1 in config.target_shape
):
return "Rc3"
else:
reshape_dims = ",".join(
map(lambda x: str(x) if x is not None else "-1", config.target_shape)
)
return f"R{reshape_dims}"
def _vgsl_activation(self, config: ActivationConfig) -> str:
"""
Generate VGSL string for Activation layer.
Parameters
----------
config : ActivationConfig
Configuration for the Activation layer.
Returns
-------
str
VGSL string representation of the Activation layer.
"""
act = self._get_activation_code(config.activation)
return f"A{act}"
def _get_activation_code(self, activation: str) -> str:
"""
Get the VGSL activation code for a given activation function.
Parameters
----------
activation : str
Name of the activation function.
Returns
-------
str
VGSL activation code.
Raises
------
ValueError
If an unsupported activation function is provided.
"""
ACTIVATION_MAP = {
"softmax": "s",
"tanh": "t",
"relu": "r",
"linear": "l",
"sigmoid": "m",
"identity": "l",
}
act_code = ACTIVATION_MAP.get(activation.lower(), None)
if act_code is None:
raise ValueError(f"Unsupported activation '{activation}'.")
return act_code