## ----setup, message = FALSE, warning = FALSE, comment = NA--------------------
knitr::opts_chunk$set(message = FALSE, warning = FALSE, comment = NA, 
                      fig.width = 6.25, fig.height = 5)
library(ANCOMBC)
library(tidyverse)

## ----helper-------------------------------------------------------------------
get_upper_tri = function(cormat){
    cormat[lower.tri(cormat)] = NA
    diag(cormat) = NA
    return(cormat)
}

## ----getPackage, eval=FALSE---------------------------------------------------
# if (!requireNamespace("BiocManager", quietly = TRUE))
#     install.packages("BiocManager")
# BiocManager::install("ANCOMBC")

## ----load, eval=FALSE---------------------------------------------------------
# library(ANCOMBC)

## -----------------------------------------------------------------------------
data(atlas1006, package = "microbiome")

# Subset to baseline
pseq = phyloseq::subset_samples(atlas1006, time == 0)

# Re-code the bmi group
meta_data = microbiome::meta(pseq)
meta_data$bmi = recode(meta_data$bmi_group,
                       obese = "obese",
                       severeobese = "obese",
                       morbidobese = "obese")

# Note that by default, levels of a categorical variable in R are sorted 
# alphabetically. In this case, the reference level for `bmi` will be 
# `lean`. To manually change the reference level, for instance, setting `obese`
# as the reference level, use:
meta_data$bmi = factor(meta_data$bmi, levels = c("obese", "overweight", "lean"))
# You can verify the change by checking:
# levels(meta_data$bmi)

# Create the region variable
meta_data$region = recode(as.character(meta_data$nationality),
                          Scandinavia = "NE", UKIE = "NE", SouthEurope = "SE", 
                          CentralEurope = "CE", EasternEurope = "EE",
                          .missing = "unknown")

phyloseq::sample_data(pseq) = meta_data

# Subset to lean, overweight, and obese subjects
pseq = phyloseq::subset_samples(pseq, bmi %in% c("lean", "overweight", "obese"))
# Discard "EE" as it contains only 1 subject
# Discard subjects with missing values of region
pseq = phyloseq::subset_samples(pseq, ! region %in% c("EE", "unknown"))

print(pseq)

## -----------------------------------------------------------------------------
set.seed(123)
# Linear relationships
res_linear = secom_linear(data = list(pseq), taxa_are_rows = TRUE,
                          tax_level = "Phylum", 
                          aggregate_data = NULL, meta_data = NULL, pseudo = 0, 
                          prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5, 
                          wins_quant = c(0.05, 0.95), method = "pearson", 
                          soft = FALSE, thresh_len = 20, n_cv = 10, 
                          thresh_hard = 0.3, max_p = 0.005, n_cl = 2)

# Nonlinear relationships
res_dist = secom_dist(data = list(pseq), taxa_are_rows = TRUE,
                      tax_level = "Phylum", 
                      aggregate_data = NULL, meta_data = NULL, pseudo = 0, 
                      prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5, 
                      wins_quant = c(0.05, 0.95), R = 1000, 
                      thresh_hard = 0.3, max_p = 0.005, n_cl = 2)

## ----eval=FALSE---------------------------------------------------------------
# set.seed(123)
# # Linear relationships
# res_linear2 = secom_linear(data = list(pseq), taxa_are_rows = TRUE,
#                            tax_level = "Phylum",
#                            aggregate_data = NULL, meta_data = NULL, pseudo = 0,
#                            prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5,
#                            wins_quant = c(0.05, 0.95), method = "pearson",
#                            soft = FALSE, alpha_grid = 0.1,
#                            thresh_len = 20, n_cv = 10,
#                            thresh_hard = 0.3, max_p = 0.005, n_cl = 2)

## -----------------------------------------------------------------------------
corr_linear = res_linear$corr_th
cooccur_linear = res_linear$mat_cooccur

# Filter by co-occurrence
overlap = 10
corr_linear[cooccur_linear < overlap] = 0

