#ifndef CGPOINTWRANGLER_H
#define CGPOINTWRANGLER_H

// stl
#include <vector>
#include <queue>
#include <deque>
#include <limits>
#include <string>
#include <algorithm>
#include <iostream>
#include <numeric>
#include <utility>
#include <functional>
#include <tuple>
#include <thread>

// glm & eigen
#include <glm/glm.hpp>
#include <glm/gtx/norm.hpp>
#include <glm/gtc/matrix_transform.hpp>

// this
#include "CgAABB.h"
#include "CgPointMath.h"
#include "CgUtils/Timer.h"

/*
* more complex point cloud algorithms & data structures
*/
class CgPointWrangler {
    public:
    
        // rearrange the vec3 between begin and end so they form a kd-tree
        static void buildKdTree(glm::vec3* begin, glm::vec3* end, unsigned int depth);
        static void buildKdTree(PCA* begin, PCA* end, unsigned int depth);

        // traverse a kd-tree between begin and end in BFO and call a function on each node
        static void traverseBFO(
            glm::vec3* begin, glm::vec3* end, 
            std::function<void(glm::vec3*, const unsigned int)> fn
        );

        // get the point that's closest to the ray in the kd-tree between begin and end
        static int getClosestPointToRay(
            const glm::vec3* begin, const glm::vec3* end,
            const CgAABB aabb,
            glm::vec3 origin, glm::vec3 direction
        );

        // naive estimation of normals
        static std::vector<std::vector<unsigned int>> getKNeighbourhoods(unsigned int neighbourhood_size, const std::vector<glm::vec3>& vertices);
        static std::vector<std::vector<glm::vec3>> kNeighbourhoodsToClusters(const std::vector<glm::vec3>& vertices, const std::vector<std::vector<unsigned int>>& neighbourhoods);
        static std::vector<std::pair<unsigned int, unsigned int>> makeNaiveSpanningTree(const std::vector<std::vector<unsigned int>>& neighbourhoods, const std::vector<PCA>& pca);
        static std::vector<PCA> performPcaOnClusters(std::vector<std::vector<glm::vec3>> clusters); 
        static std::vector<glm::vec3> alignNormals(
            glm::vec3 ref_normal, // reference the first normal will be aligned to 
            const std::vector<std::pair<unsigned int, unsigned int>> spanning_tree, // list of edges (from, to)
            const std::vector<PCA>& pca // contains the unaligned normals
        );

        static std::vector<PCA> performHierarchicalSimplification(
            std::vector<glm::vec3> original_points,
            float max_cluster_size,
            float max_cluster_variance,
            float max_cluster_radius
        );

        static std::vector<PCA> performIncrementalSimplification(
            std::vector<glm::vec3> original_points,
            float max_cluster_size,
            float max_cluster_variance,
            float max_cluster_radius
        );

        static std::vector<unsigned int> getNearestNeighborsSlow(
            unsigned int current_point, 
            unsigned int k,
            const std::vector<glm::vec3>& vertices
        );

        static std::vector<unsigned int> getNearestNeighborsFast(
            unsigned int current_point, 
            unsigned int k, 
            const std::vector<glm::vec3>& vertices,
            CgAABB aabb
        );

        static std::vector<unsigned int> getNearestNeighborsFast(
            glm::vec3 query, 
            unsigned int k, 
            const std::vector<glm::vec3>& vertices,
            CgAABB aabb
        );

    private:
        // calculate offset between begin and half-way pointer between begin and end
        // requires end >= begin
        template<typename T>
        static unsigned int halfSize(const T* begin, const T* end);
        
        CgPointWrangler(){}
};

template<typename T>
inline unsigned int CgPointWrangler::halfSize(const T* begin, const T* end) {
    return (end - begin) / 2;
}

// rearrange the vec3 between begin and end (exclusive) so they form a kd-tree
// requires begin < end
inline void CgPointWrangler::buildKdTree(glm::vec3* begin, glm::vec3* end, unsigned int depth) {
    unsigned int half_size = halfSize(begin, end);

    // none or one element
    if(half_size == 0) return;

    // partially sort and ensure median element 
    // is in the middle (runs in O(n))
    // about 4x faster on tyra.obj than std::sort
    std::nth_element(begin, begin + half_size, end, [=](const glm::vec3& a, const glm::vec3& b){
        return a[depth % 3] < b[depth % 3];
    });

    // split array in two (excluding median) and recurse
    buildKdTree(begin, begin + half_size, depth + 1);
    buildKdTree(begin + half_size + 1, end, depth + 1);
}

