Source code for maboss.results.probtrajresult

"""
Class that contains the results of a MaBoSS simulation.
"""

import pandas as pd
import numpy as np

class ProbTrajResult(object):
    
    def __init__(self, output_nodes=None):

        self.output_nodes = output_nodes    
        
        self.state_probtraj = None
        self.state_probtraj_errors = None
        self.state_probtraj_full = None
        
        self.nd_probtraj = None
        self.nd_probtraj_error = None
        
        self.entropy_probtraj = None
        self.entropy_probtraj_error = None
        
        self.last_states_probtraj = None
        self.last_nodes_probtraj = None
        
        self._raw_data = None
        self._first_state_index = None

        self._raw_states = None
        self._raw_probas = None
        self._raw_errors = None
        self._raw_entropy = None
        self._raw_last_data = None

        self.indexes = None
        self.states = None
        self.nodes = None
        self.states_indexes = None
        self.nodes_indexes = None

    def get_states_probtraj(self, prob_cutoff=None):
        """
            Returns the state probability vs time, as a pandas dataframe.

            :param float prob_cutoff: returns only the states with proba > cutoff
        """
        if self.state_probtraj is None:

            raw_states = self._get_raw_states()
            raw_probas = self._get_raw_probas()
            indexes, states = self._get_indexes()
            states_indexes = self._get_states_indexes()
            
            new_data = np.zeros((len(raw_probas), len(states)))
            for i, t_probas in enumerate(raw_probas):
                for j, proba in enumerate(t_probas):
                    new_data[i, states_indexes[raw_states[i][j]]] = proba    
            
            self.state_probtraj = pd.DataFrame(
                data=new_data,
                columns=states,
                index=indexes

            )
            self.state_probtraj.sort_index(axis=1, inplace=True)

        if prob_cutoff is not None:
            maxs = self.state_probtraj.max(axis=0)
            return self.state_probtraj[maxs[maxs>prob_cutoff].index]

        return self.state_probtraj

    def get_states_probtraj_errors(self):
        """
            Returns the state probability error vs time, as a pandas dataframe.
        """
        if self.state_probtraj_errors is None:
            
            raw_states = self._get_raw_states()
            raw_errors = self._get_raw_errors()
            indexes, states = self._get_indexes()
            states_indexes = self._get_states_indexes()

            new_data = np.zeros((len(raw_errors), len(states)))
            for i, t_errors in enumerate(raw_errors):
                for j, error in enumerate(t_errors):
                    new_data[i, states_indexes[raw_states[i][j]]] = error    
            
            self.state_probtraj_errors = pd.DataFrame(
                data=new_data,
                columns=states,
                index=indexes

            )
            self.state_probtraj_errors.sort_index(axis=1, inplace=True)

        return self.state_probtraj_errors

    def get_nodes_probtraj(self, prob_cutoff=None):
        """
            Returns the node probability vs time, as a pandas dataframe.

            :param float prob_cutoff: returns only the nodes with proba > cutoff
        """
        if self.nd_probtraj is None:

            raw_states = self._get_raw_states()
            raw_probas = self._get_raw_probas()
            indexes, states = self._get_indexes()
            nodes = self.output_nodes if self.output_nodes is not None else self._get_nodes()
            nodes_indexes = self._get_nodes_indexes()

            new_probs = np.zeros((len(indexes), len(nodes)))
            for i, t_probas in enumerate(raw_probas):
                for j, proba in enumerate(t_probas):
                    if raw_states[i][j] != "<nil>":
                        for node in raw_states[i][j].split(" -- "):
                            new_probs[i, nodes_indexes[node]] += proba

            self.nd_probtraj = pd.DataFrame(new_probs, columns=nodes, index=indexes)
            self.nd_probtraj.sort_index(axis=1, inplace=True)

        if prob_cutoff is not None:
            maxs = self.nd_probtraj.max(axis=0)
            return self.nd_probtraj[maxs[maxs>prob_cutoff].index]

        return self.nd_probtraj

    def get_nodes_probtraj_error(self):
        """
            Returns the node probability error vs time, as a pandas dataframe.
        """
        if self.nd_probtraj_error is None:

            raw_states = self._get_raw_states()
            raw_errors = self._get_raw_errors()
            indexes, states = self._get_indexes()
            nodes = self.output_nodes if self.output_nodes is not None else self._get_nodes()
            nodes_indexes = self._get_nodes_indexes()
            
            new_errors = np.zeros((len(indexes), len(nodes)))
            for i, t_raw_errors in enumerate(raw_errors):
                for j, error in enumerate(t_raw_errors):
                    if raw_states[i][j] != "<nil>":
                        for node in raw_states[i][j].split(" -- "):
                            new_errors[i, nodes_indexes[node]] += error

            self.nd_probtraj_error = pd.DataFrame(new_errors, columns=nodes, index=indexes)
            self.nd_probtraj_error.sort_index(axis=1, inplace=True)

        return self.nd_probtraj_error

    def get_states_probtraj_full(self, prob_cutoff=None):
        if self.state_probtraj_full is None:

            raw_states = self._get_raw_states()
            raw_probas = self._get_raw_probas()
            raw_errors = self._get_raw_errors()
            raw_entropy = self._get_raw_entropy()
            indexes, states = self._get_indexes()
            states_indexes = self._get_states_indexes()

            full_cols = ["TH", "ErrorTH", "H"]
            for col in states:
                full_cols.append("Prob[%s]" % col)
                full_cols.append("ErrProb[%s]" % col)

            new_data = np.zeros((len(indexes), len(full_cols)))

            for i, t_entropy in enumerate(raw_entropy):
                new_data[i, 0:3] = t_entropy

            for i, t_probas in enumerate(raw_probas):
                for j, proba in enumerate(t_probas):
                    new_data[i, 3+(states_indexes[raw_states[i][j]]*2)] = proba

            for i, t_errors in enumerate(raw_errors):
                for j, error in enumerate(t_errors):
                    new_data[i, 4+(states_indexes[raw_states[i][j]]*2)] = error

            self.state_probtraj_full = pd.DataFrame(new_data, columns=full_cols, index=indexes)

        if prob_cutoff is not None:
            maxs = self.state_probtraj_full.max(axis=0)
            cols = ["TH", "ErrorTH", "H"]

            for state in maxs[maxs > prob_cutoff].index:
                if state.startswith("Prob["):
                    cols.append(state)
                    cols.append("Err%s" % state)

            return self.state_probtraj_full[cols]            
      
        return self.state_probtraj_full

    def get_last_states_probtraj(self):
        """
            Returns the asymptotic state probability, as a pandas dataframe.
        """
        if self.last_states_probtraj is None:
            
            data, first_col = self._get_raw_last_data()

            states = [s for s in data[first_col::3]]
            probs = np.array([float(v) for v in data[first_col+1::3]])
            
            self.last_states_probtraj = pd.DataFrame([probs], columns=states, index=[data[0]])
            self.last_states_probtraj.sort_index(axis=1, inplace=True)
             
        return self.last_states_probtraj

    def get_last_nodes_probtraj(self):
        """
            Returns the asymptotic node probability, as a pandas dataframe.
        """
        if self.last_nodes_probtraj is None:
            data, first_col = self._get_raw_last_data()
              
            raw_states = [s for s in data[first_col::3]]
            raw_probs = np.array([float(v) for v in data[first_col+1::3]])

            nodes = set()
            for state in raw_states:
                if state != "<nil>":
                    nodes.update([node for node in state.split(" -- ")])
            nodes = list(nodes) 
            nodes_indexes = {node:index for index, node in enumerate(nodes)}

            new_probas = np.zeros((1, len(nodes)))
            for i, proba in enumerate(raw_probs):
                if raw_states[i] != "<nil>":
                    for node in raw_states[i].split(" -- "):
                        new_probas[0, nodes_indexes[node]] += proba

            self.last_nodes_probtraj = pd.DataFrame(new_probas, columns=nodes, index=[data[0]])
            self.last_nodes_probtraj.sort_index(axis=1, inplace=True)

        return self.last_nodes_probtraj

    def get_entropy_trajectory(self):
        """
            Returns the entropy vs time, as a pandas dataframe.
        """
        if self.entropy_probtraj is None:

            raw_entropy = self._get_raw_entropy()
            indexes, _ = self._get_indexes()
            
            new_data = np.zeros((len(raw_entropy), 2))
            for i, entropy in enumerate(raw_entropy):
                new_data[i, 0] = entropy[0]
                new_data[i, 1] = entropy[2]
            
            self.entropy_probtraj = pd.DataFrame(
                data=new_data,
                columns=["TH", "H"],
                index=indexes
            )

        return self.entropy_probtraj

    def get_entropy_trajectory_error(self):
        """
            Returns the entropy error vs time, as a pandas dataframe.
        """
        if self.entropy_probtraj_error is None:

            raw_entropy = self._get_raw_entropy()
            indexes, _ = self._get_indexes()
            
            new_data = np.zeros((len(raw_entropy), 2))
            for i, entropy in enumerate(raw_entropy):
                new_data[i, 0] = entropy[1]
                new_data[i, 1] = entropy[2]
            
            self.entropy_probtraj_error = pd.DataFrame(
                data=new_data,
                columns=["ErrorTH", "H"],
                index=indexes
            )

        return self.entropy_probtraj_error

    def _get_probtraj_fd(self):
        return open(self.get_probtraj_file(), 'r')

    def _get_raw_data(self):
    
        if self._raw_data is None:
            with self._get_probtraj_fd() as probtraj:

                raw_lines = probtraj.readlines()

                if self._first_state_index is None:                    
                    self._first_state_index = next(i for i, col in enumerate(raw_lines[0].strip("\n").split("\t")) if col == "State")

                self._raw_data = [line.strip("\n").split("\t") for line in raw_lines[1:]]
        
        return self._raw_data, self._first_state_index
        
    def _get_raw_last_data(self):
    
        if self._raw_last_data is None:
            with self._get_probtraj_fd() as probtraj:

                if self._first_state_index is None:
                    first_line = probtraj.readline()
                    self._first_state_index = next(i for i, col in enumerate(first_line.strip("\n").split("\t")) if col == "State")

                last_line = probtraj.readlines()[-1]
                self._raw_last_data = last_line.strip("\n").split("\t")
        
        return self._raw_last_data, self._first_state_index

    def _get_raw_states(self):
        if self._raw_states is None:
            data, first_state_index = self._get_raw_data()
            self._raw_states = [[s for s in t_data[first_state_index::3]] for t_data in data]
        return self._raw_states

    def _get_raw_probas(self):
        if self._raw_probas is None:
            data, first_state_index = self._get_raw_data()
            self._raw_probas = [[np.float64(p) for p in t_data[first_state_index+1::3]] for t_data in data]
        return self._raw_probas

    def _get_raw_errors(self):
        if self._raw_errors is None:
            data, first_state_index = self._get_raw_data()
            self._raw_errors = [[np.float64(p) for p in t_data[first_state_index+2::3]] for t_data in data]
        return self._raw_errors
        
    def _get_raw_entropy(self):
        if self._raw_entropy is None:
            data, _ = self._get_raw_data()
            self._raw_entropy = [[np.float64(p) for p in t_data[1:4]] for t_data in data]
        return self._raw_entropy

    def _get_indexes(self):

        if self.indexes is None:
            data, _ = self._get_raw_data()
            self.indexes = [float(t_data[0]) for t_data in data] 
            
        if self.states is None:
            self.states = set()
            for t_states in self._get_raw_states():
                self.states.update(t_states)
            self.states = list(self.states)

        return self.indexes, self.states
    
    def _get_nodes(self):
        
        if self.nodes is None:
            self.nodes = set()
            for t_states in self._get_raw_states():
                for state in t_states:
                    if state != "<nil>":
                        self.nodes.update([node for node in state.split(" -- ")])
            self.nodes = list(self.nodes)
        
        return self.nodes

    def _get_states_indexes(self):
        if self.states_indexes is None:
            _, states = self._get_indexes()
            self.states_indexes = {state:index for index, state in enumerate(states)}
        return self.states_indexes

    def _get_nodes_indexes(self):
        if self.nodes_indexes is None:
            nodes = self.output_nodes if self.output_nodes is not None else self._get_nodes()
            self.nodes_indexes = {node:index for index, node in enumerate(nodes)}
        return self.nodes_indexes