df_linear = data.frame(get_upper_tri(corr_linear)) %>%
  rownames_to_column("var1") %>%
  pivot_longer(cols = -var1, names_to = "var2", values_to = "value") %>%
  filter(!is.na(value)) %>%
  mutate(value = round(value, 2))

tax_name = sort(union(df_linear$var1, df_linear$var2))
df_linear$var1 = factor(df_linear$var1, levels = tax_name)
df_linear$var2 = factor(df_linear$var2, levels = tax_name)

heat_linear_th = df_linear %>%
  ggplot(aes(var2, var1, fill = value)) +
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", na.value = "grey",
                       midpoint = 0, limit = c(-1,1), space = "Lab", 
                       name = NULL) +
  scale_x_discrete(drop = FALSE) +
  scale_y_discrete(drop = FALSE) +
  geom_text(aes(var2, var1, label = value), color = "black", size = 4) +
  labs(x = NULL, y = NULL, title = "Pearson (Thresholding)") +
  theme_bw() +
  theme(axis.text.x = element_text(angle = 45, vjust = 1, size = 12, hjust = 1, 
                                   face = "italic"),
        axis.text.y = element_text(size = 12, face = "italic"),
        strip.text.x = element_text(size = 14),
        strip.text.y = element_text(size = 14),
        legend.text = element_text(size = 12),
        plot.title = element_text(hjust = 0.5, size = 15),
        panel.grid.major = element_blank(),
        axis.ticks = element_blank(),
        legend.position = "none") +
  coord_fixed()

heat_linear_th

## -----------------------------------------------------------------------------
corr_linear = res_linear$corr_fl
cooccur_linear = res_linear$mat_cooccur

# Filter by co-occurrence
overlap = 10
corr_linear[cooccur_linear < overlap] = 0

df_linear = data.frame(get_upper_tri(corr_linear)) %>%
  rownames_to_column("var1") %>%
  pivot_longer(cols = -var1, names_to = "var2", values_to = "value") %>%
  filter(!is.na(value)) %>%
  mutate(value = round(value, 2))

tax_name = sort(union(df_linear$var1, df_linear$var2))
df_linear$var1 = factor(df_linear$var1, levels = tax_name)
df_linear$var2 = factor(df_linear$var2, levels = tax_name)

heat_linear_fl = df_linear %>%
  ggplot(aes(var2, var1, fill = value)) +
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", na.value = "grey",
                       midpoint = 0, limit = c(-1,1), space = "Lab", 
                       name = NULL) +
  scale_x_discrete(drop = FALSE) +
  scale_y_discrete(drop = FALSE) +
  geom_text(aes(var2, var1, label = value), color = "black", size = 4) +
  labs(x = NULL, y = NULL, title = "Pearson (Filtering)") +
  theme_bw() +
  theme(axis.text.x = element_text(angle = 45, vjust = 1, size = 12, hjust = 1, 
                                   face = "italic"),
        axis.text.y = element_text(size = 12, face = "italic"),
        strip.text.x = element_text(size = 14),
        strip.text.y = element_text(size = 14),
        legend.text = element_text(size = 12),
        plot.title = element_text(hjust = 0.5, size = 15),
        panel.grid.major = element_blank(),
        axis.ticks = element_blank(),
        legend.position = "none") +
  coord_fixed()

heat_linear_fl

## -----------------------------------------------------------------------------
corr_dist = res_dist$dcorr_fl
cooccur_dist = res_dist$mat_cooccur

# Filter by co-occurrence
overlap = 10
corr_dist[cooccur_dist < overlap] = 0

df_dist = data.frame(get_upper_tri(corr_dist)) %>%
  rownames_to_column("var1") %>%
  pivot_longer(cols = -var1, names_to = "var2", values_to = "value") %>%
  filter(!is.na(value)) %>%
  mutate(value = round(value, 2))

tax_name = sort(union(df_dist$var1, df_dist$var2))
df_dist$var1 = factor(df_dist$var1, levels = tax_name)
df_dist$var2 = factor(df_dist$var2, levels = tax_name)