// see the version for glm::vec3
inline void CgPointWrangler::buildKdTree(PCA* begin, PCA* end, unsigned int depth) {
    unsigned int half_size = halfSize(begin, end);
    if(half_size == 0) return;
    std::nth_element(begin, begin + half_size, end, [=](const PCA& a, const PCA& b){
        return a.centroid[depth % 3] < b.centroid[depth % 3];
    });
    buildKdTree(begin, begin + half_size, depth + 1);
    buildKdTree(begin + half_size + 1, end, depth + 1);
}

inline void CgPointWrangler::traverseBFO(
    glm::vec3* begin, 
    glm::vec3* end, 
    std::function<void(glm::vec3*, const unsigned int)> fn
) {
    // queue of pairs of vec3 pointers
    // (range of points in the current node and its children)
    std::queue<std::pair<glm::vec3*, glm::vec3*>> q;
    if(begin == end) return;
    unsigned int depth = 0;
    q.push({begin, end});
    while(!q.empty()) {
        int level_count = q.size();
        for(int i = 0; i < level_count; i++) {
            auto curr = q.front();
            q.pop();
            glm::vec3* begin = curr.first;
            glm::vec3* end = curr.second;
            unsigned int half_size = halfSize(begin, end);
            glm::vec3* curr_pointer = begin + half_size;
            fn(curr_pointer, depth);
            if(end - begin <= 1) continue;
            // push left & right half of array to reconstruct splits
            if(begin <= begin + half_size) q.push({begin, begin + half_size});
            if(begin + half_size < end) q.push({begin + half_size + 1, end});
        }
        depth += 1;
    }
}

inline int CgPointWrangler::getClosestPointToRay(
    const glm::vec3* begin, 
    const glm::vec3* end,
    CgAABB aabb,
    glm::vec3 origin, 
    glm::vec3 direction
) {
  if(begin == end) return -1;

  glm::vec3 inv_direction = glm::vec3(1.0 / direction.x, 1.0 / direction.y, 1.0 / direction.z);
  // sensible starting point
  float erg_sqr_dist = std::numeric_limits<float>::infinity();
  int erg = -1;

  // stack of traversed nodes (start index, end index, aabb, depth)
  std::vector<std::tuple<const glm::vec3*, const glm::vec3*, CgAABB, unsigned int>> nodes;
  nodes.push_back({begin, end, aabb, 0});
  
  // traverse the tree as long as there's something to do.
  while(!nodes.empty()) {
    // unpack current node & pop it
    const glm::vec3* cur_begin;
    const glm::vec3* cur_end;
    CgAABB cur_aabb;
    unsigned int cur_depth;
    std::tie(cur_begin, cur_end, cur_aabb, cur_depth) = nodes.back();
    auto half_size = halfSize(cur_begin, cur_end);
    nodes.pop_back();

    glm::vec3 cur_point = *(cur_begin + half_size);
    // insert current point into set if it's closer to ray than last one
    float cand_dist = CgPointMath::sqrPointRayDist(cur_point, origin, direction);
    if(cand_dist < erg_sqr_dist) {
      erg_sqr_dist = cand_dist;
      erg = cur_begin + half_size - begin;
    }

    // the aabb of the sub-nodes
    unsigned int axis = cur_depth % 3;
    CgAABB lower_aabb;
    CgAABB higher_aabb;
    std::tie(lower_aabb, higher_aabb) = cur_aabb.split(axis, cur_point[axis]);
    // only push children if their aabb is closer to ray than 
    // the closest point we found until now and there are
    // points in them
    if (
        half_size != 0
        && CgPointMath::sqrBoxRayDist(lower_aabb, origin, direction, inv_direction) < erg_sqr_dist
    ) {
        nodes.push_back({cur_begin, cur_begin + half_size, lower_aabb, cur_depth + 1});
    }
    if (
        cur_begin + half_size < cur_end
        && CgPointMath::sqrBoxRayDist(higher_aabb, origin, direction, inv_direction) < erg_sqr_dist
    ) {
        nodes.push_back({cur_begin + half_size + 1, cur_end, higher_aabb, cur_depth + 1});
    }
  }

  return erg;
}

