Author Topic: Kd Tree  (Read 10504 times)

0 Members and 1 Guest are viewing this topic.

gile

  • Water Moccasin
  • Posts: 2233
  • Marseille, France
Re: Kd Tree
« Reply #30 on: April 19, 2013, 11:32:50 AM »
I made some changes in the upper code.
I replaced the using of fields (global) by method parameters so that the public method can run in parallel execution.
I replaced the DistanceTo/GetDistanceTo methods by square distances.
Speaking English as a French Frog

WILL HATCH

  • Bull Frog
  • Posts: 448
Re: Kd Tree
« Reply #31 on: April 19, 2013, 12:16:44 PM »
Great work gile!

gile

  • Water Moccasin
  • Posts: 2233
  • Marseille, France
Re: Kd Tree
« Reply #32 on: April 26, 2013, 02:47:20 PM »
Here's the F# version.
The main interest is that it allows parallel execution for building the tree while targeting the Framwork 2.0 so that it can be used with AutoCAD version form 2007 to 2011.

The F# specific constructs and types are only used internally.
The public methods (the same as in the C# version) are writen to be used with other .NET languages (C# or VB) they return Point3d arrays rather than Point3dCollections. A 'Pair' type (class) is used to replace the missing Tuple class in .NET Frameworks prior to 4.0.

Attached the FsPoint3dTree.dll which can be referenced in C# or VB projects (it may be needed to install the F# runtime).

Note: I can't understand why, but the building of the tree runs about ten times slower in A 2012 and later than in prior versions, so for using with A2012 or later, use the C# version instead.

Code - F#: [Select]
  1. namespace Gile.Point3dTree
  2.  
  3. open System
  4. open Autodesk.AutoCAD.Geometry
  5.  
  6. module private Array =
  7.     let median (compare: 'a -> 'a -> int) (items: 'a[])=
  8.        if items = null || items.Length = 0 then
  9.            failwith "items"
  10.        let l = items.Length
  11.        let k = l / 2
  12.  
  13.        let rec loop f t =
  14.            if f < t then swap f t items.[(f+t) / 2] f t
  15.            else items.[.. k-1], items.[k], items.[k+1 ..]
  16.        and swap a b c f t =
  17.            if a < b then
  18.                if compare items.[a] c > -1 then
  19.                    let tmp = items.[b]
  20.                    items.[b] <- items.[a]
  21.                    items.[a] <- tmp
  22.                    swap a (b-1) c f t
  23.                else
  24.                    swap (a+1) b c f t
  25.            else
  26.                let n = if compare items.[a] c > 0 then a - 1 else a
  27.                if k <= n
  28.                then loop f n
  29.                else loop (n+1) t
  30.  
  31.        loop 0 (l-1)
  32.  
  33. type private TreeNode =
  34.    | Empty
  35.    | Node of int * Point3d option * Point3d * TreeNode * TreeNode  
  36.  
  37. /// <summary>
  38. /// Defines a tuple (double) to be used with versions of .NET Framework prior to 4.0
  39. /// </summary>
  40. /// <typeparam name="T1">Type of the first item.</typeparam>
  41. /// <typeparam name="T2">Type of the second item.</typeparam>
  42. /// <param name="item1">First item.</param>
  43. /// <param name="item2">Second item.</param>
  44. type Pair<'T1, 'T2>(item1, item2) =
  45.  
  46.    /// <summary>
  47.    /// Gets the first item of the pair.
  48.    /// </summary>
  49.    member this.Item1 with get() = item1
  50.  
  51.    /// <summary>
  52.    /// Gets the second item of the pair.
  53.    /// </summary>
  54.    member this.Item2 with get() = item2
  55.  
  56.    /// <summary>
  57.    /// Creates a new instance of Pair.
  58.    /// </summary>
  59.    /// <param name="item1">First item of the pair.</param>
  60.    /// <param name="item2">Second item of the pair.</param>
  61.    /// <returns>a new Pair containing the items.</returns>
  62.    static member Create(item1, item2) = Pair(item1, item2)
  63.  
  64.  
  65. /// <summary>
  66. /// Creates an new instance of Point3dTree.
  67. /// </summary>
  68. /// <param name="points">The Point3d collection to fill the tree.</param>
  69. /// <param name="ignoreZ">A value indicating if the Z coordinate of points is ignored
  70. /// (as if all points were projected to the XY plane).</param>
  71. type Point3dTree (points: Point3d seq, ignoreZ: bool) =
  72.    do if points = null then raise (System.ArgumentNullException("points"))
  73.  
  74.    let dimension = if ignoreZ then 2 else 3
  75.    let sqrDist (p1: Point3d) (p2: Point3d) =
  76.        if ignoreZ
  77.        then (p1.X - p2.X) * (p1.X - p2.X) + (p1.Y - p2.Y) * (p1.Y - p2.Y)
  78.        else (p1.X - p2.X) * (p1.X - p2.X) + (p1.Y - p2.Y) * (p1.Y - p2.Y) + (p1.Z - p2.Z) * (p1.Z - p2.Z)
  79.  
  80.    let rec shift n d =
  81.        if n >>> d > 1 then shift n (d+1) else d
  82.    let pDepth = shift System.Environment.ProcessorCount 0
  83.  
  84.    let create pts =
  85.        let rec loop depth parent = function
  86.            | [||] -> Empty
  87.            | pts ->
  88.                let d = depth % dimension
  89.                let left, median, right =
  90.                    pts |> Array.median(fun (p1: Point3d) p2 -> compare p1.[d] p2.[d])
  91.                let children =
  92.                    if depth < pDepth then
  93.                        [ async { return loop (depth + 1) (Some(median)) left };
  94.                          async { return loop (depth + 1) (Some(median)) right } ]
  95.                        |> Async.Parallel
  96.                        |> Async.RunSynchronously
  97.                    else
  98.                        [| loop (depth + 1) (Some(median)) left;
  99.                           loop (depth + 1) (Some(median)) right |]
  100.                Node(depth, parent, median, children.[0], children.[1])
  101.        loop 0 None pts
  102.  
  103.    let root = points |> Seq.distinct |> Seq.toArray |> create
  104.  
  105.    let rec findNeighbour location node (current, bestDist) =
  106.        match node with
  107.        | Empty -> (current, bestDist)
  108.        | Node(depth, _, point, left, right) ->
  109.            let dist = sqrDist point location
  110.            let d = depth % dimension
  111.            let bestPair =
  112.                if dist < bestDist
  113.                then point, dist
  114.                else current, bestDist
  115.            if bestDist < (location.[d] - point.[d]) * (location.[d] - point.[d]) then
  116.                findNeighbour location (if location.[d] < point.[d] then left else right) bestPair
  117.            else
  118.                findNeighbour location left bestPair
  119.                |> findNeighbour location right
  120.  
  121.    let rec getNeighbours center radius node acc =
  122.        match node with
  123.        | Empty -> acc
  124.        | Node(depth, _, point, left, right) ->
  125.            let acc = if sqrDist center point <= radius then point :: acc else acc
  126.            let d= depth % dimension;
  127.            let coordCen, coordPt = center.[d], point.[d]
  128.            if (coordCen - coordPt) * (coordCen - coordPt) > radius then
  129.                getNeighbours center radius (if coordCen < coordPt then left else right) acc
  130.            else
  131.                getNeighbours center radius left acc
  132.                |> getNeighbours center radius right
  133.  
  134.    let rec getKNeighbours center number node (pairs: (float * Point3d) list) =
  135.        match node with
  136.        | Empty -> pairs
  137.        | Node(depth, _, point, left, right) ->
  138.            let dist = sqrDist center point
  139.            let pairs =
  140.                match pairs.Length with
  141.                | 0 -> [ (dist, point) ]
  142.                | l when l < number ->
  143.                    if (dist > fst pairs.Head)
  144.                    then (dist, point) :: pairs
  145.                    else pairs.Head :: (dist, point) :: pairs.Tail
  146.                | _ ->
  147.                    if dist < fst pairs.Head
  148.                    then ((dist, point) :: pairs.Tail) |> List.sortBy(fun p -> -fst p)
  149.                    else pairs
  150.            let d = depth % dimension
  151.            let coordCen, coordCur = center.[d], point.[d]
  152.            if (coordCen - coordCur) * (coordCen - coordCur) > fst pairs.Head then
  153.                getKNeighbours center number (if coordCen < coordCur then left else right) pairs
  154.            else
  155.                getKNeighbours center number left pairs
  156.                |> getKNeighbours center number right
  157.  
  158.    let rec findRange (lowerLeft: Point3d) (upperRight: Point3d) node (acc: Point3d list) =
  159.        match node with
  160.        | Empty -> acc
  161.        | Node(depth, _, point, left, right) ->
  162.            let acc =
  163.                if ignoreZ then
  164.                    if point.X >= lowerLeft.X && point.X <= upperRight.X &&
  165.                       point.Y >= lowerLeft.Y && point.Y <= upperRight.Y
  166.                    then point :: acc else acc
  167.                else
  168.                    if point.X >= lowerLeft.X && point.X <= upperRight.X &&
  169.                       point.Y >= lowerLeft.Y && point.Y <= upperRight.Y &&
  170.                       point.Z >= lowerLeft.Z && point.Z <= upperRight.Z
  171.                    then point :: acc else acc
  172.            let d = depth % dimension
  173.            if upperRight.[d] < point.[d]
  174.            then findRange lowerLeft upperRight left acc
  175.            elif lowerLeft.[d] > point.[d]
  176.            then findRange lowerLeft upperRight right acc
  177.            else findRange lowerLeft upperRight left acc
  178.                 |> findRange lowerLeft upperRight right
  179.                
  180.    let rec getConnexions node radius (nodes, pairs) =
  181.        match node with
  182.        | Empty -> (nodes, pairs)
  183.        | Node(depth, parent, point, left, right) ->
  184.            let pairs =
  185.                nodes
  186.                |> List.fold(fun acc (n: TreeNode) ->
  187.                    match n with
  188.                    | Empty -> acc
  189.                    | Node(dep, par, _, _, _) ->
  190.                        let pt = par.Value
  191.                        let d = (dep + 1) % dimension
  192.                        if (point.[d] - pt.[d]) * (point.[d] - pt.[d]) <= radius
  193.                        then getNeighbours point radius n acc
  194.                        else acc)
  195.                    (getNeighbours point radius left []
  196.                    |> getNeighbours point radius right)
  197.                |> List.map(fun p -> Pair<Point3d, Point3d>(point, p))
  198.                |> List.fold(fun acc p -> p :: acc) pairs
  199.            let nodes =
  200.                match nodes with
  201.                | [] -> if right = Empty then [] else [right]
  202.                | h :: t  ->
  203.                    if right = Empty
  204.                    then if left = Empty then t else nodes
  205.                    else right :: nodes
  206.            getConnexions left radius (nodes, pairs)
  207.            |> getConnexions right radius            
  208.    
  209.    /// <summary>
  210.    /// Creates an new instance of Point3dTree with ignoreZ = false (default).
  211.    /// </summary>
  212.    /// <param name="points">The Point3d collection to fill the tree.</param>
  213.    new (points: Point3d seq) = Point3dTree(points, false)
  214.  
  215.    /// <summary>
  216.    /// Gets the nearest neighbour.
  217.    /// </summary>
  218.    /// <param name="point">The point from which search the nearest neighbour.</param>
  219.    /// <returns>The nearest point in the collection from the specified one.</returns>
  220.    member this.NearestNeighbour(location) =
  221.        match root with
  222.        | Empty -> raise (System.ArgumentNullException("root"))
  223.        | Node(_, _, point, _, _) ->
  224.            findNeighbour location root (point, Double.MaxValue)
  225.            |> fst
  226.  
  227.    /// <summary>
  228.    /// Gets the neighbours within the specified distance.
  229.    /// </summary>
  230.    /// <param name="point">The point from which search the nearest neighbours.</param>
  231.    /// <param name="radius">The distance in which collect the neighbours.</param>
  232.    /// <returns>The points which distance from the specified point is less or equal to the specified distance.</returns>
  233.    member this.NearestNeighbours(center, radius) =
  234.        getNeighbours center (radius * radius) root []
  235.        |> List.toArray
  236.  
  237.    /// <summary>
  238.    /// Gets the given number of nearest neighbours.
  239.    /// </summary>
  240.    /// <param name="point">The point from which search the nearest neighbours.</param>
  241.    /// <param name="number">The number of points to collect.</param>
  242.    /// <returns>The n nearest neighbours of the specified point.</returns>
  243.    member this.NearestNeighbours(center, number) =
  244.        getKNeighbours center number root []
  245.        |> List.map(fun p -> snd p)
  246.        |> List.toArray
  247.    
  248.    /// <summary>
  249.    /// Gets the points in a range.
  250.    /// </summary>
  251.    /// <param name="pt1">The first corner of range.</param>
  252.    /// <param name="pt2">The opposite corner of the range.</param>
  253.    /// <returns>All points within the box.</returns>
  254.    member this.BoxedRange(pt1: Point3d, pt2: Point3d) =
  255.        let lowerLeft = Point3d(min pt1.X pt2.X, min pt1.Y pt2.Y, min pt1.Z pt2.Z)
  256.        let upperRight = Point3d(max pt1.X pt2.X, max pt1.Y pt2.Y, max pt1.Z pt2.Z)
  257.        findRange lowerLeft upperRight root []
  258.        |> List.toArray
  259.  
  260.    /// <summary>
  261.    /// Gets all the pairs of points which distance is less or equal than the specified distance.
  262.    /// </summary>
  263.    /// <param name="radius">The maximum distance between two points. </param>
  264.    /// <returns>The pairs of points which distance is less or equal than the specified distance.</returns>
  265.    member this.ConnectAll(radius) =
  266.        getConnexions root (radius * radius) ([], [])
  267.        |> snd
  268.        |> List.toArray
« Last Edit: April 27, 2013, 02:48:17 AM by gile »
Speaking English as a French Frog

kaefer

  • Swamp Rat
  • Posts: 572
Re: Kd Tree
« Reply #33 on: April 27, 2013, 07:32:23 AM »
Here's the F# version.

Wow, that's an impressive piece of work.

I'm in no way qualified to comment on its inner workings; but what irks me, from a theoretical perspective at least, is that your tree handling functions aren't tail recursive. In this case they have no reason to, since you're guaranteed to encounter nicely balanced trees only, whose depth would never exceed the stack limit without exhausting the available memory.

Just in case, here's a link to a classic blog post demonstrating continuation-passing style and this is an example of the continuation monad, such that

Code - F#: [Select]
  1. // a tree
  2. type 'a T = E | N of 'a T * 'a * 'a T
  3.  
  4. type ContinuationBuilder() =
  5.     member b.Bind(x, f) = fun k -> x (fun x -> f x k)
  6.     member b.Return x = fun k -> k x
  7.  
  8. let cont = ContinuationBuilder()
  9.  
  10. module Tree =
  11.     let foldTree nodeFunc z0 t =
  12.         let rec aux t = cont{
  13.             match t with
  14.             | E -> return z0
  15.             | N(left, z, right) ->
  16.                 let! leftRes = aux left
  17.                 let! rightRes = aux right
  18.                 return nodeFunc leftRes z rightRes }
  19.         aux t id
  20.  
  21. // example
  22. N(N(E, 2, E), 1, N(E, 3, N(E, 4, E)))
  23. |> Tree.foldTree (fun l x r -> l + x + r) 0 // returns 10


gile

  • Water Moccasin
  • Posts: 2233
  • Marseille, France
Re: Kd Tree
« Reply #34 on: April 27, 2013, 01:59:35 PM »
Hi kaefer,

It seems to me that all the tree handling functions are tail recursive using an accumulator (C# translations with IlSpy or Reflector shows a 'while' statements for all).

Anyway, thanks for the interesting links (despite the fact that all the tests I did with continuations shows a significant loss of performances).
Speaking English as a French Frog

kaefer

  • Swamp Rat
  • Posts: 572
Re: Kd Tree
« Reply #35 on: April 27, 2013, 02:17:41 PM »
It seems to me that all the tree handling functions are tail recursive using an accumulator (C# translations with IlSpy or Reflector shows a 'while' statements for all).

Yes. And no, if I'm not mistaken. Can we agree on them being both? E.g.

Code - C#: [Select]
  1. internal Tuple<FSharpList<TreeNode>, FSharpList<Pair<Point3d, Point3d>>> getConnexions(TreeNode node, double radius, FSharpList<TreeNode> nodes, FSharpList<Pair<Point3d, Point3d>> pairs)
  2. {
  3.         while (node is TreeNode.Node)
  4.         {
  5.                 ...
  6.                 Tuple<FSharpList<TreeNode>, FSharpList<Pair<Point3d, Point3d>>> connexions = this.getConnexions(item3, radius, nodes2, pairs2);
  7.                 ...

The call on the right sub-tree gets converted to a while loop, but the one on the left does not, potentially blowing the stack. Anyway, I think the point is moot when dealing with balanced trees, and its performance implications would preclude a theoretically sound approach.