heat_dist_fl = df_dist %>%
  ggplot(aes(var2, var1, fill = value)) +
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", na.value = "grey",
                       midpoint = 0, limit = c(-1,1), space = "Lab", 
                       name = NULL) +
  scale_x_discrete(drop = FALSE) +
  scale_y_discrete(drop = FALSE) +
  geom_text(aes(var2, var1, label = value), color = "black", size = 4) +
  labs(x = NULL, y = NULL, title = "Distance (Filtering)") +
  theme_bw() +
  theme(axis.text.x = element_text(angle = 45, vjust = 1, size = 12, hjust = 1, 
                                   face = "italic"),
        axis.text.y = element_text(size = 12, face = "italic"),
        strip.text.x = element_text(size = 14),
        strip.text.y = element_text(size = 14),
        legend.text = element_text(size = 12),
        plot.title = element_text(hjust = 0.5, size = 15),
        panel.grid.major = element_blank(),
        axis.ticks = element_blank(),
        legend.position = "none") +
  coord_fixed()

heat_dist_fl

## ----eval=FALSE---------------------------------------------------------------
# tse = mia::makeTreeSummarizedExperimentFromPhyloseq(atlas1006)
# tse = tse[, tse$time == 0]
# tse$bmi = recode(tse$bmi_group,
#                  obese = "obese",
#                  severeobese = "obese",
#                  morbidobese = "obese")
# tse = tse[, tse$bmi %in% c("lean", "overweight", "obese")]
# tse$bmi = factor(tse$bmi, levels = c("obese", "overweight", "lean"))
# tse$region = recode(as.character(tse$nationality),
#                     Scandinavia = "NE", UKIE = "NE", SouthEurope = "SE",
#                     CentralEurope = "CE", EasternEurope = "EE",
#                     .missing = "unknown")
# tse = tse[, ! tse$region %in% c("EE", "unknown")]
# 
# set.seed(123)
# # Linear relationships
# res_linear = secom_linear(data = list(tse), taxa_are_rows = TRUE,
#                           assay_name = "counts", tax_level = "Phylum",
#                           aggregate_data = NULL, meta_data = NULL, pseudo = 0,
#                           prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5,
#                           wins_quant = c(0.05, 0.95), method = "pearson",
#                           soft = FALSE, thresh_len = 20, n_cv = 10,
#                           thresh_hard = 0.3, max_p = 0.005, n_cl = 2)
# 
# # Nonlinear relationships
# res_dist = secom_dist(data = list(tse), taxa_are_rows = TRUE,
#                       assay_name = "counts", tax_level = "Phylum",
#                       aggregate_data = NULL, meta_data = NULL, pseudo = 0,
#                       prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5,
#                       wins_quant = c(0.05, 0.95), R = 1000,
#                       thresh_hard = 0.3, max_p = 0.005, n_cl = 2)

## ----eval=FALSE---------------------------------------------------------------
# abundance_data = microbiome::abundances(pseq)
# aggregate_data = microbiome::abundances(microbiome::aggregate_taxa(pseq, "Phylum"))
# meta_data = microbiome::meta(pseq)
# 
# set.seed(123)
# # Linear relationships
# res_linear = secom_linear(data = list(abundance_data),
#                           taxa_are_rows = TRUE,
#                           aggregate_data = list(aggregate_data),
#                           meta_data = list(meta_data),
#                           pseudo = 0,
#                           prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5,
#                           wins_quant = c(0.05, 0.95), method = "pearson",
#                           soft = FALSE, thresh_len = 20, n_cv = 10,
#                           thresh_hard = 0.3, max_p = 0.005, n_cl = 2)
# 
# # Nonlinear relationships
# res_dist = secom_dist(data = list(abundance_data),
#                       taxa_are_rows = TRUE,
#                       aggregate_data = list(aggregate_data),
#                       meta_data = list(meta_data),
#                       pseudo = 0,
#                       prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5,
#                       wins_quant = c(0.05, 0.95), R = 1000,
#                       thresh_hard = 0.3, max_p = 0.005, n_cl = 2)

