-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathexample_SI_ABC.m
More file actions
166 lines (130 loc) · 6.01 KB
/
example_SI_ABC.m
File metadata and controls
166 lines (130 loc) · 6.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
%% SSIT/Examples/example_SI_ABC.m
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Approximate Bayesian Computation (ABC)
% in the SSIT using `runABCsearch'
%
% This example:
% 1. Loads a template model for scRNA-seq genes with SSA solution scheme.
% 2. Associates the template model with scRNA-seq data for gene TSC22D3.
% 3. Defines a prior over parameters.
% 4. Runs ABC via Metropolis–Hastings using 'cdf_one_norm' loss.
% 5. Visualizes the ABC results.
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Set SSA options:
% Make copy of scRNAseq template model:
scRNAseq = Model_Template;
%Set solution scheme to SSA:
scRNAseq.solutionScheme = 'SSA';
% Set number of simulations performed per experiment (small # for demo):
scRNAseq.ssaOptions.nSimsPerExpt=100;
% Equilibrate before starting (burn-in):
scRNAseq.tSpan = [-100,scRNAseq.tSpan];
% Run iterations in parallel with multiple cores:
scRNAseq.ssaOptions.useParallel = true;
%% Associate scRNA-seq data for gene TSC22D3:
scRNAseq = scRNAseq.loadData('data/Raw_DEX_UpRegulatedGenes_ForSSIT.csv',...
{'rna','TSC22D3'});
% Choose which parameters to fit:
fitpars = 1:9;
scRNAseq.fittingOptions.modelVarsToFit = [fitpars];
%% Set up a prior over parameters (logPriorLoss)
% logPriorLoss should return a *loss* (positive penalty); smaller is better.
% A convenient choice is a quadratic penalty in log10-parameter space,
% corresponding to a log-normal prior.
theta0 = cell2mat(scRNAseq.parameters([fitpars],2));
log10_mu = log10(theta0(:));
log10_sigma = 2 * ones(size(log10_mu)); % std dev in log10-space
% Define prior "loss" (default, @(x)allFitOptions.obj(exp(x))):
logPriorLoss = [];
%% Set ABC / MCMC options
% runABCsearch passes 'fitOptions' to maximizeLikelihood with the
% 'MetropolisHastings' algorithm. Tune these depending on your problem size.
fitOptions = struct();
fitOptions.numberOfSamples = 500; % Total MH iterations
fitOptions.burnIn = 10; % Discard burn-in samples
fitOptions.thin = 1; % Keep every nth sample
proposalWidthScale = 0.5; % Proposal scale
% Proposal distribution:
fitOptions.proposalDistribution = @(x)x+proposalWidthScale*randn(size(x));
% Log prior:
fitOptions.logPrior = @(z) -sum((z-log10_mu).^2./(2*log10_sigma.^2));
% Initial parameter guess (optional, default: current Model.parameters):
%parGuess = [];
% In this case, parGuess = []; is the same as the default:
parGuess = cell2mat(scRNAseq.parameters((fitpars),2));
% Choose loss function for ABC (default: 'cdf_one_norm'):
lossFunction = 'cdf_one_norm';
% Enforce independence by downsampling SSA trajectories:
enforceIndependence = true;
%% Run ABC search
% This will:
% * repeatedly simulate SSA trajectories,
% * compute a CDF-based loss against the data,
% * add the prior penalty, and
% * perform MH sampling to approximate the posterior.
%
% Outputs:
% pars - "best" (minimum-loss) parameter set found
% minimumLossFunction - value of the loss at that point
% Results - MH/ABC diagnostics and chains
% ModelABC - model updated with 'pars'
% Compile and store the given reaction propensities:
scRNAseq = scRNAseq.formPropensitiesGeneral('scRNAseq');
[parsABC, minimumLoss, ResultsABC, scRNAseq] = ...
scRNAseq.runABCsearch(parGuess, lossFunction, logPriorLoss,...
fitOptions, enforceIndependence);
fprintf('ABC completed.\n');
fprintf('Minimum loss value: %g\n', minimumLoss);
disp('Best-fit parameters (ABC):');
disp(parsABC(:).');
%% Inspect ABC results:
% The 'ResultsABC' struct is returned by maximizeLikelihood with the
% 'MetropolisHastings' algorithm.
% ResultsABC.mhSamples - MCMC chain of parameter samples
% ResultsABC.mhValue - corresponding loss values
% ResultsABC.mhAcceptance - MH acceptance fraction
% Below we show a simple marginal histogram for each fitted parameter.
if isfield(ResultsABC, 'mhSamples')
parChain = ResultsABC.mhSamples; % size: [numberOfSamples x nPars]
nPars = size(parChain, 2);
figure;
for k = 1:nPars
subplot(ceil(nPars/2), 2, k);
histogram(parChain(:,k), 40, 'Normalization', 'pdf');
hold on;
xline(parGuess(k), 'b', 'LineWidth', 1.5);
xline(parsABC(k), 'r', 'LineWidth', 1.5);
xline(cell2mat(Model_TSC22D3.parameters(k,2)),'g','LineWidth',1.5);
title(sprintf('Parameter %d', k));
xlabel('\theta_k');
ylabel('Posterior density (approx.)');
end
sgtitle('ABC posterior marginals (approximate)');
else
warning('ResultsABC.mhSamples not found.');
end
%% Compare initial vs. final (ABC) parameter losses
%% (Experimental data vs. data simulated by the SSA):
% Minimum loss from ABC run:
minLoss = minimumLoss;
nTimes = sum(scRNAseq.fittingOptions.timesToFit);
% Set the number of species being fitting (in this case, only mRNA):
nSpecies = 1;
% Replicate groups / dose groups etc., if applicable:
nConds = 1;
avgLossPerCDF = minLoss / (nTimes * nSpecies * nConds);
fprintf('Average CDF L1 discrepancy per time/species: %.4f\n',avgLossPerCDF);
L_init = scRNAseq.computeLossFunctionSSA(lossFunction,...
theta0, enforceIndependence);
L_min = minimumLoss;
fprintf('Initial loss: %.3f, Final (min) loss: %.3f\n', L_init, L_min);
fprintf('Relative improvement: %.1f%%\n', 100 * (L_init - L_min)/L_init);
%% Compare ABC posterior sample to MLE
% TODO: Overlay parameter values and compute predictive distributions.
%load('seqModels/Model_TSC22D3.mat')
theta_TSC22D3 = cell2mat(Model_TSC22D3.parameters(1:9,2));
L_MLE = scRNAseq.computeLossFunctionSSA(lossFunction,...
theta_TSC22D3, enforceIndependence);
L_min = minimumLoss;
fprintf('MLE loss: %.3f, Final (min) ABC loss: %.3f\n', L_MLE, L_min);
fprintf('Relative improvement: %.1f%%\n', 100 * (L_MLE - L_min)/L_MLE);