6. Predict environmental parameters (Random Forest)

Use Random Forest models to predict environmental parameters from community composition data.

Here we will use machine-learning to predict environmental parameters from community composition data. If you think this is interesting and want to learn more, check out the poster The marine microbiome can accurately predict its chemical and biological environment by Emma Bell, Karin Garefelt, Krzysztof Jurdzinski et al. here at the SBDI days, or read the preprint.

Choose target variable

Pick an environmental parameter to predict. For example, salinity:

y <- salinity

Remove samples with missing values:

ok <- which(!is.na(y))
y <- y[ok]

Create a feature matrix

We use relative abundances of ASVs as features (norm_counts created in 4).

X <- t(norm_counts) # rows = samples, cols = ASVs

Subset to the samples with environmental parameter values:

X <- X[ok, , drop = FALSE]

Optionally filter rare ASVs (present in <10% of samples) to reduce dimensionality:

keep = which(colSums(X > 0)/nrow(X) >= 0.1)
X = X[, keep, drop = FALSE]

Train/test split

Split the data into training and test sets to evaluate model performance on unseen samples.

set.seed(1)
n <- nrow(X)
train_ix <- sample(seq_len(n), size = round(0.8 * n))
test_ix <- setdiff(seq_len(n), train_ix)

X_train <- X[train_ix, , drop = FALSE]
y_train <- y[train_ix]

X_test <- X[test_ix, , drop = FALSE]
y_test <- y[test_ix]

Train Random Forest

We train a random forest model using the R package ranger.

rf <- ranger(
  x = X_train,
  y = y_train,
  num.trees = 5000,
  importance = "permutation"
)

Predict and evaluate

Assess model performance by predicting the environmental parameter for the test data.

pred <- predict(rf, data = X_test)$predictions

# Coefficient of determination
r2 = 1 - sum((y_test - pred)^2) / sum((y_test - mean(y_test))^2)
r2

# Root Mean Square Error
rmse <- sqrt(mean((pred - y_test)^2, na.rm = TRUE)) 
rmse 

Plot observed vs predicted

Visualize how well predicted values match observed values in the test set.

par(mfrow = c(1,1), mar = c(5,5,2,2), xpd = FALSE)
lims <- range(c(y_test, pred))
plot(y_test, pred,
     xlab = "Observed",
     ylab = "Predicted",
     xlim = lims,
     ylim = lims)
abline(0, 1, lty = 2, col = "grey50")

Variable importance (optional)

Identify which ASVs contribute most to the model’s predictions.

imp <- sort(rf$variable.importance, decreasing = TRUE)
plot(imp)

To get the taxonomy of the 10 ASVs with highest importance for the model:

merged_df$asvs[names(imp[1:10]),7:10]

← Previous · Overview · Next →