Cross-validation for NMF rank determination
Four methods for cross-validation of non-negative matrix factorizations
Cross-Validation for NMF
Rank is the most important hyperparameter in NMF. Finding that “sweet spot” rank can make the difference between learning a useful model that captures meaningful signal (but not noise) or learning a garbage model that misses good signal or focuses too much on useless noise.
Alex Williams has posted a great introduction to cross-validation for NMF on his blog. His review of the first two methods is particularly intuitive. However, the third method is both theoretically questionable and poor in practice.
There are three “unsupervised” cross-validation methods for NMF which I have found to be useful:
- Bi-cross-validation, proposed by Perry and explained simply by Williams. The “Bi-” in “Bi-cross-validation” means that the model is trained on a block of randomly selected samples and features and evaluated on a non-intersecting block of samples and features. Thus, no samples or features in the test set are included in the training set. If the test and training sets contain samples in common, or features in common, NMF gets to “cheat” in training and directly infer patterns of regulation, and thus basic subsample-cross-validation with NMF does not work.
- Imputation, described nicely by Lin and also reviewed in this StackExchange post by amoeba. Here, a small fraction of values (i.e. 5%) are “masked” and considered as missing during factorization, and the mean squared error of the imputed values is calculated after model training.
- Robustness is simply the cosine similarity of matched factors in independent models trained on non-overlapping sample sets. The premise is that noise capture will result in low similarity, while efficient signal capture will result in high similarity. Furthermore, approximations which are too low-rank will not classify signals in the same manner, leading to poor factor matching.
Takeaways
- The
project
method (bi-cross-validation) is useful for well-conditioned signal. - The
robust
method (similarity of independent factorizations) is generally the most informative for noisy data possibly suffering from signal dropout. - The
imputation
method is the slowest of the three, but generally the most sensitive.
Install RcppML
Install the development version of RcppML:
devtools::install_github("zdebruine/RcppML")
library(RcppML)
library(ggplot2)
library(cowplot)
library(umap)
library(irlba)
Simulated data
Simulated data is useful for demonstrating the utility of methods in response to adversarial perturbations such as noise or dropout.
We will first explore cross-validation using two simulated datasets generated with simulateNMF
:
data_clean
will have no noise or signal dropoutdata_dirty
contains the same signal asdata_clean
, but with a good amount of noise and dropout.
data_clean <- simulateNMF(nrow = 200, ncol = 200, k = 5, noise = 0, dropout = 0, seed = 123)
data_dirty <- simulateNMF(nrow = 200, ncol = 200, k = 5, noise = 0.5, dropout = 0.5, seed = 123)
Notice how data_clean
contains only 5 non-zero singular values, while data_dirty
does not:
We can use RcppML::crossValidate
to determine the rank of each dataset. The default method uses “bi-cross-validation”. See ?crossValidate
for details.
cv_clean <- crossValidate(data_clean, k = 1:10, method = "predict", reps = 3, seed = 123)
cv_dirty <- crossValidate(data_dirty, k = 1:10, method = "predict", reps = 3, seed = 123)
plot_grid(
plot(cv_clean) + ggtitle("bi-cross-validation on\nclean dataset"),
plot(cv_dirty) + ggtitle("bi-cross-validation on\ndirty dataset"), nrow = 1)
crossValidate
also supports another method which compares robustness of two factorizations on independent sample subsets.
cv_clean <- crossValidate(data_clean, k = 1:10, method = "robust", reps = 3, seed = 123)
cv_dirty <- crossValidate(data_dirty, k = 1:10, method = "robust", reps = 3, seed = 123)
plot_grid(
plot(cv_clean) + ggtitle("robust cross-validation on\nclean dataset"),
plot(cv_dirty) + ggtitle("robust cross-validation on\ndirty dataset"), nrow = 1)
This second method does better on ill-conditioned data because it measures the robustness between independent factorizations.
Finally, we can use the impute
method:
cv_clean <- crossValidate(data_clean, k = 1:10, method = "impute", reps = 3, seed = 123)
cv_dirty <- crossValidate(data_dirty, k = 1:10, method = "impute", reps = 3, seed = 123)
plot_grid(
plot(cv_clean) + ggtitle("impute cross-validation on\nclean dataset") + scale_y_continuous(trans = "log10"),
plot(cv_dirty) + ggtitle("impute cross-validation on\ndirty dataset") + scale_y_continuous(trans = "log10"), nrow = 1)
For real datasets, it is important to experiment with both cross-validation methods and to explore multi-resolution analysis or other objectives where appropriate.
Let’s take a look at a real dataset:
Finding the rank of the hawaiibirds
dataset
data(hawaiibirds)
A <- hawaiibirds$counts
cv_predict <- crossValidate(A, k = 1:20, method = "predict", reps = 3, seed = 123)
cv_robust <- crossValidate(A, k = 1:20, method = "robust", reps = 3, seed = 123)
cv_impute <- crossValidate(A, k = 1:20, method = "impute", reps = 3, seed = 123)
plot_grid(
plot(cv_predict) + ggtitle("method = 'predict'") + theme(legend.position = "none"),
plot(cv_robust) + ggtitle("method = 'robust'") + theme(legend.position = "none"),
plot(cv_impute) + ggtitle("method = 'impute'") + scale_y_continuous(trans = "log10") + theme(legend.position = "none"),
get_legend(plot(cv_predict)), rel_widths = c(1, 1, 1, 0.4), nrow = 1, labels = "auto")
Finding the rank of the aml
dataset
data(aml)
cv_impute <- crossValidate(aml, k = 2:14, method = "impute", reps = 3, seed = 123)
plot(cv_impute) + scale_y_continuous(trans = "log10")
Technical considerations
Runtime is a major consideration for large datasets. Unfortunately, missing value imputation can be very slow.
Perturb
Compare missing value imputation with perturb (zeros) and perturb (random):
data(hawaiibirds)
data(aml)
data(movielens)
library(Seurat)
## Warning: package 'Seurat' was built under R version 4.0.5
## Attaching SeuratObject
library(SeuratData)
## Registered S3 method overwritten by 'cli':
## method from
## print.boxx spatstat.geom
## -- Installed datasets ------------------------------------- SeuratData v0.2.1 --
## v bmcite 0.3.0 v pbmc3k 3.1.4
## v hcabm40k 3.0.0 v pbmcMultiome 0.1.0
## v ifnb 3.1.0 v pbmcsca 3.0.0
## v panc8 3.0.2 v stxBrain 0.1.1
## -------------------------------------- Key -------------------------------------
## v Dataset loaded successfully
## > Dataset built with a newer version of Seurat than installed
## (?) Unknown version of Seurat installed
pbmc3k
## An object of class Seurat
## 13714 features across 2700 samples within 1 assay
## Active assay: RNA (13714 features, 0 variable features)
A <- pbmc3k@assays$RNA@counts
n <- 0.2
method = "impute"
cv1 <- crossValidate(A, k = 1:15, method = method, reps = 3, seed = 123, perturb_to = "random", n = n)
cv2 <- crossValidate(aml, k = 1:15, method = method, reps = 3, seed = 123, perturb_to = "random", n = n)
cv3 <- crossValidate(movielens$ratings, k = 1:15, method = method, reps = 3, seed = 123, perturb_to = "random", n = n)
cv4 <- crossValidate(hawaiibirds$counts, k = 1:15, method = method, reps = 3, seed = 123, perturb_to = "random", n = n)
plot_grid(
plot(cv1) + theme(legend.position = "none") + scale_y_continuous(trans = "log10"),
plot(cv2) + theme(legend.position = "none") + scale_y_continuous(trans = "log10"),
plot(cv3) + theme(legend.position = "none") + scale_y_continuous(trans = "log10"),
plot(cv4) + theme(legend.position = "none") + scale_y_continuous(trans = "log10"),
nrow = 2)