From 3030e862eb32d5e2aefbbb626b14cd91d75f36a4 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sat, 7 Jul 2018 21:27:00 +0800 Subject: [PATCH] Register parameters of SVM --- shogun/classifier/svm/SVMLin.cpp | 14 ++++++++++++++ shogun/classifier/svm/SVMLin.h | 3 +++ shogun/classifier/svm/SVMSGD.cpp | 28 +++++++++++++++------------- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/shogun/classifier/svm/SVMLin.cpp b/shogun/classifier/svm/SVMLin.cpp index 388cf2f..10192d4 100644 --- a/shogun/classifier/svm/SVMLin.cpp +++ b/shogun/classifier/svm/SVMLin.cpp @@ -22,6 +22,7 @@ using namespace shogun; CSVMLin::CSVMLin() : CLinearMachine(), C1(1), C2(1), epsilon(1e-5), use_bias(true) { + init(); } CSVMLin::CSVMLin( @@ -30,6 +31,8 @@ CSVMLin::CSVMLin( { set_features(traindat); set_labels(trainlab); + + init(); } @@ -37,6 +40,17 @@ CSVMLin::~CSVMLin() { } +void CSVMLin::init() +{ + SG_ADD( + &use_bias, "use_bias", "Indicates if bias is used.", MS_NOT_AVAILABLE); + SG_ADD( + &C1, "C1", "C constant for negatively labeled examples.", MS_AVAILABLE); + SG_ADD( + &C2, "C2", "C constant for positively labeled examples.", MS_AVAILABLE); + SG_ADD(&epsilon, "epsilon", "Convergence precision.", MS_NOT_AVAILABLE); +} + bool CSVMLin::train_machine(CFeatures* data) { ASSERT(m_labels) diff --git a/shogun/classifier/svm/SVMLin.h b/shogun/classifier/svm/SVMLin.h index 950305f..0bc063c 100644 --- a/shogun/classifier/svm/SVMLin.h +++ b/shogun/classifier/svm/SVMLin.h @@ -104,6 +104,9 @@ class CSVMLin : public CLinearMachine */ virtual bool train_machine(CFeatures* data=NULL); + /** set up parameters */ + void init(); + protected: /** C1 */ float64_t C1; diff --git a/shogun/classifier/svm/SVMSGD.cpp b/shogun/classifier/svm/SVMSGD.cpp index bad90f8..885d58b 100644 --- a/shogun/classifier/svm/SVMSGD.cpp +++ b/shogun/classifier/svm/SVMSGD.cpp @@ -71,8 +71,7 @@ void CSVMSGD::set_loss_function(CLossFunction* loss_func) bool CSVMSGD::train_machine(CFeatures* data) { // allocate memory for w and initialize everyting w and bias with 0 - ASSERT(m_labels) - ASSERT(m_labels->get_label_type() == LT_BINARY) + auto labels = binary_labels(m_labels); if (data) { @@ -83,7 +82,7 @@ bool CSVMSGD::train_machine(CFeatures* data) ASSERT(features) - int32_t num_train_labels=m_labels->get_num_labels(); + int32_t num_train_labels = labels->get_num_labels(); int32_t num_vec=features->get_num_vectors(); ASSERT(num_vec==num_train_labels) @@ -122,7 +121,7 @@ bool CSVMSGD::train_machine(CFeatures* data) for (int32_t i=0; iget_label(i); + float64_t y = labels->get_label(i); float64_t z = y * (features->dense_dot(i, w.vector, w.vlen) + bias); if (z < 1 || is_log_loss) @@ -214,13 +213,16 @@ void CSVMSGD::init() loss=new CHingeLoss(); SG_REF(loss); - m_parameters->add(&C1, "C1", "Cost constant 1."); - m_parameters->add(&C2, "C2", "Cost constant 2."); - m_parameters->add(&wscale, "wscale", "W scale"); - m_parameters->add(&bscale, "bscale", "b scale"); - m_parameters->add(&epochs, "epochs", "epochs"); - m_parameters->add(&skip, "skip", "skip"); - m_parameters->add(&count, "count", "count"); - m_parameters->add(&use_bias, "use_bias", "Indicates if bias is used."); - m_parameters->add(&use_regularized_bias, "use_regularized_bias", "Indicates if bias is regularized."); + SG_ADD(&C1, "C1", "Cost constant 1.", MS_AVAILABLE); + SG_ADD(&C2, "C2", "Cost constant 2.", MS_AVAILABLE); + SG_ADD(&wscale, "wscale", "W scale", MS_NOT_AVAILABLE); + SG_ADD(&bscale, "bscale", "b scale", MS_NOT_AVAILABLE); + SG_ADD(&epochs, "epochs", "epochs", MS_NOT_AVAILABLE); + SG_ADD(&skip, "skip", "skip", MS_NOT_AVAILABLE); + SG_ADD(&count, "count", "count", MS_NOT_AVAILABLE); + SG_ADD( + &use_bias, "use_bias", "Indicates if bias is used.", MS_NOT_AVAILABLE); + SG_ADD( + &use_regularized_bias, "use_regularized_bias", + "Indicates if bias is regularized.", MS_NOT_AVAILABLE); }