octave-maintainers
[Top][All Lists]
Advanced

[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

Re: some notes about mex support in Octave


From: David Bateman
Subject: Re: some notes about mex support in Octave
Date: Wed, 26 Jul 2006 03:23:21 +0200
User-agent: Mozilla Thunderbird 1.0.6-7.6.20060mdk (X11/20050322)

David Bateman wrote:
> John W. Eaton wrote:
> 
> 
>>The code for sparse matrices is also not quite finished.  I could use
>>some help writing the following functions:
>>
>>  mex.cc:             mxArray_sparse::as_octave_value
>>
>>  ov-bool-sparse.cc:  octave_sparse_bool_matrix::as_mxArray
>>  ov-cx-sparse.cc:    octave_sparse_complex_matrix::as_mxArray
>>  ov-re-sparse.cc:    octave_sparse_matrix::as_mxArray
> 
> 
> I'd like to help out, but you might have noticed I've ben a bit quite
> lately. Too many other commitments and octave had to take a second place....
> 
> What is involved in writing thes functions? I suppose by the hybrid
> approach you mean you are trying to avoid copying where possible? What
> does this imply? If its not too much work, I'll try to find time for it..
> 
> Regards
> David
> 

Ok, then what about the attached patch. The test code should probably be
cleaned up. Also I note that valgrind is not clean against the mex
types, though that is independent of the sparse type..

D.
Index: src/mex.cc
===================================================================
RCS file: /cvs/octave/src/mex.cc,v
retrieving revision 1.5
diff -c -r1.5 mex.cc
*** src/mex.cc  25 Jul 2006 19:56:00 -0000      1.5
--- src/mex.cc  26 Jul 2006 01:21:56 -0000
***************
*** 372,377 ****
--- 372,384 ----
        id = mxCHAR_CLASS;
      else if (cn == "double")
        id = mxDOUBLE_CLASS;
+     else if (cn == "sparse")
+       {
+       if (val.is_bool_type())
+         id = mxLOGICAL_CLASS;
+       else
+         id = mxDOUBLE_CLASS;
+       }
      else if (cn == "single")
        id = mxSINGLE_CLASS;
      else if (cn == "int8")
***************
*** 1312,1346 ****
  
  // Matlab-style sparse arrays.
  
! class mxArray_sparse : public mxArray_number
  {
  public:
  
    mxArray_sparse (mxClassID id_arg, int m, int n, int nzmax_arg,
                  mxComplexity flag = mxREAL)
!     : mxArray_number (id_arg, m, n, flag), nzmax (nzmax_arg)
    {
      ir = static_cast<int *> (calloc (nzmax, sizeof (int)));
!     jc = static_cast<int *> (calloc (nzmax, sizeof (int)));
    }
  
    mxArray_sparse *clone (void) const { return new mxArray_sparse (*this); }
  
    ~mxArray_sparse (void)
    {
      mxFree (ir);
      mxFree (jc);
    }
  
    octave_value as_octave_value (void) const
    {
!     // FIXME
!     abort ();
!     return octave_value ();
    }
  
    int is_sparse (void) const { return 1; }
  
    int *get_ir (void) const { return ir; }
  
    int *get_jc (void) const { return jc; }
--- 1319,1438 ----
  
  // Matlab-style sparse arrays.
  
! class mxArray_sparse : public mxArray_matlab
  {
  public:
  
    mxArray_sparse (mxClassID id_arg, int m, int n, int nzmax_arg,
                  mxComplexity flag = mxREAL)
!     : mxArray_matlab (id_arg, m, n), nzmax (nzmax_arg)
    {
+     pr = (calloc (nzmax, get_element_size ()));
+     pi = (flag == mxCOMPLEX ? calloc (nzmax, get_element_size ()) : 0);
      ir = static_cast<int *> (calloc (nzmax, sizeof (int)));
!     jc = static_cast<int *> (calloc (n + 1, sizeof (int)));
    }
  
    mxArray_sparse *clone (void) const { return new mxArray_sparse (*this); }
  
    ~mxArray_sparse (void)
    {
+     mxFree (pr);
+     mxFree (pi);
      mxFree (ir);
      mxFree (jc);
    }
  
    octave_value as_octave_value (void) const
    {
!     octave_value retval;
! 
!     dim_vector dv = dims_to_dim_vector ();
! 
!     switch (get_class_id ())
!       {
!       case mxLOGICAL_CLASS:
!       {
!         bool *ppr = static_cast<bool *> (pr);
! 
!         SparseBoolMatrix val (get_m(), get_n(), nzmax);
! 
!         for (int i = 0; i < nzmax; i++)
!           {
!             val.xdata(i) = ppr[i];
!             val.xridx(i) = ir[i];
!           }
! 
!         for (int i = 0; i < get_n() + 1; i++)
!           val.xcidx(i) = jc[i];
! 
!         retval = val;
!       }
!       break;
! 
!       case mxSINGLE_CLASS:
!       error ("single precision data type not supported");
!       break;
! 
!       case mxDOUBLE_CLASS:
!       {
!         if (pi)
!           {
!             double *ppr = static_cast<double *> (pr);
!             double *ppi = static_cast<double *> (pi);
! 
!             SparseComplexMatrix val (get_m(), get_n(), nzmax);
! 
!             for (int i = 0; i < nzmax; i++)
!               {
!                 val.xdata(i) = Complex (ppr[i], ppi[i]);
!                 val.xridx(i) = ir[i];
!               }
! 
!             for (int i = 0; i < get_n() + 1; i++)
!               val.xcidx(i) = jc[i];
! 
!             retval = val;
!           }
!         else
!           {
!             double *ppr = static_cast<double *> (pr);
! 
!             SparseMatrix val (get_m(), get_n(), nzmax);
! 
!             for (int i = 0; i < nzmax; i++)
!               {
!                 val.xdata(i) = ppr[i];
!                 val.xridx(i) = ir[i];
!               }
! 
!             for (int i = 0; i < get_n() + 1; i++)
!               val.xcidx(i) = jc[i];
! 
!             retval = val;
!           }
!       }
!       break;
! 
!       default:
!       panic_impossible ();
!       }
! 
!     return retval;
    }
  
+   int is_complex (void) const { return pi != 0; }
+ 
    int is_sparse (void) const { return 1; }
  
+   void *get_data (void) const { return pr; }
+ 
+   void *get_imag_data (void) const { return pi; }
+ 
+   void set_data (void *pr_arg) { pr = pr_arg; }
+ 
+   void set_imag_data (void *pi_arg) { pi = pi_arg; }
+ 
    int *get_ir (void) const { return ir; }
  
    int *get_jc (void) const { return jc; }
***************
*** 1357,1375 ****
  
    int nzmax;
  
    int *ir;
    int *jc;
  
    mxArray_sparse (const mxArray_sparse& val)
!     : mxArray_number (val), nzmax (val.nzmax),
        ir (static_cast<int *> (malloc (nzmax * sizeof (int)))),
        jc (static_cast<int *> (malloc (nzmax * sizeof (int))))
    {
!     for (int i = 0; i < nzmax; i++)
!       {
!       ir[i] = val.ir[i];
!       jc[i] = val.jc[i];
!       }
    }
  };
  
--- 1449,1471 ----
  
    int nzmax;
  
+   void *pr;
+   void *pi;
    int *ir;
    int *jc;
  
    mxArray_sparse (const mxArray_sparse& val)
!     : mxArray_matlab (val), nzmax (val.nzmax),
        ir (static_cast<int *> (malloc (nzmax * sizeof (int)))),
        jc (static_cast<int *> (malloc (nzmax * sizeof (int))))
    {
!     int ntot = nzmax * get_element_size ();
! 
!     memcpy (pr, val.pr, ntot);
!     memcpy (ir, val.ir, nzmax * sizeof(int));
!     memcpy (jc, val.jc, (val.get_n() + 1) * sizeof(int));
!     if (pi)
!       memcpy (pi, val.pi, ntot);
    }
  };
  
Index: src/ov-bool-sparse.cc
===================================================================
RCS file: /cvs/octave/src/ov-bool-sparse.cc,v
retrieving revision 1.15
diff -c -r1.15 ov-bool-sparse.cc
*** src/ov-bool-sparse.cc       22 Jul 2006 08:31:17 -0000      1.15
--- src/ov-bool-sparse.cc       26 Jul 2006 01:21:56 -0000
***************
*** 693,700 ****
  mxArray *
  octave_sparse_bool_matrix::as_mxArray (void) const
  {
!   // FIXME
!   return 0;
  }
  
  /*
--- 693,715 ----
  mxArray *
  octave_sparse_bool_matrix::as_mxArray (void) const
  {
!   int nz = nzmax();
!   mxArray *retval = new mxArray (mxLOGICAL_CLASS, rows(), columns(), 
!                                nz, mxREAL);
!   bool *pr = static_cast<bool *> (retval->get_data ());
!   int *ir = retval->get_ir();
!   int *jc = retval->get_jc();
! 
!   for (int i = 0; i < nz; i++)
!     {
!       pr[i] = matrix.data(i);
!       ir[i] = matrix.ridx(i);
!     }
! 
!   for (int i = 0; i < columns() + 1; i++)
!     jc[i] = matrix.cidx(i);
! 
!   return retval;
  }
  
  /*
Index: src/ov-cx-sparse.cc
===================================================================
RCS file: /cvs/octave/src/ov-cx-sparse.cc,v
retrieving revision 1.14
diff -c -r1.14 ov-cx-sparse.cc
*** src/ov-cx-sparse.cc 22 Jul 2006 08:31:17 -0000      1.14
--- src/ov-cx-sparse.cc 26 Jul 2006 01:21:57 -0000
***************
*** 762,769 ****
  mxArray *
  octave_sparse_complex_matrix::as_mxArray (void) const
  {
!   // FIXME
!   return 0;
  }
  
  /*
--- 762,787 ----
  mxArray *
  octave_sparse_complex_matrix::as_mxArray (void) const
  {
!   int nz = nzmax();
!   mxArray *retval = new mxArray (mxDOUBLE_CLASS, rows(), columns(), 
!                                nz, mxCOMPLEX);
!   double *pr = static_cast<double *> (retval->get_data ());
!   double *pi = static_cast<double *> (retval->get_imag_data ());
!   int *ir = retval->get_ir();
!   int *jc = retval->get_jc();
! 
!   for (int i = 0; i < nz; i++)
!     {
!       Complex val = matrix.data(i);
!       pr[i] = real (val);
!       pi[i] = imag (val);
!       ir[i] = matrix.ridx(i);
!     }
! 
!   for (int i = 0; i < columns() + 1; i++)
!     jc[i] = matrix.cidx(i);
! 
!   return retval;
  }
  
  /*
Index: src/ov-re-sparse.cc
===================================================================
RCS file: /cvs/octave/src/ov-re-sparse.cc,v
retrieving revision 1.16
diff -c -r1.16 ov-re-sparse.cc
*** src/ov-re-sparse.cc 22 Jul 2006 08:31:17 -0000      1.16
--- src/ov-re-sparse.cc 26 Jul 2006 01:21:57 -0000
***************
*** 788,795 ****
  mxArray *
  octave_sparse_matrix::as_mxArray (void) const
  {
!   // FIXME
!   return 0;
  }
  
  /*
--- 788,811 ----
  mxArray *
  octave_sparse_matrix::as_mxArray (void) const
  {
!   int nz = nzmax();
!   int nr = rows();
!   int nc = columns();
!   mxArray *retval = new mxArray (mxDOUBLE_CLASS, nr, nc, nz, mxREAL);
!   double *pr = static_cast<double *> (retval->get_data ());
!   int *ir = retval->get_ir();
!   int *jc = retval->get_jc();
! 
!   for (int i = 0; i < nz; i++)
!     {
!       pr[i] = matrix.data(i);
!       ir[i] = matrix.ridx(i);
!     }
! 
!   for (int i = 0; i < nc + 1; i++)
!     jc[i] = matrix.cidx(i);
! 
!   return retval;
  }
  
  /*
*** examples/mysparse.c.orig    2006-07-26 03:20:15.478674992 +0200
--- examples/mysparse.c 2006-07-26 03:16:57.302802304 +0200
***************
*** 0 ****
--- 1,116 ----
+ #include "mex.h"
+ 
+ void
+ mexFunction (int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
+ {
+   int n, m, nz;
+   mxArray *v;
+   int i;
+   double *pr, *pi;
+   double *pr2, *pi2;
+   int *ir, *jc;
+   int *ir2, *jc2;
+   
+   if (nrhs != 1 || ! mxIsSparse (prhs[0]))
+     mexErrMsgTxt ("expects sparse matrix");
+ 
+   m = mxGetM (prhs [0]);
+   n = mxGetN (prhs [0]);
+   nz = mxGetNzmax (prhs [0]);
+   
+   if (mxIsComplex(prhs[0]))
+     {
+       mexPrintf("Matrix is %d-by-%d complex sparse matrix with %d elements\n",
+               m, n, nz);
+ 
+       pr = mxGetPr(prhs[0]);
+       pi = mxGetPi(prhs[0]);
+       ir = mxGetIr(prhs[0]);
+       jc = mxGetJc(prhs[0]);
+ 
+       i = n;
+       while (jc[i] == jc[i-1] && i != 0) i--;
+       mexPrintf("last non-zero element (%d, %d) = (%g, %g)\n", ir[nz-1]+ 1, 
+               i, pr[nz-1], pi[nz-1]);
+ 
+       v = mxCreateSparse (m, n, nz, mxCOMPLEX);
+       pr2 = mxGetPr(v);
+       pi2 = mxGetPi(v);
+       ir2 = mxGetIr(v);
+       jc2 = mxGetJc(v);
+       
+       for (i = 0; i < nz; i++)
+       {
+         pr2[i] = 2 * pr[i];
+         pi2[i] = 2 * pi[i];
+         ir2[i] = ir[i];
+       }
+       for (i = 0; i < n + 1; i++)
+       jc2[i] = jc[i];
+ 
+       if (nlhs > 0)
+       plhs[0] = v;
+     }
+   else if (mxIsLogical(prhs[0]))
+     {
+       bool *pbr, *pbr2;
+       mexPrintf("Matrix is %d-by-%d logical sparse matrix with %d elements\n",
+               m, n, nz);
+ 
+       pbr = mxGetLogicals(prhs[0]);
+       ir = mxGetIr(prhs[0]);
+       jc = mxGetJc(prhs[0]);
+ 
+       i = n;
+       while (jc[i] == jc[i-1] && i != 0) i--;
+       mexPrintf("last non-zero element (%d, %d) = %d\n", ir[nz-1]+ 1, 
+               i, pbr[nz-1]);
+ 
+       v = mxCreateSparseLogicalMatrix (m, n, nz);
+       pbr2 = mxGetLogicals(v);
+       ir2 = mxGetIr(v);
+       jc2 = mxGetJc(v);
+       
+       for (i = 0; i < nz; i++)
+       {
+         pbr2[i] = pbr[i];
+         ir2[i] = ir[i];
+       }
+       for (i = 0; i < n + 1; i++)
+       jc2[i] = jc[i];
+ 
+       if (nlhs > 0)
+       plhs[0] = v;
+     }
+   else
+     {
+       
+       mexPrintf("Matrix is %d-by-%d real sparse matrix with %d elements\n",
+               m, n, nz);
+ 
+       pr = mxGetPr(prhs[0]);
+       ir = mxGetIr(prhs[0]);
+       jc = mxGetJc(prhs[0]);
+ 
+       i = n;
+       while (jc[i] == jc[i-1] && i != 0) i--;
+       mexPrintf("last non-zero element (%d, %d) = %g\n", ir[nz-1]+ 1, 
+               i, pr[nz-1]);
+ 
+       v = mxCreateSparse (m, n, nz, mxREAL);
+       pr2 = mxGetPr(v);
+       ir2 = mxGetIr(v);
+       jc2 = mxGetJc(v);
+       
+       for (i = 0; i < nz; i++)
+       {
+         pr2[i] = 2 * pr[i];
+         ir2[i] = ir[i];
+       }
+       for (i = 0; i < n + 1; i++)
+       jc2[i] = jc[i];
+ 
+       if (nlhs > 0)
+       plhs[0] = v;
+     }
+ }

reply via email to

[Prev in Thread] Current Thread [Next in Thread]