/*
 * Copyright (c) 2008 Filip Niksic
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 *
 **/

// 
// File:   matrix.hpp
// Author: filip
//
// Created on 2008. ožujak 30, 23:05
//

#ifndef _MATRIX_HPP
#define	_MATRIX_HPP

#include <mpi.h>
#include <cstddef>
#include <algorithm>
#include <iostream>
#include <cmath>
extern "C" {
    #include "cblas.h"
}

namespace trilu {

    template <typename T, size_t N> class Matrix;
    
    template <typename T, size_t N>
    std::ostream& operator<<(std::ostream& out, const Matrix<T, N>& X) {
        for (size_t i = 0; i < N; ++i)
            for (size_t j = 0; j < N; ++j)
                out << X.M[i][j] << (j + 1 == N ? "\n" : " ");
        return out;
    }
    
    template <typename T, size_t N>
    class Matrix {
        typedef T Row[N];
        T M[N][N];
    public:
        Matrix() {}
        Matrix(const T (&)[N][N]);
        // A.leftMul(X) makes A = XA
        Matrix& leftMul(const Matrix&);
        // A.rightMul(X) makes A = AX
        Matrix& rightMul(const Matrix&);
        Matrix& operator+=(const Matrix&);
        Row& operator[](const size_t n) { return M[n]; }
        const Row& operator[](const size_t n) const { return M[n]; }
        friend std::ostream& operator<< <T, N> (std::ostream&, const Matrix&);
    };

    template <typename T, size_t N>
    inline Matrix<T, N>::Matrix(const T (&A)[N][N]) {
        std::copy(&A[0][0], &A[0][0] + N*N, &M[0][0]);
    }

    template <typename T, size_t N>
    inline Matrix<T, N>& Matrix<T, N>::leftMul(const Matrix& X) {
        // Implementacija opcenitog mnozenja A = XA
        return *this;
    }
    
    template <typename T, size_t N>
    inline Matrix<T, N>& Matrix<T, N>::rightMul(const Matrix& X) {
        // Implementacija opcenitog mnozenja A = AX
        return *this;
    }
    
    template <typename T, size_t N>
    inline Matrix<T, N>& Matrix<T, N>::operator+=(const Matrix& rhs) {
        // Implementacija opcenitog zbrajanja
        for (size_t i = 0; i < N; ++i)
            for (size_t j = 0; j < N; ++j)
                M[i][j] += rhs.M[i][j];
        return *this;
    }

