public class ConjugateGradientSolver extends Object
Implementation of a conjugate gradient iterative solver for linear systems. Implements both standard conjugate gradient and pre-conditioned conjugate gradient.
Conjugate gradient requires the matrix A in the linear system Ax = b to be symmetric and positive definite. For convenience, this implementation could be extended relatively easily to handle the case where the input matrix to be be non-symmetric, in which case the system A'Ax = b would be solved. Because this requires only one pass through the matrix A, it is faster than explicitly computing A'A, then passing the results to the solver.
For inputs that may be ill conditioned (often the case for highly sparse input), this solver also accepts a parameter, lambda, which adds a scaled identity to the matrix A, solving the system (A + lambda*I)x = b. This obviously changes the solution, but it will guarantee solvability. The ridge regression approach to linear regression is a common use of this feature.
If only an approximate solution is required, the maximum number of iterations or the error threshold may be specified to end the algorithm early at the expense of accuracy. When the matrix A is ill conditioned, it may sometimes be necessary to increase the maximum number of iterations above the default of A.numCols() due to numerical issues.
By default the solver will run a.numCols() iterations or until the residual falls below 1E-9.
For more information on the conjugate gradient algorithm, see Golub & van Loan, "Matrix Computations", sections 10.2 and 10.3 or the conjugate gradient wikipedia article.
Modifier and Type | Field and Description |
---|---|
static double |
DEFAULT_MAX_ERROR |
Constructor and Description |
---|
ConjugateGradientSolver() |
Modifier and Type | Method and Description |
---|---|
int |
getIterations()
Returns the number of iterations run once the solver is complete.
|
double |
getResidualNorm()
Returns the norm of the residual at the completion of the solver.
|
Vector |
solve(VectorIterable a,
Vector b)
Solves the system Ax = b with default termination criteria.
|
Vector |
solve(VectorIterable a,
Vector b,
Preconditioner precond)
Solves the system Ax = b with default termination criteria using the specified preconditioner.
|
Vector |
solve(VectorIterable a,
Vector b,
Preconditioner preconditioner,
int maxIterations,
double maxError)
Solves the system Ax = b, where A is a linear operator and b is a vector.
|
public static final double DEFAULT_MAX_ERROR
public Vector solve(VectorIterable a, Vector b)
a
- The linear operator A.b
- The vector b.IllegalArgumentException
- if a is not square or if the size of b is not equal to the number of columns of a.public Vector solve(VectorIterable a, Vector b, Preconditioner precond)
a
- The linear operator A.b
- The vector b.precond
- A preconditioner to use on A during the solution process.IllegalArgumentException
- if a is not square or if the size of b is not equal to the number of columns of a.public Vector solve(VectorIterable a, Vector b, Preconditioner preconditioner, int maxIterations, double maxError)
a
- The matrix A.b
- The vector b.preconditioner
- The preconditioner to apply.maxIterations
- The maximum number of iterations to run.maxError
- The maximum amount of residual error to tolerate. The algorithm will run until the residual falls
below this value or until maxIterations are completed.IllegalArgumentException
- if the matrix is not square, if the size of b is not equal to the number of
columns of A, if maxError is less than zero, or if maxIterations is not positive.public int getIterations()
public double getResidualNorm()
Copyright © 2008–2017 The Apache Software Foundation. All rights reserved.