## -----------------------------------------------------------------------------
# Select subjects from "CE" and "NE"
pseq1 = phyloseq::subset_samples(pseq, region == "CE")
pseq2 = phyloseq::subset_samples(pseq, region == "NE")
phyloseq::sample_names(pseq1) = paste0("Sample-", seq_len(phyloseq::nsamples(pseq1)))
phyloseq::sample_names(pseq2) = paste0("Sample-", seq_len(phyloseq::nsamples(pseq2)))

print(pseq1)
print(pseq2)

## -----------------------------------------------------------------------------
set.seed(123)
# Linear relationships
res_linear = secom_linear(data = list(CE = pseq1, NE = pseq2), 
                          taxa_are_rows = TRUE,
                          tax_level = c("Phylum", "Phylum"), 
                          aggregate_data = NULL, meta_data = NULL, pseudo = 0, 
                          prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5, 
                          wins_quant = c(0.05, 0.95), method = "pearson", 
                          soft = FALSE, thresh_len = 20, n_cv = 10, 
                          thresh_hard = 0.3, max_p = 0.005, n_cl = 2)

# Nonlinear relationships
res_dist = secom_dist(data = list(CE = pseq1, NE = pseq2),
                      taxa_are_rows = TRUE,
                      tax_level = c("Phylum", "Phylum"), 
                      aggregate_data = NULL, meta_data = NULL, pseudo = 0, 
                      prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5, 
                      wins_quant = c(0.05, 0.95), R = 1000, 
                      thresh_hard = 0.3, max_p = 0.005, n_cl = 2)

## ----fig.width=8, fig.height=8------------------------------------------------
corr_linear = res_linear$corr_th
cooccur_linear = res_linear$mat_cooccur

# Filter by co-occurrence
overlap = 10
corr_linear[cooccur_linear < overlap] = 0

df_linear = data.frame(get_upper_tri(corr_linear)) %>%
  rownames_to_column("var1") %>%
  pivot_longer(cols = -var1, names_to = "var2", values_to = "value") %>%
  filter(!is.na(value)) %>%
  mutate(var2 = gsub("\\...", " - ", var2),
         value = round(value, 2))

tax_name = sort(union(df_linear$var1, df_linear$var2))
df_linear$var1 = factor(df_linear$var1, levels = tax_name)
df_linear$var2 = factor(df_linear$var2, levels = tax_name)
txt_color = ifelse(grepl("CE", tax_name), "#1B9E77", "#D95F02")

heat_linear_th = df_linear %>%
  ggplot(aes(var2, var1, fill = value)) +
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                       na.value = "grey", midpoint = 0, limit = c(-1,1), 
                       space = "Lab", name = NULL) +
  scale_x_discrete(drop = FALSE) +
  scale_y_discrete(drop = FALSE) +
  geom_text(aes(var2, var1, label = value), color = "black", size = 4) +
  labs(x = NULL, y = NULL, title = "Pearson (Thresholding)") +
  theme_bw() +
  geom_vline(xintercept = 6.5, color = "blue", linetype = "dashed") +
  geom_hline(yintercept = 6.5, color = "blue", linetype = "dashed") +
  theme(axis.text.x = element_text(angle = 45, vjust = 1, size = 12, hjust = 1, 
                                   face = "italic", color = txt_color),
        axis.text.y = element_text(size = 12, face = "italic", 
                                   color = txt_color),
        strip.text.x = element_text(size = 14),
        strip.text.y = element_text(size = 14),
        legend.text = element_text(size = 12),
        plot.title = element_text(hjust = 0.5, size = 15),
        panel.grid.major = element_blank(),
        axis.ticks = element_blank(),
        legend.position = "none") +
  coord_fixed()

heat_linear_th

## ----fig.width=8, fig.height=8------------------------------------------------
corr_linear = res_linear$corr_th
cooccur_linear = res_linear$mat_cooccur

# Filter by co-occurrence
overlap = 10
corr_linear[cooccur_linear < overlap] = 0

