You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

169 lines
5.9 KiB
Kotlin

package me.eater.threedom.utils
import me.eater.threedom.dom.IDocument
import me.eater.threedom.dom.INode
import me.eater.threedom.utils.joml.Vector3d
import me.eater.threedom.utils.joml.compareTo
import me.eater.threedom.utils.joml.translation
import org.joml.Vector3dc
class KDTree(private val document: IDocument, private var root: Node = Node(Vector3d(0, 0, 0))) {
private val nodeLocMap = mutableMapOf<Long, Vector3dc>()
class Node(
val vertex: Vector3dc = Vector3d(0, 0, 0),
val nodeIds: MutableSet<Long> = mutableSetOf(),
val depth: Long = 0,
val branches: Array<Node?> = Array(3) { null }
) {
val axis: Int
get() = (depth % 3).toInt()
val median: Double
get() = vertex[axis]
val left: Node?
get() = branches[0]
val middle: Node?
get() = branches[1]
val right: Node?
get() = branches[2]
fun add(translation: Vector3dc, nodeId: Long) {
var current = this
while (true) {
if (translation == current.vertex) {
current.nodeIds.add(nodeId)
return
}
val branch = 1 + translation[current.axis].compareTo(current.vertex[current.axis])
if (current.branches[branch] == null) {
current.branches[branch] = Node(translation, mutableSetOf(nodeId), current.depth + 1)
return
} else {
current = current.branches[branch]!!
}
}
}
fun remove(translation: Vector3dc, nodeId: Long) {
find(translation)?.nodeIds?.remove(nodeId)
}
fun findInRange(origin: Vector3dc, range: Number): Sequence<Node> {
val rangeD = range.toDouble()
val pointA = Vector3d(origin.x() - rangeD, origin.y() - rangeD, origin.z() - rangeD)
val pointB = Vector3d(origin.x() + rangeD, origin.y() + rangeD, origin.z() + rangeD)
return findInRegion(pointA, pointB).filter { it.vertex.distance(origin) <= rangeD }
}
fun findInRegion(pointA: Vector3dc, pointB: Vector3dc) = sequence<Node> {
var current: Node? = this@Node
val selection = mutableListOf<Node>()
while (current != null) {
if (pointA <= current.vertex && current.vertex <= pointB) {
yield(current)
}
if (pointA[current.axis] < current.median) {
current.left?.let(selection::add)
}
if (pointA[current.axis] <= current.median && current.median <= pointB[current.axis]) {
current.middle?.let(selection::add)
}
if (pointB[current.axis] > current.median) {
current.right?.let(selection::add)
}
current = selection.firstOrNull()?.apply { selection.removeAt(0) }
}
}
fun find(translation: Vector3dc): Node? {
var current: Node? = this
while (current != null) {
if (translation == current.vertex) {
return current
}
current = current.branches[1 + translation[current.axis].compareTo(current.vertex[current.axis])]
}
return null
}
companion object {
@Suppress("UNCHECKED_CAST")
fun create(nodes: Collection<INode<*>>, depth: Long = 0): Node =
create(nodes.groupBy({ it.absolute.translation }) { it.nodeId } as Map<Vector3dc, Collection<Long>>,
depth)
fun create(nodes: Map<Vector3dc, Collection<Long>>, depth: Long = 0): Node {
if (nodes.isEmpty()) {
return Node()
}
if (nodes.size == 1) {
val (loc, onlyNodes) = nodes.entries.first()
return Node(loc, onlyNodes.toMutableSet())
}
val axis: Int = (depth % 3).toInt()
val sorted = nodes.keys.sortedBy { it[axis] }.toMutableList()
val median = sorted.size / 2
val selected = sorted[median]
val branches: Array<MutableMap<Vector3dc, Collection<Long>>> =
arrayOf(mutableMapOf(), mutableMapOf(), mutableMapOf())
for (item in sorted) {
nodes[item]?.let { branches[1 + item[axis].compareTo(selected[axis])][item] = it }
}
return Node(
selected,
nodes[selected]?.toMutableSet() ?: mutableSetOf(),
depth,
branches.map { Node.create(it, depth + 1) }.toTypedArray()
)
}
}
}
constructor(document: IDocument, nodes: Collection<INode<*>>) : this(document, Node.create(nodes))
fun add(node: INode<*>) {
val vec = node.absolute.translation
nodeLocMap[node.nodeId] = vec
root.add(vec, node.nodeId)
}
fun remove(node: INode<*>) {
root.remove(node.absolute.translation, node.nodeId)
nodeLocMap.remove(node.nodeId)
}
fun find(vertex: Vector3dc) = root.find(vertex)?.nodeIds?.mapNotNull(document::getNodeByNodeId) ?: emptyList()
fun findInRange(origin: Vector3dc, range: Number) =
root.findInRange(origin, range).flatMap { it.nodeIds.asSequence() }.mapNotNull(document::getNodeByNodeId)
fun findInRegion(pointA: Vector3dc, pointB: Vector3dc) =
root.findInRegion(pointA, pointB).flatMap { it.nodeIds.asSequence() }.mapNotNull(document::getNodeByNodeId)
fun update(node: INode<*>) {
nodeLocMap[node.nodeId]?.let { root.remove(it, node.nodeId) }
add(node)
}
fun rebalance() {
root = Node.create(nodeLocMap.entries.groupBy({ it.value }) { it.key })
}
}