-
Notifications
You must be signed in to change notification settings - Fork 2
/
DeepInsight3D.asv
53 lines (47 loc) · 1.78 KB
/
DeepInsight3D.asv
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
function [AUC,C,Accuracy,ValErr] = DeepInsight3D(DSETnum,Parm)
%[AUC,C,Accuracy,ValErr] = DeepInsight3D(DSETnum)
% DeepInsight3D function
% DSETnum is the number of dataset; e.g. dataset1.mat. dataset2.mat,...
%
% AUC (for 2-class problem) otherwise returns an empty matrix
% C is the confusion matrix of the test set
% Accuracy is the test accuracy
% ValErr is the validation error of the validation set
%
% contact: alok.fj@gmail.com
close all;
fid2 = fopen('DeepInsight3D_Results.txt','a+');
fprintf(fid2,'\n');
fprintf(fid2,'%s',Parm.FileRun);
fprintf(fid2,'\n');
fprintf(fid2,'SnowFall: %d\n',Parm.SnowFall);
fprintf(fid2,'Method: %s\n',Parm.Method);
if any(strcmp('Dist',fieldnames(Parm)))==1
fprintf(fid2,'Distance: %s\n',Parm.Dist);
else
fprintf(fid2,'Distance is not applicable or Deafult\n');
end
fprintf(fid2,'Use Previous Model: %s\n',Parm.UsePrevModel);
% Convert tabular data to image
[InputSz1,InputSz2,InputSz3,Init_dim,SET] = func_Prepare_Data(Parm);
% Run CNN net
display('Training model begins: Net1');
[Accuracy(Parm.Stage),ValErr(Parm.Stage),Momentum(Parm.Stage),L2Reg(Parm.Stage),...
InitLR(Parm.Stage),AUC(Parm.Stage),C,prob] = func_TrainModel(Parm);
fprintf(fid2,'Net: %s\n',Parm.NetName);
fprintf(fid2,'ObjFcnMeasure: %s\n',Parm.ObjFcnMeasure);
fprintf('Stage: %d; Test Accuracy: %6.4f; ValErr: %4.4f; \n',Parm.Stage,Accuracy(Parm.Stage),ValErr(Parm.Stage));
fprintf('Momentum: %g; L2Regularization: %g; InitLearnRate: %g\n',Momentum(Parm.Stage),L2Reg(Parm.Stage),InitLR(Parm.Stage));
if size(C,1)==2
fprintf(fid2,'AUC: %6.4f; \n',AUC(Parm.Stage));
end
fprintf('Confusion Matrix\n');
fprintf(fid2,'\nConfusionMatrix\n');
for nC=1:size(C,2)
fprintf(fid2,'%d\t',C(nC,:));
fprintf(fid2,'\n');
end
display('Training model ends');
fprintf('\n');
Parm.Stage=Parm.Stage+1;
end