6. Predict environmental parameters (Random Forest)
Use Random Forest models to predict environmental parameters from community composition data.
Here we will use machine-learning to predict environmental parameters from community composition data. If you think this is interesting and want to learn more, check out the poster The marine microbiome can accurately predict its chemical and biological environment by Emma Bell, Karin Garefelt, Krzysztof Jurdzinski et al. here at the SBDI days, or read the preprint.
Choose target variable
Pick an environmental parameter to predict. For example, salinity:
y <- salinityRemove samples with missing values:
ok <- which(!is.na(y))
y <- y[ok]Create a feature matrix
We use relative abundances of ASVs as features (norm_counts created in 4).
X <- t(norm_counts) # rows = samples, cols = ASVsSubset to the samples with environmental parameter values:
X <- X[ok, , drop = FALSE]Optionally filter rare ASVs (present in <10% of samples) to reduce dimensionality:
keep = which(colSums(X > 0)/nrow(X) >= 0.1)
X = X[, keep, drop = FALSE]Train/test split
Split the data into training and test sets to evaluate model performance on unseen samples.
set.seed(1)
n <- nrow(X)
train_ix <- sample(seq_len(n), size = round(0.8 * n))
test_ix <- setdiff(seq_len(n), train_ix)
X_train <- X[train_ix, , drop = FALSE]
y_train <- y[train_ix]
X_test <- X[test_ix, , drop = FALSE]
y_test <- y[test_ix]Train Random Forest
We train a random forest model using the R package ranger.
rf <- ranger(
x = X_train,
y = y_train,
num.trees = 5000,
importance = "permutation"
)Predict and evaluate
Assess model performance by predicting the environmental parameter for the test data.
pred <- predict(rf, data = X_test)$predictions
# Coefficient of determination
r2 = 1 - sum((y_test - pred)^2) / sum((y_test - mean(y_test))^2)
r2
# Root Mean Square Error
rmse <- sqrt(mean((pred - y_test)^2, na.rm = TRUE))
rmse Plot observed vs predicted
Visualize how well predicted values match observed values in the test set.
par(mfrow = c(1,1), mar = c(5,5,2,2), xpd = FALSE)
lims <- range(c(y_test, pred))
plot(y_test, pred,
xlab = "Observed",
ylab = "Predicted",
xlim = lims,
ylim = lims)
abline(0, 1, lty = 2, col = "grey50")Variable importance (optional)
Identify which ASVs contribute most to the model’s predictions.
imp <- sort(rf$variable.importance, decreasing = TRUE)
plot(imp)To get the taxonomy of the 10 ASVs with highest importance for the model:
merged_df$asvs[names(imp[1:10]),7:10]← Previous · Overview · Next →