df_linear = data.frame(get_upper_tri(corr_linear)) %>%
  rownames_to_column("var1") %>%
  pivot_longer(cols = -var1, names_to = "var2", values_to = "value") %>%
  filter(!is.na(value)) %>%
  mutate(var2 = gsub("\\...", " - ", var2),
         value = round(value, 2))

tax_name = sort(union(df_linear$var1, df_linear$var2))
df_linear$var1 = factor(df_linear$var1, levels = tax_name)
df_linear$var2 = factor(df_linear$var2, levels = tax_name)
txt_color = ifelse(grepl("CE", tax_name), "#1B9E77", "#D95F02")

heat_linear_fl = df_linear %>%
  ggplot(aes(var2, var1, fill = value)) +
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                       na.value = "grey", midpoint = 0, limit = c(-1,1), 
                       space = "Lab", name = NULL) +
  scale_x_discrete(drop = FALSE) +
  scale_y_discrete(drop = FALSE) +
  geom_text(aes(var2, var1, label = value), color = "black", size = 4) +
  labs(x = NULL, y = NULL, title = "Pearson (Filtering)") +
  theme_bw() +
  geom_vline(xintercept = 6.5, color = "blue", linetype = "dashed") +
  geom_hline(yintercept = 6.5, color = "blue", linetype = "dashed") +
  theme(axis.text.x = element_text(angle = 45, vjust = 1, size = 12, hjust = 1, 
                                   face = "italic", color = txt_color),
        axis.text.y = element_text(size = 12, face = "italic", 
                                   color = txt_color),
        strip.text.x = element_text(size = 14),
        strip.text.y = element_text(size = 14),
        legend.text = element_text(size = 12),
        plot.title = element_text(hjust = 0.5, size = 15),
        panel.grid.major = element_blank(),
        axis.ticks = element_blank(),
        legend.position = "none") +
  coord_fixed()

heat_linear_fl

## ----fig.width=8, fig.height=8------------------------------------------------
corr_dist = res_dist$dcorr_fl
cooccur_dist = res_dist$mat_cooccur

# Filter by co-occurrence
overlap = 10
corr_dist[cooccur_dist < overlap] = 0

df_dist = data.frame(get_upper_tri(corr_dist)) %>%
  rownames_to_column("var1") %>%
  pivot_longer(cols = -var1, names_to = "var2", values_to = "value") %>%
  filter(!is.na(value)) %>%
  mutate(var2 = gsub("\\...", " - ", var2),
         value = round(value, 2))

tax_name = sort(union(df_dist$var1, df_dist$var2))
df_dist$var1 = factor(df_dist$var1, levels = tax_name)
df_dist$var2 = factor(df_dist$var2, levels = tax_name)
txt_color = ifelse(grepl("CE", tax_name), "#1B9E77", "#D95F02")

heat_dist_fl = df_dist %>%
  ggplot(aes(var2, var1, fill = value)) +
  geom_tile(color = "black") +
  scale_fill_gradient2(low = "blue", high = "red", mid = "white", 
                       na.value = "grey", midpoint = 0, limit = c(-1,1), 
                       space = "Lab", name = NULL) +
  scale_x_discrete(drop = FALSE) +
  scale_y_discrete(drop = FALSE) +
  geom_text(aes(var2, var1, label = value), color = "black", size = 4) +
  labs(x = NULL, y = NULL, title = "Distance (Filtering)") +
  theme_bw() +
  geom_vline(xintercept = 6.5, color = "blue", linetype = "dashed") +
  geom_hline(yintercept = 6.5, color = "blue", linetype = "dashed") +
  theme(axis.text.x = element_text(angle = 45, vjust = 1, size = 12, hjust = 1, 
                                   face = "italic", color = txt_color),
        axis.text.y = element_text(size = 12, face = "italic", 
                                   color = txt_color),
        strip.text.x = element_text(size = 14),
        strip.text.y = element_text(size = 14),
        legend.text = element_text(size = 12),
        plot.title = element_text(hjust = 0.5, size = 15),
        panel.grid.major = element_blank(),
        axis.ticks = element_blank(),
        legend.position = "none") +
  coord_fixed()

