-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathfinal_classify.asv
More file actions
70 lines (55 loc) · 1.96 KB
/
final_classify.asv
File metadata and controls
70 lines (55 loc) · 1.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
function [OutClass] = final_classify(I,V,H,H_label,B_LLC,Coeff,database)
NumTrials = 1;
gridSpacing = 6;
patchSize = 16;
maxImSize = 300;
nrml_threshold = 1;
pyramid = [1 2 4]; % Levels in SPM
gamma = 0.15;
knn = 20;
lambda = 0.1;
[feaTest] = ExtractSIFT(I, gridSpacing, patchSize, maxImSize, nrml_threshold);
H_test = SPM_max_pooling(feaTest, V, pyramid, gamma);
tradeoff = 0.55;
[N_test, L_test] = exact_alm_rpca(H_test, tradeoff);
Coeff_test = LLC_coding_appr(B_LLC',H_test',900);
Coeff_test = Coeff_test';
%% Classification using linear SVM
fprintf('Classification...\n')
[dimFea, nFea] = size(Coeff);
clabel = unique(H_label);
accuracy = zeros(NumTrials, 1);
NumTrials = 1;
fprintf('Trial: ');
% TestIdx = [];
for iter2 = 1:18,
idx_label = find(H_label == clabel(iter2));
num = length(idx_label);
RandIdx = randperm(num);
% TrainIdx = [TrainIdx; idx_label(RandIdx(1:NumTrain))];
% TestIdx = [TestIdx; idx_label(RandIdx(NumTrain+1:end))];
end;
TrainFeat = Coeff;
TrainLabel = H_label;
TestFeat = Coeff_test;
% TestFeat = Coeff(:, TestIdx);
% TestLabel = H_label(TestIdx);
[w, b, class_name] = li2nsvm_multiclass_lbfgs(TrainFeat', TrainLabel', lambda);
[C, Y] = li2nsvm_multiclass_fwd(TestFeat', w, b, class_name);
% acc = zeros(length(class_name), 1);
%
% for iter2 = 1 : length(class_name),
% c = class_name(iter2);
% idx = find(TestLabel == c);
% curr_pred_label = C(idx);
% curr_gnd_label = TestLabel(idx);
% acc(iter2) = length(find(curr_pred_label == curr_gnd_label))/length(idx);
% end;
OutClass = C;
% accuracy(iter1) = mean(acc);
end;
% fprintf('Mean accuracy: %f\n', mean(accuracy));
% fprintf('Standard deviation: %f\n', std(accuracy));
fprintf('\n')
disp(database.cname{C});
end