You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123 lines
4.5 KiB
Kotlin

package me.eater.threedom.utils
import me.eater.threedom.dom.IDocument
import me.eater.threedom.dom.INode
import me.eater.threedom.utils.joml.Vector3d
import me.eater.threedom.utils.joml.compareTo
import me.eater.threedom.utils.joml.getTranslation
import org.joml.Vector3dc
class KDTree(private val document: IDocument, private val root: Node = Node(Vector3d(0, 0, 0))) {
private val nodeLocMap = mutableMapOf<Long, Vector3dc>()
data class Node(
val vertex: Vector3dc = Vector3d(0, 0, 0),
val nodeIds: MutableSet<Long> = mutableSetOf(),
val depth: Long = 0,
var left: Node? = null,
var right: Node? = null
) {
val axis: Int
get() = (depth % 3).toInt()
fun add(translation: Vector3dc, nodeId: Long) {
var current = this
while (true) {
if (translation == current.vertex) {
current.nodeIds.add(nodeId)
return
}
if (translation[current.axis] < current.vertex[current.axis] || (translation[current.axis] == current.vertex[current.axis] && translation < current.vertex)) {
if (current.left == null) {
current.left = Node(translation, mutableSetOf(nodeId), current.depth + 1)
return
} else {
current = current.left!!
}
} else {
if (current.right == null) {
current.right = Node(translation, mutableSetOf(nodeId), current.depth + 1)
return
} else {
current = current.right!!
}
}
}
}
fun remove(translation: Vector3dc, nodeId: Long) {
find(translation)?.nodeIds?.remove(nodeId)
}
fun find(translation: Vector3dc): Node? {
var current: Node? = this
while (current != null) {
if (translation == current.vertex) {
return current
}
current =
if (translation[current.axis] < current.vertex[current.axis] || (translation[current.axis] == current.vertex[current.axis] && translation < current.vertex)) {
current.left
} else {
current.right
}
}
return null
}
companion object {
@Suppress("UNCHECKED_CAST")
fun create(nodes: Collection<INode<*>>, depth: Long = 0): Node =
create(nodes.groupBy({ it.absolute.getTranslation() }) { it.nodeId } as Map<Vector3dc, Collection<Long>>,
depth)
fun create(nodes: Map<Vector3dc, Collection<Long>>, depth: Long = 0): Node {
if (nodes.isEmpty()) {
return Node()
}
if (nodes.size == 1) {
val (loc, onlyNodes) = nodes.entries.first()
return Node(loc, onlyNodes.toMutableSet())
}
val axis: Int = (depth % 3).toInt()
val sorted = nodes.keys.sortedBy { it[axis] }
val median = sorted.size / 2
val selected = sorted[median]
val left = sorted.slice(0..median).toSet().takeIf { it.isNotEmpty() }?.let {
nodes.filterKeys(it::contains)
}?.let { create(it, depth + 1) }
val right = sorted.slice(median + 1..sorted.size).toSet().takeIf { it.isNotEmpty() }?.let {
nodes.filterKeys(it::contains)
}?.let { create(it, depth + 1) }
return Node(selected, nodes[selected]?.toMutableSet() ?: mutableSetOf(), depth, left, right)
}
}
}
constructor(document: IDocument, nodes: Collection<INode<*>>) : this(document, Node.create(nodes))
fun add(node: INode<*>) {
val vec = node.absolute.getTranslation()
nodeLocMap[node.nodeId] = vec
root.add(vec, node.nodeId)
}
fun remove(node: INode<*>) {
root.remove(node.absolute.getTranslation(), node.nodeId)
nodeLocMap.remove(node.nodeId)
}
fun find(vertex: Vector3dc) = root.find(vertex)?.nodeIds?.mapNotNull(document::getNodeByNodeId) ?: emptyList()
fun update(node: INode<*>) {
nodeLocMap[node.nodeId]?.let { root.remove(it, node.nodeId) }
add(node)
}
}