Source code for congas.models.HmmSimple

import pyro
import pyro.distributions as dist
import torch
from congas.models.Model import Model
from pyro.ops.indexing import Vindex
from pyro import poutine
from pyro.infer.autoguide import AutoDelta



[docs]class HmmSimple(Model): """ Simple Hmm, models the CNV event as a Categorical variable. It does not cluster the data Model parameters: T = max number of clusters (default = 6) init_probs = prior probs for initial state CNV probabilities (default=torch.tensor([0.1,0.1,0.2,0.3,0.2,0.1])) hidden_dim = hidden dimensions (should be len(probs)) theta_scale = scale for the normalization factor variable (default = 3) theta_rate = rate for the normalization factor variable (default = 1) batch_size = batch size (default = None) t = probability of remaining in the same state (default=0.1) """ params = {'init_probs': torch.tensor([0.1, 0.1, 0.2, 0.3, 0.2, 0.1]), 'hidden_dim': 6, 'theta_scale': 9, 'theta_rate': 3, 'batch_size': None , 't': 0.1} data_name = set(['data', 'mu', 'pld', 'segments']) def __init__(self, data_dict): self._params = self.params.copy() self._data = None super().__init__(data_dict, self.data_name)
[docs] def model(self, *args, **kwargs): I, N = self._data['data'].shape batch = N if self._params['batch_size'] else self._params['batch_size'] probs_z = pyro.sample("cnv_probs", dist.Dirichlet(self._params['t'] * torch.eye(self._params['hidden_dim']) + (1- self._params['t'])) .to_event(1)) pi = pyro.sample("pi", dist.Dirichlet(self._params['init_probs'])) z = pyro.sample("init_state", dist.Categorical(pi), infer={"enumerate": "parallel"}) with pyro.plate("data2", N, batch): theta = pyro.sample('norm_factor', dist.Gamma(self._params['theta_scale'], self._params['theta_rate'])) for i in pyro.markov(range(I)): z = pyro.sample("z_{}".format(i), dist.Categorical(Vindex(probs_z)[z]), infer={"enumerate": "parallel"}) with pyro.plate('data_{}'.format(i), N, batch): pyro.sample('obs_{}'.format(i), dist.Poisson((z * theta * self._data['mu'][i]) + 1e-8), obs=self._data['data'][i,:])
[docs] def guide(self,MAP = False,*args, **kwargs): return AutoDelta(poutine.block(self.model, expose=['pi', 'norm_factor', 'cnv_probs']), init_loc_fn=self.init_fn())
[docs] def init_fn(self): def init_function(site): I, N = self._data['data'].shape if site["name"] == "cnv_probs": return (self._params['t'] * torch.eye(self._params['hidden_dim']) + (1- self._params['t'])) if site["name"] == "norm_factor": return torch.mean(self._data['data'] / (2 * self._data['mu'].reshape(self._data['data'].shape[0],1)), axis=0) if site["name"] == "pi": return self._params['init_probs'] raise ValueError(site["name"]) return init_function