inline std::vector<std::vector<unsigned int>> CgPointWrangler::getKNeighbourhoods(unsigned int neighbourhood_size, const std::vector<glm::vec3>& vertices) {
    // all k-neighbourhoods
    std::vector<std::vector<unsigned int>> neighbourhoods;
    neighbourhoods.resize(vertices.size());
    CgAABB aabb = CgPointMath::calculateAABB(vertices);

    run_threaded(vertices.size(), [&](
        const unsigned int start_index, 
        const unsigned int count
    ) {
        for(unsigned int i = start_index; i < start_index + count; i++) {
            neighbourhoods[i] = getNearestNeighborsFast(i, neighbourhood_size, vertices, aabb);    
        }
    });

    return neighbourhoods;
}

inline std::vector<std::vector<glm::vec3>> CgPointWrangler::kNeighbourhoodsToClusters(
    const std::vector<glm::vec3>& vertices,
    const std::vector<std::vector<unsigned int>>& neighbourhoods
) {
    std::vector<std::vector<glm::vec3>> clusters;
    clusters.resize(neighbourhoods.size());

    run_threaded(neighbourhoods.size(), [&](
        const unsigned int start_index, 
        const unsigned int count
    ) {
        for(unsigned int i = start_index; i < start_index + count; i++) {
            std::vector<unsigned int> nh_indices = neighbourhoods[i];
            clusters[i].resize(nh_indices.size());
            for(unsigned int j = 0; j < nh_indices.size(); j++) {
                unsigned int index = nh_indices[j];
                glm::vec3 p = vertices[index];
                clusters[i][j] = p;
            };
        }
    });

    return clusters;
}

inline std::vector<PCA> CgPointWrangler::performPcaOnClusters(std::vector<std::vector<glm::vec3>> clusters) {
    std::vector<PCA> pca;
    pca.resize(clusters.size());

    // worker for doing the PCAs
    run_threaded(pca.size(), [&](
        const unsigned int start_index, 
        const unsigned int count
    ) {
        for(unsigned int i = start_index; i < start_index + count; i++) {
            std::vector<glm::vec3> nh = clusters[i];
            pca[i] = CgPointMath::calculatePCA(nh);
        }
    });

    return pca;
}

inline std::vector<std::pair<unsigned int, unsigned int>> CgPointWrangler::makeNaiveSpanningTree(
    const std::vector<std::vector<unsigned int>>& neighbourhoods,
    const std::vector<PCA>& pca
) {
    std::cout << pca.size() << " pca elements for spanning tree" << std::endl;
    std::vector<std::pair<unsigned int, unsigned int>> spanning_tree;
    spanning_tree.clear();
    spanning_tree.reserve(pca.size() + 1);

    // find point with highest y-value
    std::vector<bool> visited;
    visited.resize(pca.size(), false);
    float max_y = -std::numeric_limits<float>::infinity();
    unsigned int max_y_index = 0;
    for(unsigned int i = 0; i < pca.size(); i++) {
        float curr = pca[i].centroid.y;
        if(max_y < curr) {
            max_y = curr;
            max_y_index = i;
        }
    }
    visited[max_y_index] = true;

    // record the "best" parent found for each vertex
    std::vector<std::pair<unsigned int, double>> scores;
    scores.resize(pca.size());

    // queue of point indices to process next
    std::vector<unsigned int> queue;

    auto cmp = [&](const unsigned int left, unsigned int right) {
        return scores[left].second > scores[right].second;
    };

    auto visit = [&](unsigned int curr) {
        glm::vec3 normal = pca[curr].evec0;
        glm::vec3 point = pca[curr].centroid;
        double rad = pca[curr].radius2;

        // retrieve neighbourhood indices
        auto curr_nh_ind = neighbourhoods[curr];

        // find out which of the neighbours we didn't process yet
        // and score them to find out in which order to process
        for(unsigned int i = 0; i < curr_nh_ind.size(); i++) {
            unsigned int index = curr_nh_ind[i];
            // score by how far away it is from the current point
            // and how parallel the normals are.
            double alignment = 1.0 - std::abs(glm::dot(pca[index].evec0, normal));
            double distance = glm::distance(pca[index].centroid, point) / rad;
            double score = distance + alignment;
            if(!visited[index]) {
                scores[index].first = curr;
                scores[index].second = score;
                queue.push_back(index);
                std::push_heap(queue.begin(), queue.end(), cmp);
                visited[index] = true;
            } else if(scores[index].second > score) {
                // update current neighbors parent to current node, since it's closer
                scores[index].first = curr;
                scores[index].second = score;
                // re-sort the heap with the new score
                std::make_heap(queue.begin(), queue.end(), cmp);
            }
        }
    };

    // get started by aligning the normal of the highest point
    spanning_tree.push_back({max_y_index, max_y_index});
    visit(max_y_index); 

    // start processing
    while(!queue.empty()) {
        // pop the index with the smallest score
        std::pop_heap(queue.begin(), queue.end(), cmp);
        unsigned int curr = queue.back();
        unsigned int parent = scores[curr].first;
        queue.pop_back();
        spanning_tree.push_back({parent, curr});
        visit(curr);
    }

    return spanning_tree;
}