    // Specijalizacija klase Matrix za double
    template <size_t N>
    class Matrix<double, N> {
        typedef double Row[N];
        double M[N][N];
    public:
        Matrix() {}
        Matrix(const double (&A)[N][N]) {
            cblas_dcopy(N*N, &A[0][0], 1, &M[0][0], 1);
        }
        // leftMul i rightMul osim mnozenja obave i normiranje matrice u normi 2
        Matrix& leftMul(const Matrix& X) {
            double *tmpM = new double[N*N]();
            cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, N, N, N, 1.0, &X.M[0][0], N, &M[0][0], N, 1.0, tmpM, N);
            // normiranje
            cblas_dscal(N*N, 1.0 / cblas_dnrm2(N*N, tmpM, 1), tmpM, 1);
            cblas_dcopy(N*N, tmpM, 1, &M[0][0], 1);
            delete [] tmpM;
            return *this;
        }
        Matrix& rightMul(const Matrix& X) {
            double *tmpM = new double[N*N]();
            cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, N, N, N, 1.0, &M[0][0], N, &X.M[0][0], N, 1.0, tmpM, N);
            // normiranje
            cblas_dscal(N*N, 1.0 / cblas_dnrm2(N*N, tmpM, 1), tmpM, 1);
            cblas_dcopy(N*N, tmpM, 1, &M[0][0], 1);
            delete [] tmpM;
            return *this;
        }
        Matrix& operator+=(const Matrix& rhs) {
            cblas_daxpy(N*N, 1.0, &rhs.M[0][0], 1, &M[0][0], 1);
            return *this;
        }
        Row& operator[](const size_t n) { return M[n]; }
        const Row& operator[](const size_t n) const { return M[n]; }
        friend std::ostream& operator<< <double, N> (std::ostream&, const Matrix&);
    };
    
    // Specijalizacija klase Matrix za <double, 2>
    template <>
    class Matrix<double, 2> {
        typedef double Row[2];
        double M[2][2];
    public:
        Matrix() {}
        Matrix(const double (&A)[2][2]) {
            M[0][0] = A[0][0]; M[0][1] = A[0][1];
            M[1][0] = A[1][0]; M[1][1] = A[1][1];
        }
        // leftMul i rightMul osim mnozenja obave i normiranje matrice u normi oo
        Matrix& leftMul(const Matrix& X) {
            double tmpM[2][2] = {
                X.M[0][0] * M[0][0] + X.M[0][1] * M[1][0],
                X.M[0][0] * M[0][1] + X.M[0][1] * M[1][1],
                X.M[1][0] * M[0][0] + X.M[1][1] * M[1][0],
                X.M[1][0] * M[0][1] + X.M[1][1] * M[1][1]
            };
            double max(0.0);
            if (std::fabs(tmpM[0][0]) > max)
                max = std::fabs(tmpM[0][0]);
            if (std::fabs(tmpM[0][1]) > max)
                max = std::fabs(tmpM[0][1]);
            if (std::fabs(tmpM[1][0]) > max)
                max = std::fabs(tmpM[1][0]);
            if (std::fabs(tmpM[1][1]) > max)
                max = std::fabs(tmpM[1][1]);
//            max = std::fabs(tmpM[0][0]) + std::fabs(tmpM[0][1])
//                    + std::fabs(tmpM[1][0]) + std::fabs(tmpM[1][1]);
            M[0][0] = tmpM[0][0] / max; M[0][1] = tmpM[0][1] / max;
            M[1][0] = tmpM[1][0] / max; M[1][1] = tmpM[1][1] / max;
            return *this;
        }
        Matrix& rightMul(const Matrix& X) {
            double tmpM[2][2] = {
                M[0][0] * X.M[0][0] + M[0][1] * X.M[1][0],
                M[0][0] * X.M[0][1] + M[0][1] * X.M[1][1],
                M[1][0] * X.M[0][0] + M[1][1] * X.M[1][0],
                M[1][0] * X.M[0][1] + M[1][1] * X.M[1][1]
            };
            double max(0.0);
            if (std::fabs(tmpM[0][0]) > max)
                max = std::fabs(tmpM[0][0]);
            if (std::fabs(tmpM[0][1]) > max)
                max = std::fabs(tmpM[0][1]);
            if (std::fabs(tmpM[1][0]) > max)
                max = std::fabs(tmpM[1][0]);
            if (std::fabs(tmpM[1][1]) > max)
                max = std::fabs(tmpM[1][1]);
//            max = std::fabs(tmpM[0][0]) + std::fabs(tmpM[0][1])
//                    + std::fabs(tmpM[1][0]) + std::fabs(tmpM[1][1]);
            M[0][0] = tmpM[0][0] / max; M[0][1] = tmpM[0][1] / max;
            M[1][0] = tmpM[1][0] / max; M[1][1] = tmpM[1][1] / max;
            return *this;
        }
        Matrix& operator+=(const Matrix& rhs) {
            M[0][0] += rhs.M[0][0]; M[0][1] += rhs.M[0][1];
            M[1][0] += rhs.M[1][0]; M[1][1] += rhs.M[1][1];
            return *this;
        }
        Row& operator[](const size_t n) { return M[n]; }
        const Row& operator[](const size_t n) const { return M[n]; }
        friend std::ostream& operator<< <double, 2> (std::ostream&, const Matrix&);
    };

    // leftMul(A, X) returns product XA
    template <typename T, size_t N>
    inline Matrix<T, N> leftMul(const Matrix<T, N>& A, const Matrix<T, N>& X) {
        return Matrix<T, N>(A).leftMul(X);
    }
    
    // rightMul(A, X) returns product AX
    template <typename T, size_t N>
    inline Matrix<T, N> rightMul(const Matrix<T, N>& A, const Matrix<T, N>& X) {
        return Matrix<T, N>(A).rightMul(X);
    }
    
    template <typename T, size_t N>
    inline Matrix<T, N> operator+(const Matrix<T, N>& lhs, const Matrix<T, N>& rhs) {
        return Matrix<T, N>(lhs) += rhs;
    }

    template <typename T, size_t N>
    void MPIRightMul(const void* invec, void* inoutvec, int len, const MPI::Datatype& datatype) {
        const Matrix<T, N>* in = static_cast< const Matrix<T, N>* >(invec);
        Matrix<T, N>* inout = static_cast< Matrix<T, N>* >(inoutvec);
        for (int i = 0; i < len; ++i)
            // Opcenito, funkcija treba obaviti inout[i] = in[i] o inout[i]
            // inout[i] = inout[i] in[i]
            inout[i].rightMul(in[i]);
    }

    template <typename T, size_t N>
    void MPILeftMul(const void* invec, void* inoutvec, int len, const MPI::Datatype& datatype) {
        const Matrix<T, N>* in = static_cast< const Matrix<T, N>* >(invec);
        Matrix<T, N>* inout = static_cast< Matrix<T, N>* >(inoutvec);
        for (int i = 0; i < len; ++i)
            // Opcenito, funkcija treba obaviti inout[i] = in[i] o inout[i]
            // inout[i] = in[i] inout[i]
            inout[i].leftMul(in[i]);
    }

}

#endif	/* _MATRIX_HPP */

