[Top][All Lists]
[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]
[Toon-members] TooN SVD.h test/svd_test.cc
From: |
Edward Rosten |
Subject: |
[Toon-members] TooN SVD.h test/svd_test.cc |
Date: |
Wed, 22 Apr 2009 15:42:27 +0000 |
CVSROOT: /cvsroot/toon
Module name: TooN
Changes by: Edward Rosten <edrosten> 09/04/22 15:42:27
Modified files:
. : SVD.h
Added files:
test : svd_test.cc
Log message:
Fix SVD?? for non square matrices.
Simple test for SVD of square matrices.
CVSWeb URLs:
http://cvs.savannah.gnu.org/viewcvs/TooN/test/svd_test.cc?cvsroot=toon&rev=1.1
http://cvs.savannah.gnu.org/viewcvs/TooN/SVD.h?cvsroot=toon&r1=1.16&r2=1.17
Patches:
Index: SVD.h
===================================================================
RCS file: /cvsroot/toon/TooN/SVD.h,v
retrieving revision 1.16
retrieving revision 1.17
diff -u -b -r1.16 -r1.17
--- SVD.h 20 Apr 2009 21:02:07 -0000 1.16
+++ SVD.h 22 Apr 2009 15:42:27 -0000 1.17
@@ -41,7 +41,64 @@
+namespace Internal{
+ template<int Rows, int Cols, bool Dynamic = (Rows == -1 || Cols == -1),
bool IsVertical = (Rows >= Cols)> struct UVT
+ {
+ typedef Matrix<Dynamic> U_type;
+ typedef Matrix<Dynamic> VT_type;
+
+ static const Matrix<Dynamic>& get_U(const Matrix<Dynamic>&
copy, const Matrix<Dynamic>& square)
+ {
+ if(copy.num_rows() >= copy.num_cols())
+ return copy;
+ else
+ return square;
+ }
+
+ static const Matrix<Dynamic>& get_VT(const Matrix<Dynamic>&
copy, const Matrix<Dynamic>& square)
+ {
+ if(copy.num_rows() >= copy.num_cols())
+ return square;
+ else
+ return copy;
+ }
+ };
+
+ template<int Rows, int Cols> struct UVT<Rows, Cols, 0, 1>
+ {
+ static const int Min = Rows<Cols?Rows:Cols;
+
+ typedef Matrix<Rows,Cols> U_type;
+ typedef Matrix<Min,Min> VT_type;
+
+ static const Matrix<Rows, Cols>& get_U(const Matrix<Rows,Cols>&
copy, const Matrix<Min, Min>& square)
+ {
+ return copy;
+ }
+
+ static const Matrix<Min, Min>& get_VT(const Matrix<Rows, Cols>&
copy, const Matrix<Min, Min>& square)
+ {
+ return square;
+ }
+ };
+
+ template<int Rows, int Cols> struct UVT<Rows, Cols, 0, 0>
+ {
+ static const int Min = Rows<Cols?Rows:Cols;
+ typedef Matrix<Min,Min> U_type;
+ typedef Matrix<Rows,Cols> VT_type;
+ static const Matrix<Min, Min>& get_U(const Matrix<Rows, Cols>&
copy, const Matrix<Min, Min>& square)
+ {
+ return square;
+ }
+
+ static const Matrix<Rows, Cols>& get_VT(const
Matrix<Rows,Cols>& copy, const Matrix<Min, Min>& square)
+ {
+ return copy;
+ }
+ };
+}
/**
@@ -229,29 +286,19 @@
/// Return the U matrix from the decomposition
/// The size of this depends on the shape of the original matrix
/// it is square if the original matrix is wide or tall if the original
matrix is tall
- Matrix<Rows,Min_Dim,Precision,RowMajor>& get_U(){
- if(is_vertical()){
- return my_copy;
- } else {
- return my_square;
+ typename Internal::UVT<Rows, Cols>::U_type get_U()
+ {
+ return Internal::UVT<Rows, Cols>::get_U(my_copy, my_square);
}
+
+ typename Internal::UVT<Rows, Cols>::VT_type get_VT()
+ {
+ return Internal::UVT<Rows, Cols>::get_VT(my_copy, my_square);
}
/// Return the singular values as a vector
Vector<Min_Dim,Precision>& get_diagonal(){ return my_diagonal; }
- /// Return the VT matrix from the decomposition
- /// The size of this depends on the shape of the original matrix
- /// it is square if the original matrix is tall or wide if the original
matrix is wide
- Matrix<Min_Dim,Cols,Precision,RowMajor>& get_VT(){
- if(is_vertical()){
- return my_square;
- } else {
- return my_copy;
- }
- }
-
-
void get_inv_diag(Vector<Min_Dim>& inv_diag, const Precision condition){
for(int i=0; i<min_dim(); i++){
if(my_diagonal[i] * condition <= my_diagonal[0]){
Index: test/svd_test.cc
===================================================================
RCS file: test/svd_test.cc
diff -N test/svd_test.cc
--- /dev/null 1 Jan 1970 00:00:00 -0000
+++ test/svd_test.cc 22 Apr 2009 15:42:27 -0000 1.1
@@ -0,0 +1,22 @@
+#include <TooN/SVD.h>
+#include <TooN/helpers.h>
+using namespace TooN;
+using namespace std;
+
+int main()
+{
+ Matrix<4, 4> m = Zero;
+ m[0] = makeVector(1, 2, 3, 4);
+ m[1] = makeVector(1, 1, 1, 1);
+
+ SVD<4, 4> svdm(m);
+
+ cout << svdm.get_VT().num_rows() << endl;
+ cout << svdm.get_VT().num_cols() << endl;
+
+ cout << m[0] * svdm.get_VT()[2] << endl;
+ cout << m[0] * svdm.get_VT()[3] << endl;
+ cout << m[1] * svdm.get_VT()[2] << endl;
+ cout << m[1] * svdm.get_VT()[3] << endl;
+
+}
[Prev in Thread] |
Current Thread |
[Next in Thread] |
- [Toon-members] TooN SVD.h test/svd_test.cc,
Edward Rosten <=