inline std::vector<glm::vec3> CgPointWrangler::alignNormals(
    glm::vec3 ref_normal, // reference the first normal will be aligned to
    const std::vector<std::pair<unsigned int, unsigned int>> spanning_tree, // list of edges (from, to)
    const std::vector<PCA>& pca // contains the unaligned normals
) {
    std::cout << spanning_tree.size() << " spanning tree edges" << std::endl;
    std::vector<glm::vec3> normals;
    normals.resize(pca.size());
    for(unsigned int i = 0; i < pca.size(); i++) normals[i] = pca[i].evec0;
    // align first normal with ref
    unsigned int start_index = spanning_tree[0].first;
    normals[start_index] = glm::dot(normals[start_index], ref_normal) > 0.0 
        ? normals[start_index] 
        : -(normals[start_index]);

    // align the rest 
    // (spanning tree should ensure that we don't align to a normal that's not yet aligned itself)
    for(unsigned int i = 1; i < spanning_tree.size(); i++) {
        unsigned int parent_index = spanning_tree[i].first;
        unsigned int own_index = spanning_tree[i].second;
        normals[own_index] = glm::dot(normals[own_index], normals[parent_index]) > 0.0 
            ? normals[own_index] 
            : -(normals[own_index]);
    }

    return normals;
}

inline std::vector<PCA> CgPointWrangler::performHierarchicalSimplification(
    std::vector<glm::vec3> original_points,
    float max_cluster_size,
    float max_cluster_variance,
    float max_cluster_radius
){
    unsigned int MAX_POINTS = std::min(100.0f, max_cluster_size);
    unsigned int MIN_POINTS = 3;

    std::vector<unsigned int> hist;
    hist.resize(MAX_POINTS + 1, 0);

    // used by the threads to define the clusters
    std::vector<unsigned int> indices;
    indices.resize(original_points.size());
    for(unsigned int i = 0; i < original_points.size(); i++) indices[i] = i;

    // used to define the work portions
    std::vector<std::pair<unsigned int, unsigned int>> jobs; // list of ranges of indices
    jobs.push_back({0, indices.size()}); // initially, use entire point cloud
    std::vector<PCA> pca; // finished clusters

    while(!jobs.empty()) {
        std::pair<unsigned int, unsigned int> job = jobs.back();
        jobs.pop_back();

        // get cluster and check if it fits constraints
        std::vector<glm::vec3> cluster;
        for(unsigned int i = job.first; i < job.second; i++) cluster.push_back(original_points[indices[i]]);
        PCA current_pca = CgPointMath::calculatePCA(cluster);
        double variance = CgPointMath::varianceFromPCA(current_pca);
        if(cluster.size() <= MIN_POINTS  || (variance < max_cluster_variance && current_pca.radius2 < max_cluster_radius && cluster.size() <= MAX_POINTS)) {
            // push cluster as-is into cluster vector
            hist[cluster.size()] += 1;
            pca.push_back(current_pca);
        } else {
            // split it & recurse
            // indices vector is not behind mutex because 
        // indices vector is not behind mutex because 
            // indices vector is not behind mutex because 
            // the threads work on mutually exclusive ranges
            unsigned int lower = job.first;
            unsigned int upper = job.second;
            while(lower != upper) {
                // separate indices depending on which side of the centroid-normal plane
                // their points are
                if(glm::dot(original_points[indices[lower]] - current_pca.centroid, current_pca.evec2) > 0) {
                lower += 1;
                } else {
                upper -= 1;
                unsigned int t = indices[upper];
                indices[upper] = indices[lower];
                indices[lower] = t;
                }
            }
            
            jobs.push_back({job.first, lower});
            jobs.push_back({lower, job.second});
        }
    }

    // clustering is done 
    std::cout << pca.size() << " clusters generated, histogram (cluster_size:count):" << std::endl;
    for(unsigned int i = 0; i < hist.size(); i++) std::cout << " |" << i + 1 << ":" << hist[i];
    std::cout << std::endl;
    return pca;
}

