SHOGUN  6.1.3
Machine.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Written (W) 2011-2012 Heiko Strathmann
9  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
10  */
11 
12 #ifndef _MACHINE_H__
13 #define _MACHINE_H__
14 
15 #include <shogun/lib/config.h>
16 
17 #include <shogun/lib/common.h>
18 #include <shogun/base/SGObject.h>
25 
26 #include <condition_variable>
27 #include <mutex>
28 
29 namespace shogun
30 {
31 
32 class CFeatures;
33 class CLabels;
34 
37 {
38  CT_NONE = 0,
39  CT_LIGHT = 10,
41  CT_LIBSVM = 20,
44  CT_MPD = 50,
45  CT_GPBT = 60,
49  CT_LDA = 100,
50  CT_LPM = 110,
51  CT_LPBOOST = 120,
52  CT_KNN = 130,
53  CT_SVMLIN=140,
55  CT_GNPPSVM = 160,
56  CT_GMNPSVM = 170,
57  CT_SVMPERF = 200,
58  CT_LIBSVR = 210,
59  CT_SVRLIGHT = 220,
60  CT_LIBLINEAR = 230,
61  CT_KMEANS = 240,
63  CT_SVMOCAS = 260,
64  CT_WDSVMOCAS = 270,
65  CT_SVMSGD = 280,
71  CT_DASVM = 340,
72  CT_LARANK = 350,
76  CT_SGDQN = 390,
80  CT_QDA = 430,
81  CT_NEWTONSVM = 440,
83  CT_LARS = 460,
89  CT_CCSOSVM = 520,
94  CT_BAGGING = 570,
95  CT_FWSOSVM = 580,
96  CT_BCFWSOSVM = 590,
98 };
99 
102 {
110 };
111 
114 {
121 };
122 
123 #define MACHINE_PROBLEM_TYPE(PT) \
124  \
127  virtual EProblemType get_machine_problem_type() const { return PT; }
128 
129 #define COMPUTATION_CONTROLLERS \
130  if (cancel_computation()) \
131  continue; \
132  pause_computation();
133 
151 class CMachine : public CSGObject
152 {
153  public:
155  CMachine();
156 
158  virtual ~CMachine();
159 
169  virtual bool train(CFeatures* data=NULL);
170 
177  virtual CLabels* apply(CFeatures* data=NULL);
178 
180  virtual CBinaryLabels* apply_binary(CFeatures* data=NULL);
182  virtual CRegressionLabels* apply_regression(CFeatures* data=NULL);
184  virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
186  virtual CStructuredLabels* apply_structured(CFeatures* data=NULL);
188  virtual CLatentLabels* apply_latent(CFeatures* data=NULL);
189 
194  virtual void set_labels(CLabels* lab);
195 
200  virtual CLabels* get_labels();
201 
207 
213 
219 
224  void set_solver_type(ESolverType st);
225 
231 
237  virtual void set_store_model_features(bool store_model);
238 
239 #ifndef SWIG // SWIG should skip this part
240 
248  virtual bool train_locked(SGVector<index_t> indices)
249  {
250  SG_ERROR("train_locked(SGVector<index_t>) is not yet implemented "
251  "for %s\n", get_name());
252  return false;
253  }
254 #endif // SWIG // SWIG should skip this part
255 
257  virtual float64_t apply_one(int32_t i)
258  {
260  return 0.0;
261  }
262 
263 #ifndef SWIG // SWIG should skip this part
264 
269  virtual CLabels* apply_locked(SGVector<index_t> indices);
270 
273  SGVector<index_t> indices);
276  SGVector<index_t> indices);
279  SGVector<index_t> indices);
282  SGVector<index_t> indices);
285  SGVector<index_t> indices);
286 #endif // SWIG // SWIG should skip this part
287 
296  virtual void data_lock(CLabels* labs, CFeatures* features);
297 
299  virtual void post_lock(CLabels* labs, CFeatures* features) { };
300 
302  virtual void data_unlock();
303 
305  virtual bool supports_locking() const { return false; }
306 
308  bool is_data_locked() const { return m_data_locked; }
309 
312  {
314  return PT_BINARY;
315  }
316 
317 #ifndef SWIG
318 
320  {
321  return m_cancel_computation.load();
322  }
323 #endif
324 
325 #ifndef SWIG
326 
328  {
329  if (m_pause_computation_flag.load())
330  {
331  std::unique_lock<std::mutex> lck(m_mutex);
332  while (m_pause_computation_flag.load())
333  m_pause_computation.wait(lck);
334  }
335  }
336 #endif
337 
338 #ifndef SWIG
339 
341  {
342  std::unique_lock<std::mutex> lck(m_mutex);
343  m_pause_computation_flag = false;
344  m_pause_computation.notify_all();
345  }
346 #endif
347 
348  virtual const char* get_name() const { return "Machine"; }
349 
350  protected:
361  virtual bool train_machine(CFeatures* data=NULL)
362  {
363  SG_ERROR("train_machine is not yet implemented for %s!\n",
364  get_name());
365  return false;
366  }
367 
378  virtual void store_model_features()
379  {
380  SG_ERROR("Model storage and therefore unlocked Cross-Validation and"
381  " Model-Selection is not supported for %s. Locked may"
382  " work though.\n", get_name());
383  }
384 
391  virtual bool is_label_valid(CLabels *lab) const
392  {
393  return true;
394  }
395 
397  virtual bool train_require_labels() const { return true; }
398 
400  rxcpp::subscription connect_to_signal_handler();
401 
404  {
405  m_cancel_computation = false;
406  m_pause_computation_flag = false;
407  }
408 
411  virtual void on_next()
412  {
413  m_cancel_computation.store(true);
414  }
415 
418  virtual void on_pause()
419  {
420  m_pause_computation_flag.store(true);
421  /* Here there should be the actual code*/
423  }
424 
427  virtual void on_complete()
428  {
429  }
430 
431  protected:
434 
437 
440 
443 
446 
448  std::atomic<bool> m_cancel_computation;
449 
451  std::atomic<bool> m_pause_computation_flag;
452 
454  std::condition_variable m_pause_computation;
455 
457  std::mutex m_mutex;
458 };
459 }
460 #endif // _MACHINE_H__
virtual float64_t apply_one(int32_t i)
Definition: Machine.h:257
std::atomic< bool > m_pause_computation_flag
Definition: Machine.h:451
EMachineType
Definition: Machine.h:36
void set_max_train_time(float64_t t)
Definition: Machine.cpp:89
Base class of the labels used in Structured Output (SO) problems.
Real Labels are real-valued labels.
virtual CLabels * apply_locked(SGVector< index_t > indices)
Definition: Machine.cpp:194
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
ESolverType
Definition: Machine.h:101
float64_t m_max_train_time
Definition: Machine.h:433
CLabels * m_labels
Definition: Machine.h:436
void reset_computation_variables()
Definition: Machine.h:403
#define SG_ERROR(...)
Definition: SGIO.h:128
#define SG_NOTIMPLEMENTED
Definition: SGIO.h:138
ESolverType m_solver_type
Definition: Machine.h:439
bool m_data_locked
Definition: Machine.h:445
virtual CStructuredLabels * apply_locked_structured(SGVector< index_t > indices)
Definition: Machine.cpp:266
virtual bool train_machine(CFeatures *data=NULL)
Definition: Machine.h:361
virtual void on_complete()
Definition: Machine.h:427
bool m_store_model_features
Definition: Machine.h:442
std::atomic< bool > m_cancel_computation
Definition: Machine.h:448
virtual const char * get_name() const
Definition: Machine.h:348
virtual bool train_locked(SGVector< index_t > indices)
Definition: Machine.h:248
A generic learning machine interface.
Definition: Machine.h:151
#define SG_FORCED_INLINE
Definition: common.h:91
SG_FORCED_INLINE void resume_computation()
Definition: Machine.h:340
std::condition_variable m_pause_computation
Definition: Machine.h:454
Multiclass Labels for multi-class classification.
virtual CBinaryLabels * apply_binary(CFeatures *data=NULL)
Definition: Machine.cpp:215
virtual void on_next()
Definition: Machine.h:411
virtual void set_store_model_features(bool store_model)
Definition: Machine.cpp:114
EProblemType
Definition: Machine.h:113
virtual ~CMachine()
Definition: Machine.cpp:38
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:124
rxcpp::subscription connect_to_signal_handler()
Definition: Machine.cpp:280
double float64_t
Definition: common.h:60
virtual CRegressionLabels * apply_regression(CFeatures *data=NULL)
Definition: Machine.cpp:221
virtual void data_unlock()
Definition: Machine.cpp:150
virtual void data_lock(CLabels *labs, CFeatures *features)
Definition: Machine.cpp:119
virtual CLabels * get_labels()
Definition: Machine.cpp:83
float64_t get_max_train_time()
Definition: Machine.cpp:94
ESolverType get_solver_type()
Definition: Machine.cpp:109
virtual CLatentLabels * apply_latent(CFeatures *data=NULL)
Definition: Machine.cpp:239
virtual EMachineType get_classifier_type()
Definition: Machine.cpp:99
virtual EProblemType get_machine_problem_type() const
Definition: Machine.h:311
virtual CRegressionLabels * apply_locked_regression(SGVector< index_t > indices)
Definition: Machine.cpp:252
virtual void store_model_features()
Definition: Machine.h:378
virtual bool supports_locking() const
Definition: Machine.h:305
virtual CMulticlassLabels * apply_locked_multiclass(SGVector< index_t > indices)
Definition: Machine.cpp:259
SG_FORCED_INLINE bool cancel_computation() const
Definition: Machine.h:319
virtual CStructuredLabels * apply_structured(CFeatures *data=NULL)
Definition: Machine.cpp:233
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
virtual void post_lock(CLabels *labs, CFeatures *features)
Definition: Machine.h:299
virtual bool is_label_valid(CLabels *lab) const
Definition: Machine.h:391
The class Features is the base class of all feature objects.
Definition: Features.h:69
virtual CBinaryLabels * apply_locked_binary(SGVector< index_t > indices)
Definition: Machine.cpp:245
virtual bool train(CFeatures *data=NULL)
Definition: Machine.cpp:43
Binary Labels for binary classification.
Definition: BinaryLabels.h:37
std::mutex m_mutex
Definition: Machine.h:457
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
Definition: Machine.cpp:227
virtual bool train_require_labels() const
Definition: Machine.h:397
virtual CLatentLabels * apply_locked_latent(SGVector< index_t > indices)
Definition: Machine.cpp:273
virtual void set_labels(CLabels *lab)
Definition: Machine.cpp:72
SG_FORCED_INLINE void pause_computation()
Definition: Machine.h:327
abstract class for latent labels As latent labels always depends on the given application, this class only defines the API that the user has to implement for latent labels.
Definition: LatentLabels.h:26
bool is_data_locked() const
Definition: Machine.h:308
void set_solver_type(ESolverType st)
Definition: Machine.cpp:104
virtual void on_pause()
Definition: Machine.h:418
virtual CLabels * apply(CFeatures *data=NULL)
Definition: Machine.cpp:159

SHOGUN Machine Learning Toolbox - Documentation