# SCM Repository

[matrix] Diff of /pkg/src/dgeMatrix.c
 [matrix] / pkg / src / dgeMatrix.c

# Diff of /pkg/src/dgeMatrix.c

revision 460, Sat Jan 29 14:08:32 2005 UTC revision 461, Sat Jan 29 14:09:38 2005 UTC
# Line 322  Line 322
322      UNPROTECT(1);      UNPROTECT(1);
323      return val;      return val;
324  }  }
325
326    static double padec [] =   /*  Constants for matrix exponential calculation. */
327    {
328      5.0000000000000000e-1,
329      1.1666666666666667e-1,
330      1.6666666666666667e-2,
331      1.6025641025641026e-3,
332      1.0683760683760684e-4,
333      4.8562548562548563e-6,
334      1.3875013875013875e-7,
335      1.9270852604185938e-9,
336    };
337
338    /**
339     * Matrix exponential - based on the code for Octave's expm function.
340     *
341     * @param x real square matrix to exponentiate
342     *
343     * @return matrix exponential of x
344     */
345    SEXP geMatrix_exp(SEXP x)
346    {
347        SEXP val = PROTECT(duplicate(x));
348        int *Dims = INTEGER(GET_SLOT(x, Matrix_DimSym));
349        int i, ilo, ilos, ihi, ihis, j, nc = Dims[1], sqpow;
350        int ncp1 = Dims[1] + 1, ncsqr = nc * nc;
351        int *pivot = Calloc(nc, int);
352        int *iperm = Calloc(nc, int);
353        double *dpp = Calloc(ncsqr, double), /* denominator power Pade' */
354            *npp = Calloc(ncsqr, double), /* numerator power Pade' */
355            *perm = Calloc(nc, double),
356            *scale = Calloc(nc, double),
357            *v = REAL(GET_SLOT(val, Matrix_xSym)),
358            *work = Calloc(ncsqr, double), inf_norm, m1_j, /* (-1)^j */
359            one = 1., trshift, zero = 0.;
360
361        if (nc < 1 || Dims[0] != nc)
362            error("Matrix exponential requires square, non-null matrix");
363
364        /* FIXME: Add special treatment for nc == 1 */
365
366        /* Preconditioning 1.  Shift diagonal by average diagonal if positive. */
367        trshift = 0;                /* determine average diagonal element */
368        for (i = 0; i < nc; i++) trshift += v[i * ncp1];
369        trshift /= nc;
370        if (trshift > 0.) {         /* shift diagonal by -trshift */
371            for (i = 0; i < nc; i++) v[i * ncp1] -= trshift;
372        }
373
374        /* Preconditioning 2. Balancing with dgebal. */
375        F77_CALL(dgebal)("P", &nc, v, &nc, &ilo, &ihi, perm, &j);
376        if (j) error("geMatrix_exp: LAPACK routine dgebal returned %d", j);
377        F77_CALL(dgebal)("S", &nc, v, &nc, &ilos, &ihis, scale, &j);
378        if (j) error("geMatrix_exp: LAPACK routine dgebal returned %d", j);
379
380        /* Preconditioning 3. Scaling according to infinity norm */
381        inf_norm = F77_CALL(dlange)("I", &nc, &nc, v, &nc, work);
382        sqpow = (inf_norm > 0) ? (int) (1 + log(inf_norm)/log(2.)) : 0;
383        if (sqpow < 0) sqpow = 0;
384        if (sqpow > 0) {
385            double scale_factor = 1.0;
386            for (i = 0; i < sqpow; i++) scale_factor *= 2.;
387            for (i = 0; i < ncsqr; i++) v[i] /= scale_factor;
388        }
389
390        /* Pade' approximation. Powers v^8, v^7, ..., v^1 */
391        AZERO(npp, ncsqr);
392        AZERO(dpp, ncsqr);
393        m1_j = -1;
394        for (j = 7; j >=0; j--) {
395            double mult = padec[j];
396            /* npp = m * npp + padec[j] *m */
397            F77_CALL(dgemm)("N", "N", &nc, &nc, &nc, &one, v, &nc, npp, &nc,
398                            &zero, work, &nc);
399            for (i = 0; i < ncsqr; i++) npp[i] = work[i] + mult * v[i];
400            /* dpp = m * dpp * (m1_j * padec[j]) * m */
401            mult *= m1_j;
402            F77_CALL(dgemm)("N", "N", &nc, &nc, &nc, &one, v, &nc, dpp, &nc,
403                            &zero, work, &nc);
404            for (i = 0; i < ncsqr; i++) dpp[i] = work[i] + mult * v[i];
405            m1_j *= -1;
406        }
407        /* Zero power */
408        for (i = 0; i < ncsqr; i++) dpp[i] *= -1.;
409        for (j = 0; j < nc; j++) {
410            npp[j * ncp1] += 1.;
411            dpp[j * ncp1] += 1.;
412        }
413
414        /* Pade' approximation is solve(dpp, npp) */
415        F77_CALL(dgetrf)(&nc, &nc, dpp, &nc, pivot, &j);
416        if (j) error("geMatrix_exp: dgetrf returned error code %d", j);
417        F77_CALL(dgetrs)("N", &nc, &nc, dpp, &nc, pivot, npp, &nc, &j);
418        if (j) error("geMatrix_exp: dgetrs returned error code %d", j);
419        Memcpy(v, npp, ncsqr);
420
421        /* Now undo all of the preconditioning */
422        /* Preconditioning 3: square the result for every power of 2 */
423        while (sqpow--) {
424            F77_CALL(dgemm)("N", "N", &nc, &nc, &nc, &one, v, &nc, v, &nc,
425                            &zero, work, &nc);
426            Memcpy(v, work, ncsqr);
427        }
428        /* Preconditioning 2: apply inverse scaling */
429        for (j = 0; j < nc; j++)
430            for (i = 0; i < nc; i++)
431                v[i + j * nc] *= scale[i]/scale[j];
432        /* Construct balancing permutation vector */
433        for (i = 0; i < nc; i++) iperm[i] = i; /* identity permutation */
434        /* Leading permutations applied in forward order */
435        for (i = 0; i < (ilo - 1); i++) {
436            int swapidx = (int) (perm[i]) - 1;
437            int tmp = iperm[i];
438            iperm[i] = iperm[swapidx];
439            iperm[swapidx] = tmp;
440        }
441        /* Trailing permutations applied in reverse order */
442        for (i = nc - 1; i >= ihi; i--) {
443            int swapidx = (int) (perm[i]) - 1;
444            int tmp = iperm[i];
445            iperm[i] = iperm[swapidx];
446            iperm[swapidx] = tmp;
447        }
448        /* Construct inverse balancing permutation vector */
449        Memcpy(pivot, iperm, nc);
450        for (i = 0; i < nc; i++) iperm[pivot[i]] = i;
451        /* Apply inverse permutation */
452        Memcpy(work, v, ncsqr);
453        for (j = 0; j < nc; j++)
454            for (i = 0; i < nc; i++)
455                v[i + j * nc] = work[iperm[i] + iperm[j] * nc];
456
457        /* Preconditioning 1: Trace normalization */
458        if (trshift > 0.) {
459            double mult = exp(trshift);
460            for (i = 0; i < ncsqr; i++) v[i] *= mult;
461        }
462
463        /* Clean up */
464        Free(dpp); Free(npp); Free(perm); Free(iperm); Free(pivot); Free(scale); Free(work);
465        UNPROTECT(1);
466        return val;
467    }

Legend:
 Removed from v.460 changed lines Added in v.461

 root@r-forge.r-project.org ViewVC Help Powered by ViewVC 1.0.0
Thanks to: