SHOGUN  6.1.3
CGMShiftedFamilySolver.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Soumyajit De
8  */
9
10 #include <shogun/lib/common.h>
11
12
13 #include <shogun/lib/SGVector.h>
14 #include <shogun/lib/Time.h>
17
20
21 using namespace Eigen;
22
23 namespace shogun
24 {
25
26 CCGMShiftedFamilySolver::CCGMShiftedFamilySolver()
28 {
29 }
30
33 {
34 }
35
37 {
38 }
39
42 {
43  SGVector<complex128_t> shifts(1);
44  shifts[0]=0.0;
45  SGVector<complex128_t> weights(1);
46  weights[0]=1.0;
47
48  return solve_shifted_weighted(A, b, shifts, weights).get_real();
49 }
50
54 {
55  SG_DEBUG("Entering\n");
56
57  // sanity check
58  REQUIRE(A, "Operator is NULL!\n");
59  REQUIRE(A->get_dimension()==b.vlen, "Dimension mismatch! [%d vs %d]\n",
60  A->get_dimension(), b.vlen);
61  REQUIRE(shifts.vector,"Shifts are not initialized!\n");
62  REQUIRE(weights.vector,"Weights are not initialized!\n");
63  REQUIRE(shifts.vlen==weights.vlen, "Number of shifts and number of "
64  "weights are not equal! [%d vs %d]\n", shifts.vlen, weights.vlen);
65
66  // the solution matrix, one column per shift, initial guess 0 for all
67  MatrixXcd x_sh=MatrixXcd::Zero(b.vlen, shifts.vlen);
68  MatrixXcd p_sh=MatrixXcd::Zero(b.vlen, shifts.vlen);
69
70  // non-shifted direction
72
73  // the rest of the part hinges on eigen3 for computing norms
74  Map<VectorXd> b_map(b.vector, b.vlen);
75  Map<VectorXd> p(p_.vector, p_.vlen);
76
77  // residual r_i=b-Ax_i, here x_0=[0], so r_0=b
78  VectorXd r=b_map;
79
80  // initial direction is same as residual
81  p=r;
82  p_sh=r.replicate(1, shifts.vlen).cast<complex128_t>();
83
84  // non shifted initializers
85  float64_t r_norm2=r.dot(r);
86  float64_t beta_old=1.0;
87  float64_t alpha=1.0;
88
89  // shifted quantities
90  SGVector<complex128_t> alpha_sh(shifts.vlen);
91  SGVector<complex128_t> beta_sh(shifts.vlen);
92  SGVector<complex128_t> zeta_sh_old(shifts.vlen);
93  SGVector<complex128_t> zeta_sh_cur(shifts.vlen);
94  SGVector<complex128_t> zeta_sh_new(shifts.vlen);
95
96  // shifted initializers
97  zeta_sh_old.set_const(1.0);
98  zeta_sh_cur.set_const(1.0);
99
100  // the iterator for this iterative solver
103
104  // start the timer
105  CTime time;
106  time.start();
107
108  // set the residuals to zero
109  if (m_store_residuals)
110  m_residuals.set_const(0.0);
111
112  // CG iteration begins
113  for (it.begin(r); !it.end(r); ++it)
114  {
115
116  SG_DEBUG("CG iteration %d, residual norm %f\n",
117  it.get_iter_info().iteration_count,
118  it.get_iter_info().residual_norm);
119
120  if (m_store_residuals)
121  {
122  m_residuals[it.get_iter_info().iteration_count]
123  =it.get_iter_info().residual_norm;
124  }
125
126  // apply linear operator to the direction vector
127  SGVector<float64_t> Ap_=A->apply(p_);
128  Map<VectorXd> Ap(Ap_.vector, Ap_.vlen);
129
130  // compute p^{T}Ap, if zero, failure
131  float64_t p_dot_Ap=p.dot(Ap);
132  if (p_dot_Ap==0.0)
133  break;
134
135  // compute the beta parameter of CG_M
136  float64_t beta=-r_norm2/p_dot_Ap;
137
138  // compute the zeta-shifted parameter of CG_M
139  compute_zeta_sh_new(zeta_sh_old, zeta_sh_cur, shifts, beta_old, beta,
140  alpha, zeta_sh_new);
141
142  // compute beta-shifted parameter of CG_M
143  compute_beta_sh(zeta_sh_new, zeta_sh_cur, beta, beta_sh);
144
145  // update the solution vector and residual
146  for (index_t i=0; i<shifts.vlen; ++i)
147  x_sh.col(i)-=beta_sh[i]*p_sh.col(i);
148
149  // r_{i}=r_{i-1}+\beta_{i}Ap
150  r+=beta*Ap;
151
152  // compute new ||r||_{2}, if zero, converged
153  float64_t r_norm2_i=r.dot(r);
154  if (r_norm2_i==0.0)
155  break;
156
157  // compute the alpha parameter of CG_M
158  alpha=r_norm2_i/r_norm2;
159
160  // update ||r||_{2}
161  r_norm2=r_norm2_i;
162
163  // update direction
164  p=r+alpha*p;
165
166  compute_alpha_sh(zeta_sh_new, zeta_sh_cur, beta_sh, beta, alpha, alpha_sh);
167
168  for (index_t i=0; i<shifts.vlen; ++i)
169  {
170  p_sh.col(i)*=alpha_sh[i];
171  p_sh.col(i)+=zeta_sh_new[i]*r;
172  }
173
174  // update parameters
175  for (index_t i=0; i<shifts.vlen; ++i)
176  {
177  zeta_sh_old[i]=zeta_sh_cur[i];
178  zeta_sh_cur[i]=zeta_sh_new[i];
179  }
180  beta_old=beta;
181  }
182
183  float64_t elapsed=time.cur_time_diff();
184
185  if (!it.succeeded(r))
186  SG_WARNING("Did not converge!\n");
187
188  SG_INFO("Iteration took %d times, residual norm=%.20lf, time elapsed=%f\n",
189  it.get_iter_info().iteration_count, it.get_iter_info().residual_norm, elapsed);
190
191  // compute the final result vector multiplied by weights
192  SGVector<complex128_t> result(b.vlen);
193  result.set_const(0.0);
194  Map<VectorXcd> x(result.vector, result.vlen);
195
196  for (index_t i=0; i<x_sh.cols(); ++i)
197  x+=x_sh.col(i)*weights[i];
198
199  SG_DEBUG("Leaving\n");
200  return result;
201 }
202
203 }
Class Time that implements a stopwatch based on either cpu time or wall clock time.
Definition: Time.h:42
#define SG_INFO(...)
Definition: SGIO.h:117
std::complex< float64_t > complex128_t
Definition: common.h:77
void compute_zeta_sh_new(const SGVector< complex128_t > &zeta_sh_old, const SGVector< complex128_t > &zeta_sh_cur, const SGVector< complex128_t > &shifts, const float64_t &beta_old, const float64_t &beta_cur, const float64_t &alpha, SGVector< complex128_t > &zeta_sh_new)
void begin(const VectorXt &residual)
int32_t index_t
Definition: common.h:72
Definition: SGMatrix.h:25
#define REQUIRE(x,...)
Definition: SGIO.h:181
const bool end(const VectorXt &residual)
void compute_beta_sh(const SGVector< complex128_t > &zeta_sh_new, const SGVector< complex128_t > &zeta_sh_cur, const float64_t &beta_cur, SGVector< complex128_t > &beta_sh)
float64_t cur_time_diff(bool verbose=false)
Definition: Time.cpp:68
template class that is used as an iterator for an iterative linear solver. In the iteration of solvin...
double float64_t
Definition: common.h:60
float64_t start(bool verbose=false)
Definition: Time.cpp:59
virtual SGVector< T > apply(SGVector< T > b) const =0
void set_const(T const_elem)
Definition: SGVector.cpp:199
const index_t get_dimension() const
#define SG_DEBUG(...)
Definition: SGIO.h:106
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
SGVector< float64_t > get_real()
Definition: SGVector.cpp:940
const bool succeeded(const VectorXt &residual)
virtual SGVector< float64_t > solve(CLinearOperator< float64_t > *A, SGVector< float64_t > b)
#define SG_WARNING(...)
Definition: SGIO.h:127
virtual SGVector< complex128_t > solve_shifted_weighted(CLinearOperator< float64_t > *A, SGVector< float64_t > b, SGVector< complex128_t > shifts, SGVector< complex128_t > weights)
void compute_alpha_sh(const SGVector< complex128_t > &zeta_sh_cur, const SGVector< complex128_t > &zeta_sh_old, const SGVector< complex128_t > &beta_sh_old, const float64_t &beta_old, const float64_t &alpha, SGVector< complex128_t > &alpha_sh)
index_t vlen
Definition: SGVector.h:571
abstract template base for CG based solvers to the solution of shifted linear systems of the form fo...

SHOGUN Machine Learning Toolbox - Documentation