SHOGUN  6.1.3
CrossValidationStorage.cpp
Go to the documentation of this file.
1 /*
2 * BSD 3-Clause License
3 *
4 * Copyright (c) 2017, Shogun-Toolbox e.V. <shogun-team@shogun-toolbox.org>
5 * All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions are met:
9 *
10 * * Redistributions of source code must retain the above copyright notice, this
11 * list of conditions and the following disclaimer.
12 *
13 * * Redistributions in binary form must reproduce the above copyright notice,
14 * this list of conditions and the following disclaimer in the documentation
15 * and/or other materials provided with the distribution.
16 *
17 * * Neither the name of the copyright holder nor the names of its
18 * contributors may be used to endorse or promote products derived from
19 * this software without specific prior written permission.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 *
32 * Written (W) 2017 Giovanni De Toni
33 *
34 */
35 
36 #include "CrossValidationStorage.h"
37 #include <shogun/labels/Labels.h>
38 #include <shogun/machine/Machine.h>
39 
40 using namespace shogun;
41 
43 {
46  m_trained_machine = NULL;
47  m_test_result = NULL;
48  m_test_true_result = NULL;
49 
50  SG_ADD(
51  &m_current_run_index, "m_current_run_index",
52  "The current run index of this fold", MS_AVAILABLE)
53  SG_ADD(
54  &m_current_fold_index, "m_current_fold_index", "The current fold index",
56  SG_ADD(
57  (CSGObject**)&m_trained_machine, "m_trained_machine",
58  "The machine trained by this fold", MS_AVAILABLE)
59  SG_ADD(
60  (CSGObject**)&m_test_result, "m_test_result",
61  "The test result of this fold", MS_AVAILABLE)
62  SG_ADD(
63  (CSGObject**)&m_test_true_result, "m_test_true_result",
64  "The true test result for this fold", MS_AVAILABLE)
65 }
66 
68 {
72 }
73 
75 {
76  m_current_run_index = run_index;
77 }
78 
80 {
81  m_current_fold_index = fold_index;
82 }
83 
85 {
86  m_train_indices = indices;
87 }
88 
90 {
91  m_test_indices = indices;
92 }
93 
95 {
96  SG_REF(machine)
98  m_trained_machine = machine;
99 }
100 
102 {
103  SG_REF(results)
105  m_test_result = results;
106 }
107 
109 {
110  SG_REF(results)
112  m_test_true_result = results;
113 }
114 
116 {
117 }
118 
120 {
121  m_evaluation_result = result;
122 }
123 
125 {
126  return m_current_run_index;
127 }
128 
130 {
131  return m_current_fold_index;
132 }
133 
135 {
136  return m_train_indices;
137 }
138 
140 {
141  return m_test_indices;
142 }
143 
145 {
146  return m_trained_machine;
147 }
148 
150 {
151  return m_test_result;
152 }
153 
155 {
156  return m_test_true_result;
157 }
158 
160 {
161  return m_evaluation_result;
162 }
163 
165 {
166  REQUIRE(
167  fold < get_num_folds(), "The fold number must be less than %i",
168  get_num_folds())
169 
170  CrossValidationFoldStorage* fld = m_folds_results[fold];
171  SG_REF(fld);
172  return fld;
173 }
174 
177 {
180  // m_train_indices.equals(rhs.m_train_indices) &&
181  // m_test_indices.equals(rhs.m_test_indices) &&
186 }
187 
191 {
192  m_num_runs = 0;
193  m_num_folds = 0;
194  m_expose_labels = NULL;
195 
196  SG_ADD(
197  &m_num_runs, "m_num_runs", "The total number of cross-validation runs",
198  MS_AVAILABLE)
199  SG_ADD(
200  &m_num_folds, "m_num_folds",
201  "The total number of cross-validation folds", MS_AVAILABLE)
202  SG_ADD(
203  (CSGObject**)&m_expose_labels, "m_expose_labels",
204  "The labels used for this cross-validation", MS_AVAILABLE)
205 }
206 
208 {
210  for (auto i : m_folds_results)
211  SG_UNREF(i)
212 }
213 
215 {
216  m_num_runs = num_runs;
217 }
218 
220 {
221  m_num_folds = num_folds;
222 }
223 
225 {
226  SG_REF(labels)
228  m_expose_labels = labels;
229 }
230 
232 {
233 }
234 
236 {
237  return m_num_runs;
238 }
239 
241 {
242  return m_num_folds;
243 }
244 
246 {
247  return m_expose_labels;
248 }
249 
252 {
253  SG_REF(result);
254  m_folds_results.push_back(result);
255 }
256 
258 {
259  auto member_vars = m_num_runs == rhs.m_num_runs &&
260  m_num_folds == rhs.m_num_folds &&
262 
263  if (!member_vars)
264  return member_vars;
265 
266  if (rhs.m_folds_results.size() != m_folds_results.size())
267  return false;
268  for (index_t i = 0; i < m_folds_results.size(); i++)
269  {
270  if (!(m_folds_results[i] == rhs.m_folds_results[i]))
271  return false;
272  }
273  return member_vars;
274 }
virtual void set_test_true_result(CLabels *results)
CrossValidationFoldStorage * get_fold(int fold) const
virtual void set_run_index(index_t run_index)
int32_t index_t
Definition: common.h:72
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
#define REQUIRE(x,...)
Definition: SGIO.h:181
virtual void set_test_result(CLabels *results)
#define SG_REF(x)
Definition: SGObject.h:52
A generic learning machine interface.
Definition: Machine.h:151
virtual void set_train_indices(SGVector< index_t > indices)
virtual void append_fold_result(CrossValidationFoldStorage *result)
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:124
double float64_t
Definition: common.h:60
virtual void set_fold_index(index_t fold_index)
virtual void set_test_indices(SGVector< index_t > indices)
virtual void set_trained_machine(CMachine *machine)
virtual void set_expose_labels(CLabels *labels)
virtual bool equals(CSGObject *other, float64_t accuracy=0.0, bool tolerant=false)
Definition: SGObject.cpp:656
virtual void set_evaluation_result(float64_t result)
bool operator==(const CrossValidationStorage &rhs) const
#define SG_UNREF(x)
Definition: SGObject.h:53
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
const SGVector< index_t > & get_test_indices() const
virtual void set_num_folds(index_t num_folds)
std::vector< CrossValidationFoldStorage * > m_folds_results
#define SG_ADD(...)
Definition: SGObject.h:93
virtual void set_num_runs(index_t num_runs)
const SGVector< index_t > & get_train_indices() const
bool operator==(const CrossValidationFoldStorage &rhs) const

SHOGUN Machine Learning Toolbox - Documentation