package me.eater.threedom.utils import me.eater.threedom.dom.IDocument import me.eater.threedom.dom.INode import me.eater.threedom.utils.joml.Vector3d import me.eater.threedom.utils.joml.compareTo import me.eater.threedom.utils.joml.translation import org.joml.Vector3dc class KDTree(private val document: IDocument, private var root: Node = Node(Vector3d(0, 0, 0))) { private val nodeLocMap = mutableMapOf() class Node( val vertex: Vector3dc = Vector3d(0, 0, 0), val nodeIds: MutableSet = mutableSetOf(), val depth: Long = 0, val branches: Array = Array(3) { null } ) { val axis: Int get() = (depth % 3).toInt() val median: Double get() = vertex[axis] val left: Node? get() = branches[0] val middle: Node? get() = branches[1] val right: Node? get() = branches[2] fun add(translation: Vector3dc, nodeId: Long) { var current = this while (true) { if (translation == current.vertex) { current.nodeIds.add(nodeId) return } val branch = 1 + translation[current.axis].compareTo(current.vertex[current.axis]) if (current.branches[branch] == null) { current.branches[branch] = Node(translation, mutableSetOf(nodeId), current.depth + 1) return } else { current = current.branches[branch]!! } } } fun remove(translation: Vector3dc, nodeId: Long) { find(translation)?.nodeIds?.remove(nodeId) } fun findInRange(origin: Vector3dc, range: Number): Sequence { val rangeD = range.toDouble() val pointA = Vector3d(origin.x() - rangeD, origin.y() - rangeD, origin.z() - rangeD) val pointB = Vector3d(origin.x() + rangeD, origin.y() + rangeD, origin.z() + rangeD) return findInRegion(pointA, pointB).filter { it.vertex.distance(origin) <= rangeD } } fun findInRegion(pointA: Vector3dc, pointB: Vector3dc) = sequence { var current: Node? = this@Node val selection = mutableListOf() while (current != null) { if (pointA <= current.vertex && current.vertex <= pointB) { yield(current) } if (pointA[current.axis] < current.median) { current.left?.let(selection::add) } if (pointA[current.axis] <= current.median && current.median <= pointB[current.axis]) { current.middle?.let(selection::add) } if (pointB[current.axis] > current.median) { current.right?.let(selection::add) } current = selection.firstOrNull()?.apply { selection.removeAt(0) } } } fun find(translation: Vector3dc): Node? { var current: Node? = this while (current != null) { if (translation == current.vertex) { return current } current = current.branches[1 + translation[current.axis].compareTo(current.vertex[current.axis])] } return null } companion object { @Suppress("UNCHECKED_CAST") fun create(nodes: Collection>, depth: Long = 0): Node = create(nodes.groupBy({ it.absolute.translation }) { it.nodeId } as Map>, depth) fun create(nodes: Map>, depth: Long = 0): Node { if (nodes.isEmpty()) { return Node() } if (nodes.size == 1) { val (loc, onlyNodes) = nodes.entries.first() return Node(loc, onlyNodes.toMutableSet()) } val axis: Int = (depth % 3).toInt() val sorted = nodes.keys.sortedBy { it[axis] }.toMutableList() val median = sorted.size / 2 val selected = sorted[median] val branches: Array>> = arrayOf(mutableMapOf(), mutableMapOf(), mutableMapOf()) for (item in sorted) { nodes[item]?.let { branches[1 + item[axis].compareTo(selected[axis])][item] = it } } return Node( selected, nodes[selected]?.toMutableSet() ?: mutableSetOf(), depth, branches.map { Node.create(it, depth + 1) }.toTypedArray() ) } } } constructor(document: IDocument, nodes: Collection>) : this(document, Node.create(nodes)) fun add(node: INode<*>) { val vec = node.absolute.translation nodeLocMap[node.nodeId] = vec root.add(vec, node.nodeId) } fun remove(node: INode<*>) { root.remove(node.absolute.translation, node.nodeId) nodeLocMap.remove(node.nodeId) } fun find(vertex: Vector3dc) = root.find(vertex)?.nodeIds?.mapNotNull(document::getNodeByNodeId) ?: emptyList() fun findInRange(origin: Vector3dc, range: Number) = root.findInRange(origin, range).flatMap { it.nodeIds.asSequence() }.mapNotNull(document::getNodeByNodeId) fun findInRegion(pointA: Vector3dc, pointB: Vector3dc) = root.findInRegion(pointA, pointB).flatMap { it.nodeIds.asSequence() }.mapNotNull(document::getNodeByNodeId) fun update(node: INode<*>) { nodeLocMap[node.nodeId]?.let { root.remove(it, node.nodeId) } add(node) } fun rebalance() { root = Node.create(nodeLocMap.entries.groupBy({ it.value }) { it.key }) } }