Package VITAE

__all__ = ['VITAE']

from VITAE.utils import load_data, reset_random_seeds

__author__ = "Jin-Hong Du, Tianyu Chen, Ming Gao, and Jingshu Wang"
__copyright__ = "Copyright 2020, The Trajectory Inference Project"

__license__ = "MIT"
__version__ = "2.0.8"
__maintainer__ = "Jin-Hong Du, Tianyu Chen, and Jingshu Wang"
__email__ = ""
__status__ = "Development"




class VITAE (adata: anndata._core.anndata.AnnData, covariates=None, pi_covariates=None, model_type: str = 'Gaussian', npc: int = 64, adata_layer_counts=None, copy_adata: bool = False, hidden_layers=[32], latent_space_dim: int = 16, conditions=None)

Variational Inference for Trajectory by AutoEncoder.

Get input data for model. Data need to be first processed using scancy and stored as an AnnData object The 'UMI' or 'non-UMI' model need the original count matrix, so the count matrix need to be saved in adata.layers in order to use these models.


adata : sc.AnnData
The scanpy AnnData object. adata should already contain adata.var.highly_variable
covariates : list, optional
A list of names of covariate vectors that are stored in adata.obs
pi_covariates : list, optional
A list of names of covariate vectors used as input for pilayer
model_type : str, optional
'UMI', 'non-UMI' and 'Gaussian', default is 'Gaussian'.
npc : int, optional
The number of PCs to use when model_type is 'Gaussian'. The default is 64.
adata_layer_counts : str, optional
the key name of adata.layers that stores the count data if model_type is 'UMI' or 'non-UMI'
copy_adata : bool, optional. Set to True if we don't want VITAE to modify the original adata. If set to True, self.adata will be an independent copy of the original adata.
hidden_layers : list, optional
The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes
latent_space_dim : int, optional
The dimension of latent space.
gamme : float, optional
The weight of the MMD loss
conditions : str or list, optional
The conditions of different cells