inline std::vector<PCA> CgPointWrangler::performIncrementalSimplification(
    std::vector<glm::vec3> original_points,
    float max_cluster_size,
    float max_cluster_variance,
    float max_cluster_radius
){
    unsigned int MAX_POINTS = std::min(100.0f, max_cluster_size);
    std::vector<glm::vec3> remaining_points;
    std::vector<unsigned int> current_nh;
    std::vector<bool> used;
    used.resize(original_points.size(), false);
    std::vector<PCA> pca;
    glm::vec3 next_seed_point = original_points[0];
    std::vector<unsigned int> hist;
    hist.resize(MAX_POINTS + 1, 0);
    while(original_points.size() > 0) { // while there are still points to consider...
        unsigned int k;
        CgAABB aabb = CgPointMath::calculateAABB(original_points);
        PCA current_pca;
        current_pca.centroid = next_seed_point;
        // walk up to the largest cluster that fits the constraints
        // for each size:
        for(k = 1; k <= MAX_POINTS + 1; k++) {
        // get neighbourhood indices
        std::vector<unsigned int> nh = CgPointWrangler::getNearestNeighborsFast(
            current_pca.centroid, k,
            original_points, aabb
        );

        // get neighbourhood points from indices
        std::vector<glm::vec3> cluster(nh.size());
        for(unsigned int i = 0; i < nh.size(); i++) cluster[i] = original_points[nh[i]];

        // do pca and check if cluster fits constraint
        PCA cluster_pca = CgPointMath::calculatePCA(cluster);
        float variance = CgPointMath::varianceFromPCA(cluster_pca);
        if(variance > max_cluster_variance || cluster_pca.radius2 > max_cluster_radius) {
            // the point that's last in the cluster is furthest away from the seed
            // use it as next seed
            next_seed_point = cluster[cluster.size()- 1];
            hist[nh.size()] += 1;
            break;
        }
        current_nh = nh;
        current_pca = cluster_pca;
        // if there are not enough points left to break the constraint
        if(k > nh.size()) break;
        }

        // push new point information
        pca.push_back(current_pca);

        // if we have all the rest of the points in the current cluster, we're done
        if(current_nh.size() == original_points.size()) break;

        // copy unused points over
        used.clear();
        used.resize(original_points.size(), false);
        remaining_points.clear();
        remaining_points.reserve(original_points.size());
        for(unsigned int i = 0; i < current_nh.size(); i++) used[current_nh[i]] = true;
        for(unsigned int i = 0; i < original_points.size(); i++) {
        if(!used[i]) remaining_points.push_back(original_points[i]);
        used[i] = false;
        }
        // build new kd-tree in remaining points.
        original_points = remaining_points;
        CgPointWrangler::buildKdTree(original_points.data(), original_points.data() + original_points.size(), 0);
    }

    // clustering is done 
    std::cout << pca.size() << " clusters generated, histogram (cluster_size:count):" << std::endl;
    for(unsigned int i = 0; i < hist.size(); i++) std::cout << " |" << i + 1 << ":" << hist[i];
    std::cout << std::endl;

    return pca;
}

