Fits a Generalized Additive Model where smooth terms are modeled by keras
neural networks.
In addition to point predictions, the model can optionally estimate uncertainty bands via Monte Carlo Dropout across forward passes.
neuralGAM(
formula,
data,
family = "gaussian",
num_units = 64,
learning_rate = 0.001,
activation = "relu",
kernel_initializer = "glorot_normal",
kernel_regularizer = NULL,
bias_regularizer = NULL,
bias_initializer = "zeros",
activity_regularizer = NULL,
loss = "mse",
uncertainty_method = c("none", "epistemic"),
alpha = 0.05,
forward_passes = 100,
dropout_rate = 0.1,
validation_split = NULL,
w_train = NULL,
bf_threshold = 0.001,
ls_threshold = 0.1,
max_iter_backfitting = 10,
max_iter_ls = 10,
seed = NULL,
verbose = 1,
...
)
Model formula. Smooth terms must be wrapped in s(...)
.
You can specify per-term NN settings, e.g.:
y ~ s(x1, num_units = 1024) + s(x3, num_units = c(1024, 512))
.
Data frame containing the variables.
Response distribution: "gaussian"
, "binomial"
, "poisson"
.
Default hidden layer sizes for smooth terms (integer or vector).
Mandatory unless every s(...)
specifies its own num_units
.
Learning rate for Adam optimizer.
Activation function for hidden layers. Either a string understood by
tf$keras$activations$get()
or a function.
Initializers for weights and biases.
Optional Keras regularizers.
Loss function to use. Can be any Keras built-in (e.g., "mse"
, "mae"
,
"huber"
, "logcosh"
) or a custom function, passed directly to keras::compile()
.
Character string indicating the type of uncertainty to estimate. One of:
"none"
(default): no uncertainty estimation.
"epistemic"
: MC Dropout for mean uncertainty (CIs)
Significance level for prediction intervals, e.g. 0.05
for 95% coverage.
Integer. Number of MC-dropout forward passes used when
uncertainty_method %in% c("epistemic","both")
.
Dropout probability in smooth-term NNs (0,1).
During training: acts as a regularizer.
During prediction (if uncertainty_method
is "epistemic"): enables MC Dropout sampling.
Optional fraction of training data used for validation.
Optional training weights.
Convergence criterion of the backfitting algorithm. Defaults to 0.001
Convergence criterion of the local scoring algorithm. Defaults to 0.1
An integer with the maximum number of iterations
of the backfitting algorithm. Defaults to 10
.
An integer with the maximum number of iterations of the local scoring Algorithm. Defaults to 10
.
Random seed.
Verbosity: 0
silent, 1
progress messages.
Additional arguments passed to keras::optimizer_adam()
.
An object of class "neuralGAM"
, a list with elements including:
Numeric vector of fitted mean predictions (training data).
Data frame of partial contributions \(g_j(x_j)\) per smooth term.
Observed response values.
Linear predictor \(\eta = \eta_0 + \sum_j g_j(x_j)\).
Lower/upper confidence interval bounds (response scale)
Training covariates (inputs).
List of fitted Keras models, one per smooth term (+ "linear"
if present).
Intercept estimate \(\eta_0\).
Model family.
Data frame of training/validation losses per backfitting iteration.
Training mean squared error.
Parsed model formula (via get_formula_elements()
).
List of Keras training histories per term.
Global hyperparameter defaults.
PI significance level (if trained with uncertainty).
Logical; whether the model was trained with uancertainty estimation enabled
Type of predictive uncertainty used ("none","epistemic").
Matrix of per-term epistemic variances (if computed).
# \dontrun{
library(neuralGAM)
dat <- sim_neuralGAM_data()
train <- dat$train
test <- dat$test
# Per-term architecture and confidence intervals
ngam <- neuralGAM(
y ~ s(x1, num_units = c(128,64), activation = "tanh") +
s(x2, num_units = 256),
data = train,
uncertainty_method = "epistemic",
forward_passes = 10,
alpha = 0.05
)
#> [1] "Initializing neuralGAM..."
#> Hint: To use tensorflow with `py_require()`, call `py_require("tensorflow")` at the start of the R session
#> Error in validate_activation(activation): Invalid activation 'tanh'. Use a valid tf.keras activation name or an R function.
ngam
#> Error: object 'ngam' not found
# }