Module VITAE.train
Functions
def clear_session()
-
Clear Tensorflow sessions.
def warp_dataset(X_normalized, c_score, batch_size: int, X=None, scale_factor=None, conditions=None, pi_cov=None, seed=0)
-
Get Tensorflow datasets.
Parameters
X_normalized
:np.array
- [N, G] The preprocessed data.
c_score
:float
, optional- The normalizing constant.
batch_size
:int
- The batch size.
X
:np.array
, optional- [N, G] The raw count data.
scale_factor
:np.array
, optional- [N, ] The raw count data.
seed
:int
, optional- The random seed for data shuffling.
conditions
:str
orlist
, optional- The conditions of different cells
Returns
dataset
:tf.Dataset
- The Tensorflow Dataset object.
def pre_train(train_dataset, test_dataset, vae, learning_rate: float, L: int, alpha: float, gamma: float, phi: float, num_epoch: int, num_step_per_epoch: int, es_patience: int, es_tolerance: int, es_relative: bool, verbose: bool = True)
-
Pretraining.
Parameters
train_dataset
:tf.Dataset
- The Tensorflow Dataset object.
test_dataset
:tf.Dataset
- The Tensorflow Dataset object.
vae
:VariationalAutoEncoder
- The model.
learning_rate
:float
- The initial learning rate for the Adam optimizer.
L
:int
- The number of MC samples.
alpha
:float
, optional- The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
phi
:float
, optional- The weight of Jocob norm of the encoder.
num_epoch
:int
- The maximum number of epoches.
num_step_per_epoch
:int
- The number of step per epoch, it will be inferred from number of cells and batch size if it is None.
es_patience
:int
- The maximum number of epoches if there is no improvement.
es_tolerance
:float
- The minimum change of loss to be considered as an improvement.
es_relative
:bool
, optional- Whether monitor the relative change of loss or not.
es_warmup
:int
, optional- The number of warmup epoches.
conditions
:str
orlist
- The conditions of different cells
Returns
vae
:VariationalAutoEncoder
- The pretrained model.
def train(train_dataset, test_dataset, vae, learning_rate: float, L: int, alpha: float, beta: float, gamma: float, phi: float, num_epoch: int, num_step_per_epoch: int, es_patience: int, es_tolerance: float, es_relative: bool, es_warmup: int, verbose: bool = False, pi_cov=None, **kwargs)
-
Training.
Parameters
train_dataset
:tf.Dataset
- The Tensorflow Dataset object.
test_dataset
:tf.Dataset
- The Tensorflow Dataset object.
vae
:VariationalAutoEncoder
- The model.
learning_rate
:float
- The initial learning rate for the Adam optimizer.
L
:int
- The number of MC samples.
alpha
:float
- The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
beta
:float
- The value of beta in beta-VAE.
gamma
:float
- The weight of mmd_loss.
phi
:float
- The weight of Jacob norm of the encoder.
num_epoch
:int
- The maximum number of epoches.
num_step_per_epoch
:int
- The number of step per epoch, it will be inferred from number of cells and batch size if it is None.
es_patience
:int
- The maximum number of epoches if there is no improvement.
es_tolerance
:float
, optional- The minimum change of loss to be considered as an improvement.
es_relative
:bool
, optional- Whether monitor the relative change of loss or not.
es_warmup
:int
- The number of warmup epoches.
**kwargs
- Extra key-value arguments for dimension reduction algorithms.
Returns
vae
:VariationalAutoEncoder
- The trained model.