inline std::vector<unsigned int> CgPointWrangler::getNearestNeighborsSlow(
    unsigned int current_point, 
    unsigned int k, 
    const std::vector<glm::vec3>& vertices
) {
    glm::vec3 q = vertices[current_point];

    std::vector<std::pair<double,int>> distances;

    for(unsigned int i = 0; i < vertices.size(); i++) {
        double dist = glm::distance(vertices[i],q);
        distances.push_back(std::make_pair(dist,i));
    }
    std::sort(distances.begin(), distances.end());
    std::vector<unsigned int> erg;

    for(unsigned int i = 0; i < k; i++) erg.push_back(distances[i].second);

    return erg;
}

inline std::vector<unsigned int> CgPointWrangler::getNearestNeighborsFast(
    unsigned int current_point, 
    unsigned int k,
    const std::vector<glm::vec3>& vertices,
    CgAABB aabb
){
    glm::vec3 query = vertices[current_point];
    return getNearestNeighborsFast(
        query, k, vertices, aabb
    );
}

inline std::vector<unsigned int> CgPointWrangler::getNearestNeighborsFast(
    glm::vec3 query, 
    unsigned int k,
    const std::vector<glm::vec3>& vertices,
    CgAABB aabb
) {
    std::vector<unsigned int> erg(k);
    if(k == 0) return erg;
    
    // pair of distance + index into backing vec3 array 
    using Pair = std::pair<double, unsigned int>;

    const glm::vec3* first = vertices.data();
    const glm::vec3* last = vertices.data() + vertices.size();
    // comp function for max heap
    auto cmp = [&](const Pair &left, const Pair &right) { return (left.first < right.first);};

    // using an ordered set for simplicitly, this may be slow
    // because we frequently erase elements, causing the rb-tree
    // to be rebuilt.
    std::vector<Pair> candidates;
    candidates.reserve(k + 1);

    // stack of traversed nodes (start index, end index, aabb, depth)
    std::vector<std::tuple<const glm::vec3*, const glm::vec3*, CgAABB, unsigned int>> nodes;
    nodes.push_back({first, last, aabb, 0});
    
    // traverse the tree as long as there's something to do.
    while(!nodes.empty()) {
        // unpack current node & pop it
        const glm::vec3* begin;
        const glm::vec3* end;
        CgAABB aabb;
        unsigned int depth;
        std::tie(begin, end, aabb, depth) = nodes.back();
        nodes.pop_back();
        unsigned int half_size = halfSize(begin, end);
        glm::vec3 curr_point = *(begin + half_size);
        // insert current point distance & index into heap
        candidates.push_back({glm::distance(query, curr_point), begin + half_size - first});
        std::push_heap(candidates.begin(), candidates.end(), cmp);
        if(candidates.size() > k) {
            // pop the candidate with the largest distance
            std::pop_heap(candidates.begin(), candidates.end(), cmp);
            candidates.pop_back();
        }

        // the aabb of the sub-nodes
        unsigned int axis = depth % 3;
        CgAABB lower_aabb;
        CgAABB higher_aabb;
        std::tie(lower_aabb, higher_aabb) = aabb.split(axis, curr_point[axis]);
        // only push children if its aabb intersects with the 
        // neighbourhood bounding sphere 
        // (candidates is a max-heap with the first element being furthest away from query)
        double neighbourhood_radius = candidates[0].first;
        if(
            begin < begin + half_size
            && (lower_aabb.position[axis] + lower_aabb.extent[axis] >= query[axis] - neighbourhood_radius || k > candidates.size())
        ) {
            nodes.push_back({begin, begin + half_size, lower_aabb, depth + 1});
        }
        if(
            begin + half_size + 1 < end
            && (higher_aabb.position[axis] - higher_aabb.extent[axis] <= query[axis] + neighbourhood_radius || k > candidates.size())
        ) {
            nodes.push_back({begin + half_size + 1, end, higher_aabb, depth + 1});
        }
    }

    // copy all nearest neighbour indices we found into
    // a vector to return
    for(unsigned int i = 0; i < candidates.size(); i++)  erg[i] = candidates[i].second;
    return erg;
}

#endif // CGPOINTWRANGLER_H