Module VITAE.train
Functions
def clear_session()-
Clear Tensorflow sessions.
def warp_dataset(X_normalized, c_score, batch_size: int, X=None, scale_factor=None, conditions=None, pi_cov=None, seed=0)-
Get Tensorflow datasets.
Parameters
X_normalized:np.array- [N, G] The preprocessed data.
c_score:float, optional- The normalizing constant.
batch_size:int- The batch size.
X:np.array, optional- [N, G] The raw count data.
scale_factor:np.array, optional- [N, ] The raw count data.
seed:int, optional- The random seed for data shuffling.
conditions:strorlist, optional- The conditions of different cells
Returns
dataset:tf.Dataset- The Tensorflow Dataset object.
def pre_train(train_dataset, test_dataset, vae, learning_rate: float, L: int, alpha: float, gamma: float, phi: float, num_epoch: int, num_step_per_epoch: int, es_patience: int, es_tolerance: int, es_relative: bool, verbose: bool = True)-
Pretraining.
Parameters
train_dataset:tf.Dataset- The Tensorflow Dataset object.
test_dataset:tf.Dataset- The Tensorflow Dataset object.
vae:VariationalAutoEncoder- The model.
learning_rate:float- The initial learning rate for the Adam optimizer.
L:int- The number of MC samples.
alpha:float, optional- The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
phi:float, optional- The weight of Jocob norm of the encoder.
num_epoch:int- The maximum number of epoches.
num_step_per_epoch:int- The number of step per epoch, it will be inferred from number of cells and batch size if it is None.
es_patience:int- The maximum number of epoches if there is no improvement.
es_tolerance:float- The minimum change of loss to be considered as an improvement.
es_relative:bool, optional- Whether monitor the relative change of loss or not.
es_warmup:int, optional- The number of warmup epoches.
conditions:strorlist- The conditions of different cells
Returns
vae:VariationalAutoEncoder- The pretrained model.
def train(train_dataset, test_dataset, vae, learning_rate: float, L: int, alpha: float, beta: float, gamma: float, phi: float, num_epoch: int, num_step_per_epoch: int, es_patience: int, es_tolerance: float, es_relative: bool, es_warmup: int, verbose: bool = False, pi_cov=None, **kwargs)-
Training.
Parameters
train_dataset:tf.Dataset- The Tensorflow Dataset object.
test_dataset:tf.Dataset- The Tensorflow Dataset object.
vae:VariationalAutoEncoder- The model.
learning_rate:float- The initial learning rate for the Adam optimizer.
L:int- The number of MC samples.
alpha:float- The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
beta:float- The value of beta in beta-VAE.
gamma:float- The weight of mmd_loss.
phi:float- The weight of Jacob norm of the encoder.
num_epoch:int- The maximum number of epoches.
num_step_per_epoch:int- The number of step per epoch, it will be inferred from number of cells and batch size if it is None.
es_patience:int- The maximum number of epoches if there is no improvement.
es_tolerance:float, optional- The minimum change of loss to be considered as an improvement.
es_relative:bool, optional- Whether monitor the relative change of loss or not.
es_warmup:int- The number of warmup epoches.
**kwargs- Extra key-value arguments for dimension reduction algorithms.
Returns
vae:VariationalAutoEncoder- The trained model.