heat_dist_fl

## ----eval=FALSE---------------------------------------------------------------
# # Select subjects from "CE" and "NE"
# tse1 = tse[, tse$region == "CE"]
# tse2 = tse[, tse$region == "NE"]
# 
# # Rename samples to ensure there is an overlap of samples between CE and NE
# colnames(tse1) = paste0("Sample-", seq_len(ncol(tse1)))
# colnames(tse2) = paste0("Sample-", seq_len(ncol(tse2)))
# 
# set.seed(123)
# # Linear relationships
# res_linear = secom_linear(data = list(CE = tse1, NE = tse2),
#                           taxa_are_rows = TRUE,
#                           assay_name = c("counts", "counts"),
#                           tax_level = c("Phylum", "Phylum"),
#                           aggregate_data = NULL, meta_data = NULL, pseudo = 0,
#                           prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5,
#                           wins_quant = c(0.05, 0.95), method = "pearson",
#                           soft = FALSE, thresh_len = 20, n_cv = 10,
#                           thresh_hard = 0.3, max_p = 0.005, n_cl = 2)
# 
# # Nonlinear relationships
# res_dist = secom_dist(data = list(CE = tse1, NE = tse2),
#                       taxa_are_rows = TRUE,
#                       assay_name = c("counts", "counts"),
#                       tax_level = c("Phylum", "Phylum"),
#                       aggregate_data = NULL, meta_data = NULL, pseudo = 0,
#                       prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5,
#                       wins_quant = c(0.05, 0.95), R = 1000,
#                       thresh_hard = 0.3, max_p = 0.005, n_cl = 2)

## ----eval=FALSE---------------------------------------------------------------
# ce_idx = which(meta_data$region == "CE")
# ne_idx = which(meta_data$region == "NE")
# 
# abundance_data1 = abundance_data[, ce_idx]
# abundance_data2 = abundance_data[, ne_idx]
# aggregate_data1 = aggregate_data[, ce_idx]
# aggregate_data2 = aggregate_data[, ne_idx]
# meta_data1 = meta_data[ce_idx, ]
# meta_data2 = meta_data[ne_idx, ]
# 
# sample_size1 = ncol(abundance_data1)
# sample_size2 = ncol(abundance_data2)
# colnames(abundance_data1) = paste0("Sample-", seq_len(sample_size1))
# colnames(abundance_data2) = paste0("Sample-", seq_len(sample_size2))
# rownames(meta_data1) = paste0("Sample-", seq_len(sample_size1))
# rownames(meta_data2) = paste0("Sample-", seq_len(sample_size2))
# 
# set.seed(123)
# # Linear relationships
# res_linear = secom_linear(data = list(CE = abundance_data1, NE = abundance_data2),
#                           taxa_are_rows = TRUE,
#                           aggregate_data = list(aggregate_data1, aggregate_data2),
#                           meta_data = list(meta_data1, meta_data2),
#                           pseudo = 0,
#                           prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5,
#                           wins_quant = c(0.05, 0.95), method = "pearson",
#                           soft = FALSE, thresh_len = 20, n_cv = 10,
#                           thresh_hard = 0.3, max_p = 0.005, n_cl = 2)
# 
# # Nonlinear relationships
# res_dist = secom_dist(data = list(CE = abundance_data1, NE = abundance_data2),
#                       taxa_are_rows = TRUE,
#                       aggregate_data = list(aggregate_data1, aggregate_data2),
#                       meta_data = list(meta_data1, meta_data2),
#                       pseudo = 0,
#                       prv_cut = 0.5, lib_cut = 1000, corr_cut = 0.5,
#                       wins_quant = c(0.05, 0.95), R = 1000,
#                       thresh_hard = 0.3, max_p = 0.005, n_cl = 2)

## ----sessionInfo, message = FALSE, warning = FALSE, comment = NA--------------
sessionInfo()