class VITAE():
    Variational Inference for Trajectory by AutoEncoder.
    def __init__(self, adata: sc.AnnData,
               covariates = None, pi_covariates = None,
               model_type: str = 'Gaussian',
               npc: int = 64,
               adata_layer_counts = None,
               copy_adata: bool = False,
               hidden_layers = [32],
               latent_space_dim: int = 16,
               conditions = None):
        Get input data for model. Data need to be first processed using scancy and stored as an AnnData object
         The 'UMI' or 'non-UMI' model need the original count matrix, so the count matrix need to be saved in
         adata.layers in order to use these models.

        adata : sc.AnnData
            The scanpy AnnData object. adata should already contain adata.var.highly_variable
        covariates : list, optional
            A list of names of covariate vectors that are stored in adata.obs
        pi_covariates: list, optional
            A list of names of covariate vectors used as input for pilayer
        model_type : str, optional
            'UMI', 'non-UMI' and 'Gaussian', default is 'Gaussian'.
        npc : int, optional
            The number of PCs to use when model_type is 'Gaussian'. The default is 64.
        adata_layer_counts: str, optional
            the key name of adata.layers that stores the count data if model_type is
            'UMI' or 'non-UMI'
        copy_adata: bool, optional. Set to True if we don't want VITAE to modify the original adata. If set to True, self.adata will be an independent copy of the original adata. 
        hidden_layers : list, optional
            The list of dimensions of layers of autoencoder between latent space and original space. Default is to have only one hidden layer with 32 nodes
        latent_space_dim : int, optional
            The dimension of latent space.
        gamme : float, optional
            The weight of the MMD loss
        conditions : str or list, optional
            The conditions of different cells


        self.dict_method_scname = {
            'PCA' : 'X_pca',
            'UMAP' : 'X_umap',
            'TSNE' : 'X_tsne',
            'diffmap' : 'X_diffmap',
            'draw_graph' : 'X_draw_graph_fa'

        if model_type != 'Gaussian':
            if adata_layer_counts is None:
                raise ValueError("need to provide the name in adata.layers that stores the raw count data")
            if 'highly_variable' not in adata.var:
                raise ValueError("need to first select highly variable genes using scanpy")

        self.model_type = model_type

        if copy_adata:
            self.adata = adata.copy()
            self.adata = adata

        if covariates is not None:
            if isinstance(covariates, str):
                covariates = [covariates]
            covariates = np.array(covariates)
            id_cat = (adata.obs[covariates].dtypes == 'category')
            # add OneHotEncoder & StandardScaler as class variable if needed
            if np.sum(id_cat)>0:
                covariates_cat = OneHotEncoder(drop='if_binary', handle_unknown='ignore'
                covariates_cat = np.array([]).reshape(adata.shape[0],0)

            # temporarily disable StandardScaler
            if np.sum(~id_cat)>0:
                #covariates_con = StandardScaler().fit_transform(adata.obs[covariates[~id_cat]])
                covariates_con = adata.obs[covariates[~id_cat]]
                covariates_con = np.array([]).reshape(adata.shape[0],0)

            self.covariates = np.c_[covariates_cat, covariates_con].astype(tf.keras.backend.floatx())
            self.covariates = None

        if conditions is not None:
            ## observations with np.nan will not participant in calculating mmd_loss
            if isinstance(conditions, str):
                conditions = [conditions]
            conditions = np.array(conditions)
            if np.any(adata.obs[conditions].dtypes != 'category'):
                raise ValueError("Conditions should all be categorical.")

            self.conditions = OrdinalEncoder(dtype=int, encoded_missing_value=-1).fit_transform(adata.obs[conditions]) + int(1)
            self.conditions = None

        if pi_covariates is not None:
            self.pi_cov = adata.obs[pi_covariates].to_numpy()
            if self.pi_cov.ndim == 1:
                self.pi_cov = self.pi_cov.reshape(-1, 1)
                self.pi_cov = self.pi_cov.astype(tf.keras.backend.floatx())
            self.pi_cov = np.zeros((adata.shape[0],1), dtype=tf.keras.backend.floatx())
        self.model_type = model_type
        self._adata = sc.AnnData(X = self.adata.X, var = self.adata.var)
        self._adata.obs = self.adata.obs
        self._adata.uns = self.adata.uns

        if model_type == 'Gaussian':
  , n_comps = npc)
            self.X_input = self.X_output = adata.obsm['X_pca']
            self.scale_factor = np.ones(self.X_output.shape[0])
            print(f"{adata.var.highly_variable.sum()} highly variable genes selected as input") 
            self.X_input = adata.X[:, adata.var.highly_variable]
            self.X_output = adata.layers[adata_layer_counts][ :, adata.var.highly_variable]
            self.scale_factor = np.sum(self.X_output, axis=1, keepdims=True)/1e4

        self.dimensions = hidden_layers
        self.dim_latent = latent_space_dim

        self.vae = model.VariationalAutoEncoder(
            self.X_output.shape[1], self.dimensions,
            self.dim_latent, self.model_type,
            False if self.covariates is None else True,

        if hasattr(self, 'inferer'):
            delattr(self, 'inferer')

    def pre_train(self, test_size = 0.1, random_state: int = 0,
            learning_rate: float = 1e-3, batch_size: int = 256, L: int = 1, alpha: float = 0.10, gamma: float = 0,
            phi : float = 1,num_epoch: int = 200, num_step_per_epoch: Optional[int] = None,
            early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, 
            early_stopping_relative: bool = True, verbose: bool = False,path_to_weights: Optional[str] = None):
        '''Pretrain the model with specified learning rate.

        test_size : float or int, optional
            The proportion or size of the test set.
        random_state : int, optional
            The random state for data splitting.
        learning_rate : float, optional
            The initial learning rate for the Adam optimizer.
        batch_size : int, optional 
            The batch size for pre-training.  Default is 256. Set to 32 if number of cells is small (less than 1000)
        L : int, optional 
            The number of MC samples.
        alpha : float, optional
            The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
        gamma : float, optional
            The weight of the mmd loss if used.
        phi : float, optional
            The weight of Jocob norm of the encoder.
        num_epoch : int, optional 
            The maximum number of epochs.
        num_step_per_epoch : int, optional 
            The number of step per epoch, it will be inferred from number of cells and batch size if it is None.            
        early_stopping_patience : int, optional 
            The maximum number of epochs if there is no improvement.
        early_stopping_tolerance : float, optional 
            The minimum change of loss to be considered as an improvement.
        early_stopping_relative : bool, optional
            Whether monitor the relative change of loss as stopping criteria or not.
        path_to_weights : str, optional 
            The path of weight file to be saved; not saving weight if None.
        conditions : str or list, optional
            The conditions of different cells

        id_train, id_test = train_test_split(
        if num_step_per_epoch is None:
            num_step_per_epoch = len(id_train)//batch_size+1
        self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()), 
                                                None if self.covariates is None else self.covariates[id_train].astype(tf.keras.backend.floatx()),
                                                conditions = None if self.conditions is None else self.conditions[id_train].astype(tf.keras.backend.floatx()))
        self.test_dataset = train.warp_dataset(self.X_input[id_test], 
                                                None if self.covariates is None else self.covariates[id_test].astype(tf.keras.backend.floatx()),
                                                conditions = None if self.conditions is None else self.conditions[id_test].astype(tf.keras.backend.floatx()))

        self.vae = train.pre_train(
            L, alpha, gamma, phi,

        if path_to_weights is not None:

    def update_z(self):
        self.z = self.get_latent_z()        
        self._adata_z = sc.AnnData(self.z)

    def get_latent_z(self):
        ''' get the posterier mean of current latent space z (encoder output)

        z : np.array
            \([N,d]\) The latent means.
        c = None if self.covariates is None else self.covariates
        return self.vae.get_z(self.X_input, c)
    def visualize_latent(self, method: str = "UMAP", 
                         color = None, **kwargs):
        visualize the current latent space z using the scanpy visualization tools

        method : str, optional
            Visualization method to use. The default is "draw_graph" (the FA plot). Possible choices include "PCA", "UMAP", 
            "diffmap", "TSNE" and "draw_graph"
        color : TYPE, optional
            Keys for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2'].
            The default is None. Same as scanpy.
        **kwargs :  
            Extra key-value arguments that can be passed to scanpy plotting functions (   


        if method not in ['PCA', 'UMAP', 'TSNE', 'diffmap', 'draw_graph']:
            raise ValueError("visualization method should be one of 'PCA', 'UMAP', 'TSNE', 'diffmap' and 'draw_graph'")
        temp = list(self._adata_z.obsm.keys())
        if method == 'PCA' and not 'X_pca' in temp:
            print("Calculate PCs ...")
        elif method == 'UMAP' and not 'X_umap' in temp:  
            print("Calculate UMAP ...")
        elif method == 'TSNE' and not 'X_tsne' in temp:
            print("Calculate TSNE ...")
        elif method == 'diffmap' and not 'X_diffmap' in temp:
            print("Calculate diffusion map ...")
        elif method == 'draw_graph' and not 'X_draw_graph_fa' in temp:
            print("Calculate FA ...")

        self._adata.obs = self.adata.obs.copy()
        self._adata.obsp = self._adata_z.obsp
#        self._adata.uns = self._adata_z.uns
        self._adata.obsm = self._adata_z.obsm
        if method == 'PCA':
            axes =, color = color, **kwargs)
        elif method == 'UMAP':            
            axes =, color = color, **kwargs)
        elif method == 'TSNE':
            axes =, color = color, **kwargs)
        elif method == 'diffmap':
            axes =, color = color, **kwargs)
        elif method == 'draw_graph':
            axes =, color = color, **kwargs)
        return axes

    def init_latent_space(self, cluster_label = None, log_pi = None, res: float = 1.0, 
                          ratio_prune= None, dist = None, dist_thres = 0.5, pilayer = False):
        '''Initialize the latent space.

        cluster_label : str, optional
            The name of vector of labels that can be found in self.adata.obs. 
            Default is None, which will perform leiden clustering on the pretrained z to get clusters
        mu : np.array, optional
            \([d,k]\) The value of initial \(\\mu\).
        log_pi : np.array, optional
            \([1,K]\) The value of initial \(\\log(\\pi)\).
            The resolution of leiden clustering, which is a parameter value controlling the coarseness of the clustering. 
            Higher values lead to more clusters. Deafult is 1.
        ratio_prune : float, optional
            The ratio of edges to be removed before estimating.
        if cluster_label is None:
            print("Perform leiden clustering on the latent space z ...")
            g = get_igraph(self.z)
            cluster_labels = leidenalg_igraph(g, res = res)
            cluster_labels = cluster_labels.astype(str) 
            uni_cluster_labels = np.unique(cluster_labels)
            if isinstance(cluster_label,str):
                cluster_labels = self.adata.obs[cluster_label].to_numpy()
                uni_cluster_labels = np.array(self.adata.obs[cluster_label].cat.categories)
                ## if cluster_label is a list
                cluster_labels = cluster_label
                uni_cluster_labels = np.unique(cluster_labels)

        n_clusters = len(uni_cluster_labels)

        if not hasattr(self, 'z'):
        z = self.z
        mu = np.zeros((z.shape[1], n_clusters))
        for i,l in enumerate(uni_cluster_labels):
            mu[:,i] = np.mean(z[cluster_labels==l], axis=0)
        if dist is None:
            ### update cluster centers if some cluster centers are too close
            clustering = AgglomerativeClustering(
            n_clusters_new = clustering.n_clusters_
            if n_clusters_new < n_clusters:
                print("Merge clusters for cluster centers that are too close ...")
                n_clusters = n_clusters_new
                for i in range(n_clusters):    
                    temp = uni_cluster_labels[clustering.labels_ == i]
                    idx = np.isin(cluster_labels, temp)
                    cluster_labels[idx] = ','.join(temp)
                    if np.sum(clustering.labels_==i)>1:
                        print('Merge %s'% ','.join(temp))
                uni_cluster_labels = np.unique(cluster_labels)
                mu = np.zeros((z.shape[1], n_clusters))
                for i,l in enumerate(uni_cluster_labels):
                    mu[:,i] = np.mean(z[cluster_labels==l], axis=0)
        self.adata.obs['vitae_init_clustering'] = cluster_labels
        self.adata.obs['vitae_init_clustering'] = self.adata.obs['vitae_init_clustering'].astype('category')
        print("Initial clustering labels saved as 'vitae_init_clustering' in self.adata.obs.")
        if (log_pi is None) and (cluster_labels is not None) and (n_clusters>3):                         
            n_states = int((n_clusters+1)*n_clusters/2)
            if dist is None:
                dist = _comp_dist(z, cluster_labels, mu.T)

            C = np.triu(np.ones(n_clusters))
            C[C>0] = np.arange(n_states)
            C = C.astype(int)

            log_pi = np.zeros((1,n_states))
            ## pruning to throw away edges for far-away clusters if there are too many clusters
            if ratio_prune is not None:
                log_pi[0, C[np.triu(dist)>np.quantile(dist[np.triu_indices(n_clusters, 1)], 1-ratio_prune)]] = - np.inf
                log_pi[0, C[np.triu(dist)>np.quantile(dist[np.triu_indices(n_clusters, 1)], 5/n_clusters) * 3]] = - np.inf

        self.n_states = n_clusters
        self.labels = cluster_labels

        # Not sure if storing the this will be useful
        # self.init_labels_name = cluster_label
        labels_map = pd.DataFrame.from_dict(
            {i:label for i,label in enumerate(uni_cluster_labels)}, 
            orient='index', columns=['label_names'], dtype=str
        self.labels_map = labels_map
        self.vae.init_latent_space(self.n_states, mu, log_pi)
        self.inferer = Inferer(self.n_states) =
        self.pi = np.triu(np.ones(self.n_states))
        self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]

        if pilayer:

    def update_latent_space(self, dist_thres: float=0.5):
        pi = self.pi[np.triu_indices(self.n_states)]
        mu =    
        clustering = AgglomerativeClustering(
        n_clusters = clustering.n_clusters_   

        if n_clusters<self.n_states:      
            print("Merge clusters for cluster centers that are too close ...")
            mu_new = np.empty((self.dim_latent, n_clusters))
            C = np.zeros((self.n_states, self.n_states))
            C[np.triu_indices(self.n_states, 0)] = pi
            C = np.triu(C, 1) + C.T
            C_new = np.zeros((n_clusters, n_clusters))
            uni_cluster_labels = self.labels_map['label_names'].to_numpy()
            returned_order = {}
            cluster_labels = self.labels
            for i in range(n_clusters):
                temp = uni_cluster_labels[clustering.labels_ == i]
                idx = np.isin(cluster_labels, temp)
                cluster_labels[idx] = ','.join(temp)
                returned_order[i] = ','.join(temp)
                if np.sum(clustering.labels_==i)>1:
                    print('Merge %s'% ','.join(temp))
            uni_cluster_labels = np.unique(cluster_labels) 
            for i,l in enumerate(uni_cluster_labels):  ## reorder the merged clusters based on the cluster names
                k = np.where(returned_order == l)
                mu_new[:, i] = np.mean(mu[:,clustering.labels_==k], axis=-1)
                # sum of the aggregated pi's
                C_new[i, i] = np.sum(np.triu(C[clustering.labels_==k,:][:,clustering.labels_==k]))
                for j in range(i+1, n_clusters):
                    k1 = np.where(returned_order == uni_cluster_labels[j])
                    C_new[i, j] = np.sum(C[clustering.labels_== k, :][:, clustering.labels_==k1])

#            labels_map_new = {}
#            for i in range(n_clusters):                       
#                # update label map: int->str
#                labels_map_new[i] = self.labels_map.loc[clustering.labels_==i, 'label_names']',')
#                if np.sum(clustering.labels_==i)>1:
#                    print('Merge %s'%labels_map_new[i])
#                # mean of the aggregated cluster means
#                mu_new[:, i] = np.mean(mu[:,clustering.labels_==i], axis=-1)
#                # sum of the aggregated pi's
#                C_new[i, i] = np.sum(np.triu(C[clustering.labels_==i,:][:,clustering.labels_==i]))
#                for j in range(i+1, n_clusters):
#                    C_new[i, j] = np.sum(C[clustering.labels_== i, :][:, clustering.labels_==j])
            C_new = np.triu(C_new,1) + C_new.T

            pi_new = C_new[np.triu_indices(n_clusters)]
            log_pi_new = np.log(pi_new, out=np.ones_like(pi_new)*(-np.inf), where=(pi_new!=0)).reshape((1,-1))
            self.n_states = n_clusters
            self.labels_map = pd.DataFrame.from_dict(
                {i:label for i,label in enumerate(uni_cluster_labels)},
                orient='index', columns=['label_names'], dtype=str
            self.labels = cluster_labels
#            self.labels_map = pd.DataFrame.from_dict(
#                labels_map_new, orient='index', columns=['label_names'], dtype=str
#            )
            self.vae.init_latent_space(self.n_states, mu_new, log_pi_new)
            self.inferer = Inferer(self.n_states)
            self.pi = np.triu(np.ones(self.n_states))
            self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]

    def train(self, stratify = False, test_size = 0.1, random_state: int = 0,
            learning_rate: float = 1e-3, batch_size: int = 256,
            L: int = 1, alpha: float = 0.10, beta: float = 1, gamma: float = 0, phi: float = 1,
            num_epoch: int = 200, num_step_per_epoch: Optional[int] =  None,
            early_stopping_patience: int = 10, early_stopping_tolerance: float = 0.01, 
            early_stopping_relative: bool = True, early_stopping_warmup: int = 0,
            path_to_weights: Optional[str] = None,
            verbose: bool = False, **kwargs):
        '''Train the model.

        stratify : np.array, None, or False
            If an array is provided, or `stratify=None` and `self.labels` is available, then they will be used to perform stratified shuffle splitting. Otherwise, general shuffle splitting is used. Set to `False` if `self.labels` is not intended for stratified shuffle splitting.
        test_size : float or int, optional
            The proportion or size of the test set.
        random_state : int, optional
            The random state for data splitting.
        learning_rate : float, optional  
            The initial learning rate for the Adam optimizer.
        batch_size : int, optional  
            The batch size for training. Default is 256. Set to 32 if number of cells is small (less than 1000)
        L : int, optional  
            The number of MC samples.
        alpha : float, optional  
            The value of alpha in [0,1] to encourage covariate adjustment. Not used if there is no covariates.
        beta : float, optional  
            The value of beta in beta-VAE.
        gamma : float, optional
            The weight of mmd_loss.
        phi : float, optional
            The weight of Jacob norm of encoder.
        num_epoch : int, optional  
            The number of epoch.
        num_step_per_epoch : int, optional 
            The number of step per epoch, it will be inferred from number of cells and batch size if it is None.
        early_stopping_patience : int, optional 
            The maximum number of epochs if there is no improvement.
        early_stopping_tolerance : float, optional 
            The minimum change of loss to be considered as an improvement.
        early_stopping_relative : bool, optional
            Whether monitor the relative change of loss or not.            
        early_stopping_warmup : int, optional 
            The number of warmup epochs.            
        path_to_weights : str, optional 
            The path of weight file to be saved; not saving weight if None.
        **kwargs :  
            Extra key-value arguments for dimension reduction algorithms.        
        if gamma == 0 or self.conditions is None:
            conditions = np.array([np.nan] * self.adata.shape[0])
            conditions = self.conditions

        if stratify is None:
            stratify = self.labels
        elif stratify is False:
            stratify = None    
        id_train, id_test = train_test_split(
        if num_step_per_epoch is None:
            num_step_per_epoch = len(id_train)//batch_size+1
        c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx())
        self.train_dataset = train.warp_dataset(self.X_input[id_train].astype(tf.keras.backend.floatx()),
                                                None if c is None else c[id_train],
                                                conditions = conditions[id_train],
                                                pi_cov = self.pi_cov[id_train])
        self.test_dataset = train.warp_dataset(self.X_input[id_test].astype(tf.keras.backend.floatx()),
                                                None if c is None else c[id_test],
                                                conditions = conditions[id_test],
                                                pi_cov = self.pi_cov[id_test])
        self.vae = train.train(
        self.update_z() =
        self.pi = np.triu(np.ones(self.n_states))
        self.pi[self.pi > 0] = tf.nn.softmax(self.vae.latent_space.pi).numpy()[0]
        if path_to_weights is not None:

    def output_pi(self, pi_cov):
        """return a matrix n_states by n_states and a mask for plotting, which can be used to cover the lower triangular(except the diagnoals) of a heatmap"""
        p = self.vae.pilayer
        pi_cov = tf.expand_dims(tf.constant([pi_cov], dtype=tf.float32), 0)
        pi_val = tf.nn.softmax(p(pi_cov)).numpy()[0]
        # Create heatmap matrix
        n = self.vae.n_states
        matrix = np.zeros((n, n))
        matrix[np.triu_indices(n)] = pi_val
        mask = np.tril(np.ones_like(matrix), k=-1)
        return matrix, mask

    def return_pilayer_weights(self):
        """return parameters of pilayer, which has dimension dim(pi_cov) + 1 by n_categories, the last row is biases"""
        return np.vstack((model.vae.pilayer.weights[0].numpy(), model.vae.pilayer.weights[1].numpy().reshape(1, -1)))

    def posterior_estimation(self, batch_size: int = 32, L: int = 50, **kwargs):
        '''Initialize trajectory inference by computing the posterior estimations.        

        batch_size : int, optional
            The batch size when doing inference.
        L : int, optional
            The number of MC samples when doing inference.
        **kwargs :  
            Extra key-value arguments for dimension reduction algorithms.              
        c = None if self.covariates is None else self.covariates.astype(tf.keras.backend.floatx())
        self.test_dataset = train.warp_dataset(self.X_input.astype(tf.keras.backend.floatx()), 
        _, _, self.pc_x,\
            self.cell_position_posterior,self.cell_position_variance,_ = self.vae.inference(self.test_dataset, L=L)
        uni_cluster_labels = self.labels_map['label_names'].to_numpy()
        self.adata.obs['vitae_new_clustering'] = uni_cluster_labels[np.argmax(self.cell_position_posterior, 1)]
        self.adata.obs['vitae_new_clustering'] = self.adata.obs['vitae_new_clustering'].astype('category')
        print("New clustering labels saved as 'vitae_new_clustering' in self.adata.obs.")
        return None

    def infer_backbone(self, method: str = 'modified_map', thres = 0.5,
            no_loop: bool = True, cutoff: float = 0,
            visualize: bool = True, color = 'vitae_new_clustering',path_to_fig = None,**kwargs):
        ''' Compute edge scores.

        method : string, optional
            'mean', 'modified_mean', 'map', or 'modified_map'.
        thres : float, optional
            The threshold used for filtering edges \(e_{ij}\) that \((n_{i}+n_{j}+e_{ij})/N<thres\), only applied to mean method.
        no_loop : boolean, optional
            Whether loops are allowed to exist in the graph. If no_loop is true, will prune the graph to contain only the
            maximum spanning true
        cutoff : string, optional
            The score threshold for filtering edges with scores less than cutoff.
        visualize: boolean
            whether plot the current trajectory backbone (undirected graph)

        G : nx.Graph
            The weighted graph with weight on each edge indicating its score of existence.
        # build_graph, return graph
        self.backbone = self.inferer.build_graphs(self.cell_position_posterior, self.pc_x,
                method, thres, no_loop, cutoff)
        self.cell_position_projected = self.inferer.modify_wtilde(self.cell_position_posterior, 
        uni_cluster_labels = self.labels_map['label_names'].to_numpy()
        temp_dict = {i:label for i,label in enumerate(uni_cluster_labels)}
        nx.relabel_nodes(self.backbone, temp_dict)
        self.adata.obs['vitae_new_clustering'] = uni_cluster_labels[np.argmax(self.cell_position_projected, 1)]
        self.adata.obs['vitae_new_clustering'] = self.adata.obs['vitae_new_clustering'].astype('category')
        print("'vitae_new_clustering' updated based on the projected cell positions.")

        self.uncertainty = np.sum((self.cell_position_projected - self.cell_position_posterior)**2, axis=-1) \
            + np.sum(self.cell_position_variance, axis=-1)
        self.adata.obs['projection_uncertainty'] = self.uncertainty
        print("Cell projection uncertainties stored as 'projection_uncertainty' in self.adata.obs")
        if visualize:
            self._adata.obs = self.adata.obs.copy()
   = self.plot_backbone(directed = False,color = color, **kwargs)
            if path_to_fig is not None:
        return None

    def select_root(self, days, method: str = 'proportion'):
        '''Order the vertices/states based on cells' collection time information to select the root state.      

        day : np.array 
            The day information for selected cells used to determine the root vertex.
            The dtype should be 'int' or 'float'.
        method : str, optional
            'sum' or 'mean'. 
            For 'proportion', the root is the one with maximal proportion of cells from the earliest day.
            For 'mean', the root is the one with earliest mean time among cells associated with it.

        root : int 
            The root vertex in the inferred trajectory based on given day information.
        ## TODO: change return description
        if days is not None and len(days)!=self.X_input.shape[0]:
            raise ValueError("The length of day information ({}) is not "
                "consistent with the number of selected cells ({})!".format(
                    len(days), self.X_input.shape[0]))
        if not hasattr(self, 'cell_position_projected'):
            raise ValueError("Need to call 'infer_backbone' first!")

        collection_time =, self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0)
        earliest_prop =, self.cell_position_projected)/np.sum(self.cell_position_projected, axis = 0)
        root_info = self.labels_map.copy()
        root_info['mean_collection_time'] = collection_time
        root_info['earliest_time_prop'] = earliest_prop
        root_info.sort_values('mean_collection_time', inplace=True)
        return root_info

    def plot_backbone(self, directed: bool = False, 
                      method: str = 'UMAP', color = 'vitae_new_clustering', **kwargs):
        '''Plot the current trajectory backbone (undirected graph).

        directed : boolean, optional
            Whether the backbone is directed or not.
        method : str, optional
            The dimension reduction method to use. The default is "UMAP".
        color : str, optional
            The key for annotations of observations/cells or variables/genes, e.g., 'ann1' or ['ann1', 'ann2'].
            The default is 'vitae_new_clustering'.
        **kwargs :
            Extra key-value arguments that can be passed to scanpy plotting functions (
        if not isinstance(color,str):
            raise ValueError('The color argument should be of type str!')
        ax = self.visualize_latent(method = method, color=color, show=False, **kwargs)
        dict_label_num = {j:i for i,j in self.labels_map['label_names'].to_dict().items()}
        uni_cluster_labels = self.adata.obs['vitae_init_clustering'].cat.categories
        cluster_labels = self.adata.obs['vitae_new_clustering'].to_numpy()
        embed_z = self._adata.obsm[self.dict_method_scname[method]]
        embed_mu = np.zeros((len(uni_cluster_labels), 2))
        for l in uni_cluster_labels:
            embed_mu[dict_label_num[l],:] = np.mean(embed_z[cluster_labels==l], axis=0)

        if directed:
            graph = self.directed_backbone
            graph = self.backbone
        edges = list(graph.edges)
        edge_scores = np.array([d['weight'] for (u,v,d) in graph.edges(data=True)])
        if max(edge_scores) - min(edge_scores) == 0:
            edge_scores = edge_scores/max(edge_scores)
            edge_scores = (edge_scores - min(edge_scores))/(max(edge_scores) - min(edge_scores))*3

        value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0])
        y_range = np.min(embed_z[:,1]), np.max(embed_z[:,1], axis=0)
        for i in range(len(edges)):
            points = embed_z[np.sum(self.cell_position_projected[:, edges[i]]>0, axis=-1)==2,:]
            points = points[points[:,0].argsort()]
                x_smooth, y_smooth = _get_smooth_curve(
                    embed_mu[edges[i], :],
                x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1]
            ax.plot(x_smooth, y_smooth,
                linewidth= 1 + edge_scores[i],
                                        foreground='white'), pe.Normal()],

            if directed:
                delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2]
                delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2]
                length = np.sqrt(delta_x**2 + delta_y**2) / 50 * value_range
                        embed_mu[edges[i][1], 0]-delta_x/length,
                        embed_mu[edges[i][1], 1]-delta_y/length,
                        color='black', alpha=1.0,
                        shape='full', lw=0, length_includes_head=True,
                        head_width=np.maximum(0.01*(1 + edge_scores[i]), 0.03) * value_range,
        colors = self._adata.uns['vitae_new_clustering_colors']
        for i,l in enumerate(uni_cluster_labels):
                       c=[colors[i]], edgecolors='white', # linewidths=10,  norm=norm,
                       s=250, marker='*', label=l)

        plt.setp(ax, xticks=[], yticks=[])
        box = ax.get_position()
        ax.set_position([box.x0, box.y0 + box.height * 0.1,
                            box.width, box.height * 0.9])
        if directed:
            ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05),
                fancybox=True, shadow=True, ncol=5)

        return ax

    def plot_center(self, color = "vitae_new_clustering", plot_legend = True, legend_add_index = True,
                    method: str = 'UMAP',ncol = 2,font_size = "medium",
                    add_egde = False, add_direct = False,**kwargs):
        '''Plot the center of each cluster in the latent space.

        color : str, optional
            The color of the center of each cluster. Default is "vitae_new_clustering".
        plot_legend : bool, optional
            Whether to plot the legend. Default is True.
        legend_add_index : bool, optional
            Whether to add the index of each cluster in the legend. Default is True.
        method : str, optional
            The dimension reduction method used for visualization. Default is 'UMAP'.
        ncol : int, optional
            The number of columns in the legend. Default is 2.
        font_size : str, optional
            The font size of the legend. Default is "medium".
        add_egde : bool, optional
            Whether to add the edges between the centers of clusters. Default is False.
        add_direct : bool, optional
            Whether to add the direction of the edges. Default is False.
        if color not in ["vitae_new_clustering","vitae_init_clustering"]:
            raise ValueError("Can only plot center of vitae_new_clustering or vitae_init_clustering")
        dict_label_num = {j: i for i, j in self.labels_map['label_names'].to_dict().items()}
        if legend_add_index:
            self._adata.obs["index_"+color] = self._adata.obs[color].map(lambda x: dict_label_num[x])
            ax = self.visualize_latent(method=method, color="index_" + color, show=False, legend_loc="on data",
            colors = self._adata.uns["index_" + color + '_colors']
            ax = self.visualize_latent(method=method, color = color, show=False,**kwargs)
            colors = self._adata.uns[color + '_colors']
        uni_cluster_labels = self.adata.obs[color].cat.categories
        cluster_labels = self.adata.obs[color].to_numpy()
        embed_z = self._adata.obsm[self.dict_method_scname[method]]
        embed_mu = np.zeros((len(uni_cluster_labels), 2))
        for l in uni_cluster_labels:
            embed_mu[dict_label_num[l], :] = np.mean(embed_z[cluster_labels == l], axis=0)

        leg = (self.labels_map.index.astype(str) + " : " + self.labels_map.label_names).values
        for i, l in enumerate(uni_cluster_labels):
            ax.scatter(*embed_mu[dict_label_num[l]:dict_label_num[l] + 1, :].T,
                       c=[colors[i]], edgecolors='white', # linewidths=3,
                       s=250, marker='*', label=leg[i])
        if plot_legend:
            ax.legend(loc='center left', bbox_to_anchor=(1, 0.5), ncol=ncol, markerscale=0.8, frameon=False)
        plt.setp(ax, xticks=[], yticks=[])
        box = ax.get_position()
        ax.set_position([box.x0, box.y0 + box.height * 0.1,
                         box.width, box.height * 0.9])
        if add_egde:
            if add_direct:
                graph = self.directed_backbone
                graph = self.backbone
            edges = list(graph.edges)
            edge_scores = np.array([d['weight'] for (u, v, d) in graph.edges(data=True)])
            if max(edge_scores) - min(edge_scores) == 0:
                edge_scores = edge_scores / max(edge_scores)
                edge_scores = (edge_scores - min(edge_scores)) / (max(edge_scores) - min(edge_scores)) * 3

            value_range = np.maximum(np.diff(ax.get_xlim())[0], np.diff(ax.get_ylim())[0])
            y_range = np.min(embed_z[:, 1]), np.max(embed_z[:, 1], axis=0)
            for i in range(len(edges)):
                points = embed_z[np.sum(self.cell_position_projected[:, edges[i]] > 0, axis=-1) == 2, :]
                points = points[points[:, 0].argsort()]
                    x_smooth, y_smooth = _get_smooth_curve(
                        embed_mu[edges[i], :],
                    x_smooth, y_smooth = embed_mu[edges[i], 0], embed_mu[edges[i], 1]
                ax.plot(x_smooth, y_smooth,
                        linewidth=1 + edge_scores[i],
                        path_effects=[pe.Stroke(linewidth=1 + edge_scores[i] + 1.5,
                                                foreground='white'), pe.Normal()],

                if add_direct:
                    delta_x = embed_mu[edges[i][1], 0] - x_smooth[-2]
                    delta_y = embed_mu[edges[i][1], 1] - y_smooth[-2]
                    length = np.sqrt(delta_x ** 2 + delta_y ** 2) / 50 * value_range
                        embed_mu[edges[i][1], 0] - delta_x / length,
                        embed_mu[edges[i][1], 1] - delta_y / length,
                        delta_x / length,
                        delta_y / length,
                        color='black', alpha=1.0,
                        shape='full', lw=0, length_includes_head=True,
                        head_width=np.maximum(0.01 * (1 + edge_scores[i]), 0.03) * value_range,
                        zorder=2) = ax
        return None

    def infer_trajectory(self, root: Union[int,str], digraph = None, color = "pseudotime",
                         visualize: bool = True, path_to_fig = None,  **kwargs):
        '''Infer the trajectory.

        root : int or string
            The root of the inferred trajectory. Can provide either an int (vertex index) or string (label name)
        digraph : nx.DiGraph, optional
            The directed graph to be used for trajectory inference. If None, the minimum spanning tree of the estimated trajectory backbone will be used.
        cutoff : string, optional
            The threshold for filtering edges with scores less than cutoff.
        visualize: boolean
            Whether plot the current trajectory backbone (directed graph)
        path_to_fig : string, optional  
            The path to save figure, or don't save if it is None.
        **kwargs : dict, optional
            Other keywords arguments for plotting.
        if isinstance(root,str):
            if root not in self.labels_map.values:
                raise ValueError("Root {} is not in the label names!".format(root))
            root = self.labels_map[self.labels_map['label_names']==root].index[0]

        if digraph is None:
            connected_comps = nx.node_connected_component(self.backbone, root)
            subG = self.backbone.subgraph(connected_comps)

            ## generate directed backbone which contains no loops
            DG = nx.DiGraph(nx.to_directed(self.backbone))
            temp = DG.subgraph(connected_comps)
            DG.remove_edges_from(temp.edges - nx.dfs_edges(DG, root))
            self.directed_backbone = DG
            if not nx.is_directed_acyclic_graph(digraph):
                raise ValueError("The graph 'digraph' should be a directed acyclic graph.")
            if set(digraph.nodes) != set(self.backbone.nodes):
                raise ValueError("The nodes in 'digraph' do not match the nodes in 'self.backbone'.")
            self.directed_backbone = digraph

            connected_comps = nx.node_connected_component(digraph, root)
            subG = self.backbone.subgraph(connected_comps)

        if len(subG.edges)>0:
            milestone_net = self.inferer.build_milestone_net(subG, root)
            if self.inferer.no_loop is False and milestone_net.shape[0]<len(self.backbone.edges):
                warnings.warn("The directed graph shown is a minimum spanning tree of the estimated trajectory backbone to avoid arbitrary assignment of the directions.")
            self.pseudotime = self.inferer.comp_pseudotime(milestone_net, root, self.cell_position_projected)
            warnings.warn("There are no connected states for starting from the giving root.")
            self.pseudotime = -np.ones(self._adata.shape[0])

        self.adata.obs['pseudotime'] = self.pseudotime
        print("Cell projection uncertainties stored as 'pseudotime' in self.adata.obs")

        if visualize:
            self._adata.obs['pseudotime'] = self.pseudotime
   = self.plot_backbone(directed = True, color = color, **kwargs)
            if path_to_fig is not None:

        return None

    def differential_expression_test(self, alpha: float = 0.05, cell_subset = None, order: int = 1):
        '''Differentially gene expression test. All (selected and unselected) genes will be tested 
        Only cells in `selected_cell_subset` will be used, which is useful when one need to
        test differentially expressed genes on a branch of the inferred trajectory.

        alpha : float, optional
            The cutoff of p-values.
        cell_subset : np.array, optional
            The subset of cells to be used for testing. If None, all cells will be used.
        order : int, optional
            The maxium order we used for pseudotime in regression.

        res_df : pandas.DataFrame
            The test results of expressed genes with two columns,
            the estimated coefficients and the adjusted p-values.
        if not hasattr(self, 'pseudotime'):
            raise ReferenceError("Pseudotime does not exist! Please run 'infer_trajectory' first.")
        if cell_subset is None:
            cell_subset = np.arange(self.X_input.shape[0])
            print("All cells are selected.")
        if order < 1:
            raise  ValueError("Maximal order of pseudotime in regression must be at least 1.")

        # Prepare X and Y for regression expression ~ rank(PDT) + covariates
        Y = self.adata.X[cell_subset,:]
#        std_Y = np.std(Y, ddof=1, axis=0, keepdims=True)
#        Y = np.divide(Y-np.mean(Y, axis=0, keepdims=True), std_Y, out=np.empty_like(Y)*np.nan, where=std_Y!=0)
        X = stats.rankdata(self.pseudotime[cell_subset])        
        if order > 1:
            for _order in range(2, order+1):
                X = np.c_[X, X**_order]
        X = ((X-np.mean(X,axis=0, keepdims=True))/np.std(X, ddof=1, axis=0, keepdims=True))
        X = np.c_[np.ones((X.shape[0],1)), X]
        if self.covariates is not None:
            X = np.c_[X, self.covariates[cell_subset, :]]

        res_df = DE_test(Y, X, self.adata.var_names, i_test = np.array(list(range(1,order+1))), alpha = alpha)
        return res_df[res_df.pvalue_adjusted_1 != 0]


    def evaluate(self, milestone_net, begin_node_true, grouping = None,
                thres: float = 0.5, no_loop: bool = True, cutoff: Optional[float] = None,
                method: str = 'mean', path: Optional[str] = None):
        ''' Evaluate the model.

        milestone_net : pd.DataFrame
            The true milestone network. For real data, milestone_net will be a DataFrame of the graph of nodes.

            cluster 1 | cluster 1
            cluster 1 | cluster 2

            For synthetic data, milestone_net will be a DataFrame of the (projected)
            positions of cells. The indexes are the orders of cells in the dataset.

            cluster 1 | cluster 1 | 1
            cluster 1 | cluster 2 | 0.1
        begin_node_true : str or int
            The true begin node of the milestone.
        grouping : np.array, optional
            \([N,]\) The labels. For real data, grouping must be provided.

        res : pd.DataFrame
            The evaluation result.
        if not hasattr(self, 'labels_map'):
            raise ValueError("No given labels for training.")

        # Evaluate for the whole dataset will ignore selected_cell_subset.
        if len(self.selected_cell_subset)!=len(self.cell_names):
            warnings.warn("Evaluate for the whole dataset.")
        # If the begin_node_true, need to encode it by self.le.
        # this dict is for milestone net cause their labels are not merged
        # all keys of label_map_dict are str
        label_map_dict = dict()
        for i in range(self.labels_map.shape[0]):
            label_mapped = self.labels_map.loc[i]
            ## merged cluster index is connected by comma
            for each in label_mapped.values[0].split(","):
                label_map_dict[each] = i
        if isinstance(begin_node_true, str):
            begin_node_true = label_map_dict[begin_node_true]
        # For generated data, grouping information is already in milestone_net
        if 'w' in milestone_net.columns:
            grouping = None
        # If milestone_net is provided, transform them to be numeric.
        if milestone_net is not None:
            milestone_net['from'] = [label_map_dict[x] for x in milestone_net["from"]]
            milestone_net['to'] = [label_map_dict[x] for x in milestone_net["to"]]

        # this dict is for potentially merged clusters.
        label_map_dict_for_merged_cluster = dict(zip(self.labels_map["label_names"],self.labels_map.index))
        mapped_labels = np.array([label_map_dict_for_merged_cluster[x] for x in self.labels])
        begin_node_pred = int(np.argmin(np.mean((
            self.z[mapped_labels==begin_node_true,:,np.newaxis] -
  [np.newaxis,:,:])**2, axis=(0,1))))

        if cutoff is None:
            cutoff = 0.01

        G = self.backbone
        w = self.cell_position_projected
        pseudotime = self.pseudotime

        # 1. Topology
        G_pred = nx.Graph()
        nx.set_node_attributes(G_pred, False, 'is_init')
        G_pred.nodes[begin_node_pred]['is_init'] = True

        G_true = nx.Graph()
        # if 'grouping' is not provided, assume 'milestone_net' contains proportions
        if grouping is None:
                milestone_net[~pd.isna(milestone_net['w'])].groupby(['from', 'to']).count().index))
        # otherwise, 'milestone_net' indicates edges
            if milestone_net is not None:             
                    milestone_net.groupby(['from', 'to']).count().index))
            grouping = [label_map_dict[x] for x in grouping]
            grouping = np.array(grouping)
        nx.set_node_attributes(G_true, False, 'is_init')
        G_true.nodes[begin_node_true]['is_init'] = True
        res = topology(G_true, G_pred)
        # 2. Milestones assignment
        if grouping is None:
            milestones_true = milestone_net['from'].values.copy()
                           &(milestone_net['w']<0.5)] = milestone_net[(milestone_net['from']!=milestone_net['to'])
            milestones_true = grouping
        milestones_true = milestones_true
        milestones_pred = np.argmax(w, axis=1)
        res['ARI'] = (adjusted_rand_score(milestones_true, milestones_pred) + 1)/2
        if grouping is None:
            n_samples = len(milestone_net)
            prop = np.zeros((n_samples,n_samples))
            prop[np.arange(n_samples), milestone_net['to']] = 1-milestone_net['w']
            prop[np.arange(n_samples), milestone_net['from']] = np.where(np.isnan(milestone_net['w']), 1, milestone_net['w'])
            res['GRI'] = get_GRI(prop, w)
            res['GRI'] = get_GRI(grouping, w)
        # 3. Correlation between geodesic distances / Pseudotime
        if no_loop:
            if grouping is None:
                pseudotime_true = milestone_net['from'].values + 1 - milestone_net['w'].values
                pseudotime_true[np.isnan(pseudotime_true)] = milestone_net[pd.isna(milestone_net['w'])]['from'].values            
                pseudotime_true = - np.ones(len(grouping))
                nx.set_edge_attributes(G_true, values = 1, name = 'weight')
                connected_comps = nx.node_connected_component(G_true, begin_node_true)
                subG = G_true.subgraph(connected_comps)
                milestone_net_true = self.inferer.build_milestone_net(subG, begin_node_true)
                if len(milestone_net_true)>0:
                    pseudotime_true[grouping==int(milestone_net_true[0,0])] = 0
                    for i in range(len(milestone_net_true)):
                        pseudotime_true[grouping==int(milestone_net_true[i,1])] = milestone_net_true[i,-1]
            pseudotime_true = pseudotime_true[pseudotime>-1]
            pseudotime_pred = pseudotime[pseudotime>-1]
            res['PDT score'] = (np.corrcoef(pseudotime_true,pseudotime_pred)[0,1]+1)/2
            res['PDT score'] = np.nan
        # 4. Shape
        # score_cos_theta = 0
        # for (_from,_to) in G.edges:
        #     _z = self.z[(w[:,_from]>0) & (w[:,_to]>0),:]
        #     v_1 = _z -[:,_from]
        #     v_2 = _z -[:,_to]
        #     cos_theta = np.sum(v_1*v_2, -1)/(np.linalg.norm(v_1,axis=-1)*np.linalg.norm(v_2,axis=-1)+1e-12)

        #     score_cos_theta += np.sum((1-cos_theta)/2)

        # res['score_cos_theta'] = score_cos_theta/(np.sum(np.sum(w>0, axis=-1)==2)+1e-12)
        return res

    def save_model(self, path_to_file: str = 'model.checkpoint',save_adata: bool = False):
        '''Saving model weights.

        path_to_file : str, optional
            The path to weight files of pre-trained or trained model
        save_adata : boolean, optional
            Whether to save adata or not.
        if hasattr(self, 'labels') and self.labels is not None:
            with open(path_to_file + '.label', 'wb') as f:
      , self.labels)
        with open(path_to_file + '.config', 'wb') as f:
            self.dim_origin = self.X_input.shape[1]
  , np.array([
                self.dim_origin, self.dimensions, self.dim_latent,
                self.model_type, 0 if self.covariates is None else self.covariates.shape[1]], dtype=object))
        if hasattr(self, 'inferer') and hasattr(self, 'uncertainty'):
            with open(path_to_file + '.inference', 'wb') as f:
      , np.array([
                    self.pi,, self.pc_x, self.cell_position_posterior, self.uncertainty,
                    self.z,self.cell_position_variance], dtype=object))
        if save_adata:
            self.adata.write(path_to_file + '.adata.h5ad')

    def load_model(self, path_to_file: str = 'model.checkpoint', load_labels: bool = False, load_adata: bool = False):
        '''Load model weights.

        path_to_file : str, optional
            The path to weight files of pre trained or trained model
        load_labels : boolean, optional
            Whether to load clustering labels or not.
            If load_labels is True, then the LatentSpace layer will be initialized basd on the model.
            If load_labels is False, then the LatentSpace layer will not be initialized.
        load_adata : boolean, optional
            Whether to load adata or not.
        if not os.path.exists(path_to_file + '.config'):
            raise AssertionError('Config file not exist!')
        if load_labels and not os.path.exists(path_to_file + '.label'):
            raise AssertionError('Label file not exist!')

        with open(path_to_file + '.config', 'rb') as f:
            [self.dim_origin, self.dimensions,
             self.dim_latent, self.model_type, cov_dim] = np.load(f, allow_pickle=True)
        self.vae = model.VariationalAutoEncoder(
            self.dim_origin, self.dimensions,
            self.dim_latent, self.model_type, False if cov_dim == 0 else True

        if load_labels:
            with open(path_to_file + '.label', 'rb') as f:
                cluster_labels = np.load(f, allow_pickle=True)
            self.init_latent_space(cluster_labels, dist_thres=0)
            if os.path.exists(path_to_file + '.inference'):
                with open(path_to_file + '.inference', 'rb') as f:
                    arr = np.load(f, allow_pickle=True)
                    if len(arr) == 8:
                        [self.pi,, self.pc_x, self.cell_position_posterior, self.uncertainty,
                         self.D_JS, self.z,self.cell_position_variance] = arr
                        [self.pi,, self.pc_x, self.cell_position_posterior, self.uncertainty,
                         self.z,self.cell_position_variance] = arr
                self._adata_z = sc.AnnData(self.z)
        ## initialize the weight of encoder and decoder
        self.vae.encoder(np.zeros((1, self.dim_origin + cov_dim)))
        self.vae.decoder(np.expand_dims(np.zeros((1,self.dim_latent + cov_dim)),1))


        if load_adata:
            if not os.path.exists(path_to_file + '.adata.h5ad'):
                raise AssertionError('AnnData file not exist!')
            self.adata = sc.read_h5ad(path_to_file + '.adata.h5ad')
            self._adata.obs = self.adata.obs.copy()


