toon-members
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[Toon-members] TooN/internal operators.hh


From: Tom Drummond
Subject: [Toon-members] TooN/internal operators.hh
Date: Fri, 17 Apr 2009 03:54:21 +0000

CVSROOT:        /cvsroot/toon
Module name:    TooN
Changes by:     Tom Drummond <twd20>    09/04/17 03:54:21

Modified files:
        internal       : operators.hh 

Log message:
        added diagmultiply for vector*matrix and matrix*vector

CVSWeb URLs:
http://cvs.savannah.gnu.org/viewcvs/TooN/internal/operators.hh?cvsroot=toon&r1=1.33&r2=1.34

Patches:
Index: operators.hh
===================================================================
RCS file: /cvsroot/toon/TooN/internal/operators.hh,v
retrieving revision 1.33
retrieving revision 1.34
diff -u -b -r1.33 -r1.34
--- operators.hh        12 Apr 2009 09:48:56 -0000      1.33
+++ operators.hh        17 Apr 2009 03:54:21 -0000      1.34
@@ -323,6 +323,15 @@
        // this is distinct to cater for non commuting precision types
        template<int Size, typename P1, typename B1, int R, int C, typename P2, 
typename B2>
        struct VectorMatrixMultiply;
+
+       // dummy struct for Vector * Matrix
+       template<int R, int C, typename P1, typename B1, int Size, typename P2, 
typename B2>
+       struct MatrixVectorDiagMultiply;
+
+       // this is distinct to cater for non commuting precision types
+       template<int Size, typename P1, typename B1, int R, int C, typename P2, 
typename B2>
+       struct VectorMatrixDiagMultiply;
+
 };
 
 // Matrix Vector multiplication Matrix * Vector
@@ -377,6 +386,64 @@
 }
 
 
+// Matrix Vector diagonal multiplication Matrix * Vector
+template<int R, int C, typename P1, typename B1, int Size, typename P2, 
typename B2> 
+struct Operator<Internal::MatrixVectorDiagMultiply<R,C,P1,B1,Size,P2,B2> > {
+       const Matrix<R,C,P1,B1>& lhs;
+       const Vector<Size,P2,B2>& rhs;
+
+       Operator(const Matrix<R,C,P1,B1>& lhs_in, const Vector<Size,P2,B2>& 
rhs_in) : lhs(lhs_in), rhs(rhs_in) {}
+
+       int num_rows() const {return lhs.num_rows();}
+       int num_cols() const {return lhs.num_cols();}
+
+       template<int Rout, int Cout, typename Pout, typename Bout>
+       void eval(Matrix<Rout, Cout, Pout, Bout>& res) const {
+               for(int c=0; c < res.num_cols(); ++c) {
+                       P2 temp = rhs[c];
+                       for(int r=0; r < res.num_rows(); ++r) {
+                               res(r,c) = lhs(r,c)*temp;
+                       }
+               }
+       }
+};
+
+template<int R, int C, int Size, typename P1, typename P2, typename B1, 
typename B2>
+Matrix<R, C, typename Internal::MultiplyType<P1,P2>::type> diagmult(const 
Matrix<R, C, P1, B1>& m, const Vector<Size, P2, B2>& v)
+{
+       SizeMismatch<C,Size>::test(m.num_cols(), v.size());
+       return 
Operator<Internal::MatrixVectorDiagMultiply<R,C,P1,B1,Size,P2,B2> >(m,v);
+}
+                                                                               
                                                        
+// Vector Matrix diagonal multiplication Vector * Matrix
+template<int R, int C, typename P1, typename B1, int Size, typename P2, 
typename B2> 
+struct Operator<Internal::VectorMatrixDiagMultiply<Size,P1,B1,R,C,P2,B2> > {
+       const Vector<Size,P1,B1>& lhs;
+       const Matrix<R,C,P2,B2>& rhs;
+
+       Operator(const Vector<Size,P1,B1>& lhs_in, const Matrix<R,C,P2,B2>& 
rhs_in) : lhs(lhs_in), rhs(rhs_in) {}
+
+       int num_rows() const {return rhs.num_rows();}
+       int num_cols() const {return rhs.num_cols();}
+
+       template<int Rout, int Cout, typename Pout, typename Bout>
+       void eval(Matrix<Rout, Cout, Pout, Bout>& res) const {
+               for(int r=0; r < res.num_rows(); ++r){
+                       const P1 temp = lhs[r];
+                       for(int c=0; c<res.num_cols(); ++c){
+                               res(r,c) = temp * rhs(r,c);
+                       }
+               }
+       }
+};
+
+template<int R, int C, typename P1, typename B1, int Size, typename P2, 
typename B2> 
+Matrix<R, C, typename Internal::MultiplyType<P1,P2>::type> diagmult(const 
Vector<Size,P1,B1>& v,
+                                                                               
                                                 const Matrix<R,C,P2,B2>& m)
+{
+       SizeMismatch<C,Size>::test(m.num_rows(), v.size());
+       return 
Operator<Internal::VectorMatrixDiagMultiply<Size,P1,B1,R,C,P2,B2> >(v,m);
+}
 
 
 
////////////////////////////////////////////////////////////////////////////////




reply via email to

[Prev in Thread] Current Thread [Next in Thread]