freepooma-devel
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[PATCH] Support reductions over where()


From: Richard Guenther
Subject: [PATCH] Support reductions over where()
Date: Thu, 20 Nov 2003 22:22:38 +0100 (CET)

Hi!

This patch adds support for reductions over two- and three-arg where()
functions like in

  double star_mass = sum(where(norm(positions(rh).read(I)) <= star_radius,
                               (rh * Pooma::cellVolumes(rh))(I)));

which integrates mass over a sphere with radius star_radius inside the
computational domain.

More interesting cases like

  int cnt = sum(where(rh.read(I) != 0.0, 1));

still need to be fixed - but they are less interesting for me at the
moment.

Tested by checking old reduction and where functionality. New testcase
passes.

Ok?

Richard.


2003Nov20  Richard Guenther <address@hidden>

        * src/Evaluator/WhereProxy.h: add Element_t typedef and
        hasRelations enum.
        src/Evaluator/OpMask.h: add Unwrap<> and ReductionTraits<>
        specialization for OpMask<> operators.
        src/Evaluator/Reduction.h: handle WhereProxy<> in main
        reduction evaluator by unwrapping the expression. Unwrap op
        for final reduction over patch results.
        src/Engine/RemoteEngine.h: unwrap op for final reduction over
        patch results.
        src/Field/tests/WhereTest.cpp: add tests for reduction over
        two- and three-arg where.

diff -Nru a/r2/src/Engine/RemoteEngine.h b/r2/src/Engine/RemoteEngine.h
--- a/r2/src/Engine/RemoteEngine.h      Thu Nov 20 22:03:32 2003
+++ b/r2/src/Engine/RemoteEngine.h      Thu Nov 20 22:03:32 2003
@@ -2069,7 +2069,7 @@
       {
        ret = vals[0];
        for (j = 1; j < n; j++)
-         op(ret, vals[j]);
+         Unwrap<Op>::unwrap(op)(ret, vals[j]);
       }

     delete [] vals;
diff -Nru a/r2/src/Evaluator/OpMask.h b/r2/src/Evaluator/OpMask.h
--- a/r2/src/Evaluator/OpMask.h Thu Nov 20 22:03:32 2003
+++ b/r2/src/Evaluator/OpMask.h Thu Nov 20 22:03:32 2003
@@ -169,6 +169,28 @@
   typedef T1 &Type_t;
 };

+template <class Op>
+struct Unwrap {
+  typedef Op Op_t;
+  static inline const Op_t& unwrap(const Op &op) { return op; }
+};
+
+template <class Op>
+struct Unwrap<OpMask<Op> > {
+  typedef typename Unwrap<Op>::Op_t Op_t;
+  static inline const Op_t& unwrap(const OpMask<Op> &op) { return 
Unwrap<Op>::unwrap(op.op_m); }
+};
+
+template <class Op, class T>
+struct ReductionTraits;
+
+template <class Op, class T>
+struct ReductionTraits<OpMask<Op>, T>
+{
+  static T identity() { return ReductionTraits<Op, T>::identity(); }
+};
+
+
 //-----------------------------------------------------------------------------
 //
 //-----------------------------------------------------------------------------
diff -Nru a/r2/src/Evaluator/Reduction.h b/r2/src/Evaluator/Reduction.h
--- a/r2/src/Evaluator/Reduction.h      Thu Nov 20 22:03:32 2003
+++ b/r2/src/Evaluator/Reduction.h      Thu Nov 20 22:03:32 2003
@@ -53,6 +53,7 @@
 #include "Engine/IntersectEngine.h"
 #include "Evaluator/ReductionKernel.h"
 #include "Evaluator/EvaluatorTags.h"
+#include "Evaluator/WhereProxy.h"
 #include "Threads/PoomaCSem.h"

 #include <vector>
@@ -109,6 +110,14 @@
     return e.centeringSize() == 1 && e.numMaterials() == 1;
   }

+  /// Un-wrap where() expression operation and pass on to generic evaluator.
+
+  template<class T, class Op, class Cond, class Expr>
+  void evaluate(T &ret, const Op &op, const WhereProxy<Cond, Expr> &w) const
+  {
+    evaluate(ret, w.opMask(op), w.whereMask());
+  }
+
   /// Input an expression and cause it to be reduced.
   /// We just pass the buck to a special reduction after updating
   /// the expression leafs and checking its validity (we can handle
@@ -249,7 +258,7 @@

     ret = vals[0];
     for (j = 1; j < n; j++)
-      op(ret, vals[j]);
+      Unwrap<Op>::unwrap(op)(ret, vals[j]);
     delete [] vals;
   }
 };
diff -Nru a/r2/src/Evaluator/WhereProxy.h b/r2/src/Evaluator/WhereProxy.h
--- a/r2/src/Evaluator/WhereProxy.h     Thu Nov 20 22:03:32 2003
+++ b/r2/src/Evaluator/WhereProxy.h     Thu Nov 20 22:03:32 2003
@@ -85,6 +85,10 @@
   typedef typename ConvertWhereProxy<ETrait_t,Tree_t>::Make_t MakeFromTree_t;
   typedef typename MakeFromTree_t::Expression_t               WhereMask_t;

+  typedef typename B::Element_t Element_t;
+
+  enum { hasRelations = B::hasRelations };
+
   inline WhereMask_t
   whereMask() const
   {
diff -Nru a/r2/src/Field/tests/WhereTest.cpp b/r2/src/Field/tests/WhereTest.cpp
--- a/r2/src/Field/tests/WhereTest.cpp  Thu Nov 20 22:03:32 2003
+++ b/r2/src/Field/tests/WhereTest.cpp  Thu Nov 20 22:03:32 2003
@@ -86,6 +86,7 @@
   // Now, we can declare a field.

   Centering<2> allFace = canonicalCentering<2>(FaceType, Continuous);
+  Centering<2> allCell = canonicalCentering<2>(CellType, Continuous);

   typedef UniformRectilinearMesh<2> Geometry_t;

@@ -103,6 +104,9 @@
   Field_t a(allFace, layout, origin, spacings);
   Field_t b(allFace, layout, origin, spacings);
   Field_t c(allFace, layout, origin, spacings);
+  Field_t d(allCell, layout, origin, spacings);
+  Field_t e(allCell, layout, origin, spacings);
+  Field_t f(allCell, layout, origin, spacings);

   PositionsTraits<Geometry_t>::Type_t x = positions(a);

@@ -154,6 +158,21 @@
   tester.check("twoarg where result dirtied part, centering one",
                all(where(dot(x.subField(0, 1), line) > 8.0,
                    b.subField(0, 1), c.subField(0, 1)) == a.subField(0, 1)));
+
+  // 2-arg where reduction
+
+  d = 1.0;
+  e = positions(e).read(e.physicalDomain()).comp(0);
+  tester.check("reduction over twoarg where",
+              sum(where(e(e.physicalDomain()) < 4.0, d)) == 4.0*9.0);
+
+  // 3-arg where reduction
+
+  d = 1.0;
+  f = 0.0;
+  e = positions(e).read(e.physicalDomain()).comp(0);
+  tester.check("reduction over twoarg where",
+              sum(where(e(e.physicalDomain()) < 4.0, d, f)) == 4.0*9.0);

   int ret = tester.results("WhereTest");
   Pooma::finalize();

reply via email to

[Prev in Thread] Current Thread [Next in Thread]