SHOGUN  6.1.3
CrossValidationStorage.h
Go to the documentation of this file.
1 /*
2 * BSD 3-Clause License
3 *
4 * Copyright (c) 2017, Shogun-Toolbox e.V. <shogun-team@shogun-toolbox.org>
5 * All rights reserved.
6 *
7 * Redistribution and use in source and binary forms, with or without
8 * modification, are permitted provided that the following conditions are met:
9 *
10 * * Redistributions of source code must retain the above copyright notice, this
11 * list of conditions and the following disclaimer.
12 *
13 * * Redistributions in binary form must reproduce the above copyright notice,
14 * this list of conditions and the following disclaimer in the documentation
15 * and/or other materials provided with the distribution.
16 *
17 * * Neither the name of the copyright holder nor the names of its
18 * contributors may be used to endorse or promote products derived from
19 * this software without specific prior written permission.
20 *
21 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22 * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29 * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 *
32 * Written (W) 2017 Giovanni De Toni
33 *
34 */
35 
36 #ifndef SHOGUN_CROSSVALIDATIONSTORAGE_H
37 #define SHOGUN_CROSSVALIDATIONSTORAGE_H
38 
39 #include <shogun/base/SGObject.h>
40 #include <shogun/lib/SGVector.h>
41 #include <vector>
42 
43 namespace shogun
44 {
45 
46  class CMachine;
47  class CLabels;
48  class CEvaluation;
49 
54  {
55  public:
58 
63  virtual void set_run_index(index_t run_index);
64 
69  virtual void set_fold_index(index_t fold_index);
70 
75  virtual void set_train_indices(SGVector<index_t> indices);
76 
81  virtual void set_test_indices(SGVector<index_t> indices);
82 
87  virtual void set_trained_machine(CMachine* machine);
88 
93  virtual void set_test_result(CLabels* results);
94 
99  virtual void set_test_true_result(CLabels* results);
100 
103  virtual void post_update_results();
104 
109  virtual void set_evaluation_result(float64_t result);
110 
116 
122 
127  const SGVector<index_t>& get_train_indices() const;
128 
133  const SGVector<index_t>& get_test_indices() const;
134 
139  CMachine* get_trained_machine() const;
140 
145  CLabels* get_test_result() const;
146 
151  CLabels* get_test_true_result() const;
152 
158 
164  bool operator==(const CrossValidationFoldStorage& rhs) const;
165 
170  virtual const char* get_name() const
171  {
172  return "CrossValidationFoldStorage";
173  };
174 
175  protected:
178 
181 
184 
187 
190 
193 
196 
199  };
200 
205  {
206  public:
209 
211  virtual ~CrossValidationStorage();
212 
217  virtual const char* get_name() const
218  {
219  return "CrossValidationStorage";
220  };
221 
225  virtual void set_num_runs(index_t num_runs);
226 
230  virtual void set_num_folds(index_t num_folds);
231 
235  virtual void set_expose_labels(CLabels* labels);
236 
238  virtual void post_init();
239 
244  virtual void append_fold_result(CrossValidationFoldStorage* result);
245 
250  index_t get_num_runs() const;
251 
256  index_t get_num_folds() const;
257 
262  CLabels* get_expose_labels() const;
263 
269  CrossValidationFoldStorage* get_fold(int fold) const;
270 
276  bool operator==(const CrossValidationStorage& rhs) const;
277 
278  protected:
281 
284 
287 
289  std::vector<CrossValidationFoldStorage*> m_folds_results;
290  };
291 }
292 
293 #endif // SHOGUN_CROSSVALIDATIONSTORAGE_H
virtual void set_test_true_result(CLabels *results)
virtual void set_run_index(index_t run_index)
int32_t index_t
Definition: common.h:72
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
virtual const char * get_name() const
virtual void set_test_result(CLabels *results)
A generic learning machine interface.
Definition: Machine.h:151
virtual void set_train_indices(SGVector< index_t > indices)
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:124
double float64_t
Definition: common.h:60
virtual void set_fold_index(index_t fold_index)
virtual void set_test_indices(SGVector< index_t > indices)
virtual void set_trained_machine(CMachine *machine)
virtual void set_evaluation_result(float64_t result)
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
const SGVector< index_t > & get_test_indices() const
virtual const char * get_name() const
std::vector< CrossValidationFoldStorage * > m_folds_results
const SGVector< index_t > & get_train_indices() const
bool operator==(const CrossValidationFoldStorage &rhs) const

SHOGUN Machine Learning Toolbox - Documentation