TheSwamp

Code Red => .NET => Topic started by: gile on April 01, 2013, 05:08:56 PM

Title: Kd Tree
Post by: gile on April 01, 2013, 05:08:56 PM
Hi,

I'm trying to implement a .NET Kd Tree for AutoCAD Point3d with a NearestNeighbours method.
I tried my class with this (old but successfull) challenge (http://www.theswamp.org/index.php?topic=32874).
Using my Point3dTree is not very faster than the quite naive algorithm I posted in reply #31 (http://www.theswamp.org/index.php?topic=32874.msg383652#msg383652) with 10000 points (
notwithstanding 3 times faster with 100000 points).
Is this due to a bad implementation of the Kd Tree or to the fact this kind of data structure is not so good to solve this type of challenge ?

Code - C#: [Select]
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Autodesk.AutoCAD.Geometry;
  5.  
  6. namespace PointKdTree
  7. {
  8.     public class KdTreeNode<T>
  9.     {
  10.         public T Value { get; internal set; }
  11.         public KdTreeNode<T> LeftChild { get; internal set; }
  12.         public KdTreeNode<T> RightChild { get; internal set; }
  13.         public int Depth { get; internal set; }
  14.         public bool Processed { get; set; }
  15.  
  16.         public KdTreeNode(T value) { this.Value = value; }
  17.     }
  18.  
  19.     public class Point3dTree
  20.     {
  21.         private int dimension;
  22.         public KdTreeNode<Point3d> Root { get; private set; }
  23.  
  24.         public Point3dTree(IEnumerable<Point3d> points) : this(points, false) { }
  25.  
  26.         public Point3dTree(IEnumerable<Point3d> points, bool ignoreZ)
  27.         {
  28.             if (points == null)
  29.                 throw new ArgumentNullException("points");
  30.             this.dimension = ignoreZ ? 2 : 3;
  31.             Point3d[] pts = points.Distinct().ToArray();
  32.             this.Root = Construct(pts, 0);
  33.         }
  34.  
  35.         public List<Point3d> NearestNeighbours(KdTreeNode<Point3d> node, double radius)
  36.         {
  37.             if (node == null)
  38.                 throw new ArgumentNullException("node");
  39.             List<Point3d> result = new List<Point3d>();
  40.             GetNeighbours(node.Value, radius, this.Root.LeftChild, ref result);
  41.             GetNeighbours(node.Value, radius, this.Root.RightChild, ref result);
  42.             return result;
  43.         }
  44.  
  45.         private void GetNeighbours(Point3d center, double radius, KdTreeNode<Point3d> node, ref List<Point3d> result)
  46.         {
  47.             if (node == null) return;
  48.             Point3d pt = node.Value;
  49.             if (!node.Processed && center.DistanceTo(pt) <= radius)
  50.             {
  51.                 result.Add(pt);
  52.             }
  53.             int d = node.Depth % this.dimension;
  54.             double coordCen = center[d];
  55.             double coordPt = pt[d];
  56.             if (Math.Abs(coordCen - coordPt) > radius)
  57.             {
  58.                 if (coordCen < coordPt)
  59.                 {
  60.                     GetNeighbours(center, radius, node.LeftChild, ref result);
  61.                 }
  62.                 else
  63.                 {
  64.                     GetNeighbours(center, radius, node.RightChild, ref result);
  65.                 }
  66.             }
  67.             else
  68.             {
  69.                 GetNeighbours(center, radius, node.LeftChild, ref result);
  70.                 GetNeighbours(center, radius, node.RightChild, ref result);
  71.             }
  72.         }
  73.  
  74.         private KdTreeNode<Point3d> Construct(Point3d[] points, int depth)
  75.         {
  76.             int length = points.Length;
  77.             if (length == 0) return null;
  78.             int d = depth % this.dimension;
  79.             Array.Sort(points, (p1, p2) => p1[d].CompareTo(p2[d]));
  80.             int med = length / 2;
  81.             KdTreeNode<Point3d> node = new KdTreeNode<Point3d>(points[med]);
  82.             node.Depth = depth;
  83.             node.LeftChild = Construct(points.Take(med).ToArray(), depth + 1);
  84.             node.RightChild = Construct(points.Skip(med + 1).ToArray(), depth + 1);
  85.             return node;
  86.         }
  87.     }
  88. }

I also improve the speed for drawing lines by using the ObjectId.Open() method instead of using a transaction.
Code - C#: [Select]
  1. using System.Collections.Generic;
  2. using System.Linq;
  3. using Autodesk.AutoCAD.ApplicationServices;
  4. using Autodesk.AutoCAD.DatabaseServices;
  5. using Autodesk.AutoCAD.EditorInput;
  6. using Autodesk.AutoCAD.Geometry;
  7. using Autodesk.AutoCAD.Runtime;
  8. using AcAp = Autodesk.AutoCAD.ApplicationServices.Application;
  9.  
  10. [assembly: CommandClass(typeof(PointKdTree.CommandMethods))]
  11.  
  12. namespace PointKdTree
  13. {
  14.     public class CommandMethods
  15.     {
  16.         struct Point3dPair
  17.         {
  18.             public readonly Point3d Start;
  19.             public readonly Point3d End;
  20.  
  21.             public Point3dPair(Point3d start, Point3d end)
  22.             {
  23.                 this.Start = start;
  24.                 this.End = end;
  25.             }
  26.         }
  27.  
  28.         [CommandMethod("Connect")]
  29.         public void Connect()
  30.         {
  31.             System.Diagnostics.Stopwatch sw = new System.Diagnostics.Stopwatch();
  32.             sw.Start();
  33.             Document doc = AcAp.DocumentManager.MdiActiveDocument;
  34.             Database db = doc.Database;
  35.             Editor ed = doc.Editor;
  36.             RXClass rxc = RXClass.GetClass(typeof(DBPoint));
  37.             using (BlockTableRecord btr = (BlockTableRecord)db.CurrentSpaceId.Open(OpenMode.ForWrite))
  38.             {
  39.                 var pts = btr.Cast<ObjectId>()
  40.                     .Where(id => id.ObjectClass == rxc)
  41.                     .Select(id =>
  42.                     {
  43.                         using (DBPoint pt = (DBPoint)id.Open(OpenMode.ForRead))
  44.                         { return pt.Position; }
  45.                     });
  46.                 Point3dTree tree = new Point3dTree(pts, true);
  47.                 List<Point3dPair> pairs = new List<Point3dPair>();
  48.                 ConnectPoints(tree, tree.Root, 70.0, ref pairs);
  49.                 foreach (Point3dPair pair in pairs)
  50.                 {
  51.                     using (Line line = new Line(pair.Start, pair.End))
  52.                     {
  53.                         btr.AppendEntity(line);
  54.                     }
  55.                 }
  56.             }
  57.             sw.Stop();
  58.             ed.WriteMessage("\nElapsed milliseconds: {0}", sw.ElapsedMilliseconds);
  59.         }
  60.  
  61.         private void ConnectPoints(Point3dTree tree, KdTreeNode<Point3d> node, double dist, ref List<Point3dPair> pointPairs)
  62.         {
  63.             if (node == null) return;
  64.             node.Processed = true;
  65.             Point3d center = node.Value;
  66.             foreach (Point3d pt in tree.NearestNeighbours(node, dist))
  67.             {
  68.                 pointPairs.Add(new Point3dPair(center, pt));
  69.             }
  70.             ConnectPoints(tree, node.LeftChild, dist, ref pointPairs);
  71.             ConnectPoints(tree, node.RightChild, dist, ref pointPairs);
  72.         }
  73.     }
  74. }
Title: Re: Kd Tree
Post by: TheMaster on April 01, 2013, 05:21:03 PM
Hi Gile. You could try using some execution profiling tools to identify bottlenecks, and I can't tell you much about the overall performance characteristics without some analysis, but just skimming your code, this sticks out:

Code - C#: [Select]
  1.  
  2.     node.LeftChild = Construct(points.Take(med).ToArray(), depth + 1);
  3.     node.RightChild = Construct(points.Skip(med + 1).ToArray(), depth + 1);
  4.  
  5.  

In this case, I would probably not use Linq, because Take/Skip(...).ToArray() can't produce a result in a single allocation, and will have to do it incrementally (starting with 4 elements, and doubling it each time more capacity is needed), which depending on how many elements there are, can be a major performance hit.

I would use Array.Copy() or Array.Constrained.Copy() instead.

And one more possible optimization:

Code - C#: [Select]
  1.  
  2.     Array.Sort(points, (p1, p2) => p1[d].CompareTo(p2[d]));
  3.  
  4.  

The above (from your Construct() method), appears to be redundant, since Construct() is recursive, and each call is sorting its argument, which for all but the outer-most call, should not be necessary.

Title: Re: Kd Tree
Post by: gile on April 01, 2013, 06:07:29 PM
Thanks Tony,

Quote
I would use Array.Copy() or Array.Constrained.Copy() instead.
That makes sense.

Quote
The above (from your Construct() method), appears to be redundant, since Construct() is recursive, and each call is sorting its argument, which for all but the outer-most call, should not be necessary.
I do not agree. At each call 'd' is changing according to the depth in tree and the tree dimension (2d or 3d):
first call: d = 0, the points are sorted by X,
second  call: d = 1 => the points are sorted b Y
third call: d = 0 if dimension = 2 (ignore Z) => the points are sorted by X or d = 2 if dimension = 3 => the points are sorted by Z
Title: Re: Kd Tree
Post by: TheMaster on April 01, 2013, 08:03:21 PM
Thanks Tony,

Quote
I would use Array.Copy() or Array.Constrained.Copy() instead.
That makes sense.

Quote
The above (from your Construct() method), appears to be redundant, since Construct() is recursive, and each call is sorting its argument, which for all but the outer-most call, should not be necessary.
I do not agree. At each call 'd' is changing according to the depth in tree and the tree dimension (2d or 3d):
first call: d = 0, the points are sorted by X,
second  call: d = 1 => the points are sorted b Y
third call: d = 0 if dimension = 2 (ignore Z) => the points are sorted by X or d = 2 if dimension = 3 => the points are sorted by Z

Hi Gile. Yes, I didn't look close enough, and was thinking 'b-tree', not 'kd-tree' :laugh:

Title: Re: Kd Tree
Post by: pkohut on April 02, 2013, 01:05:34 AM
Hi Gile,

Profiling will probably reveal the bottleneck to be in the Construct method, and the call to sort being a big contributor.
Code: [Select]
Array.Sort(points, (p1, p2) => p1[d].CompareTo(p2[d]));Change it to a lighter weight partial sort, like C++'s std::nth_element (http://stackoverflow.com/questions/2540602/does-c-sharp-have-a-stdnth-element-equivalent)


Here is a code snippet from http://www.theswamp.org/index.php?topic=32874.msg401826#msg401826, which is similar to your Construct method.
Code: [Select]
    void OptimizeKdTree(typename std::vector<NODE_POINT<T, DIM>>::iterator & i1,
        typename std::vector<NODE_POINT<T, DIM>>::iterator & i2,
        const size_t dir)
    {
        if(i1 == i2)
            return;

        // Create instance of the compare function and set with the KDTree direction
        NodeCompare<NODE_POINT<T, DIM>, DIM> compare(dir % DIM);

        // Find the halfway mark between the 2 iterators
        std::vector<NODE_POINT<T, DIM>>::iterator iMid = i1 + (i2 - i1) / 2;

        /*
        * rearrange so elements left of iMid are less than iMid.pos
        * and elements right of iMid are greater than iMid.pos
        */
        std::nth_element(i1, iMid, i2, compare);

        // Insert the iMid into the KDTree
        Insert(iMid->Data(), 0);

        // Continue until no more elements to insert
        if(iMid != i1)
            OptimizeKdTree(i1, iMid, dir + 1);
        if(++iMid != i2)
            OptimizeKdTree(iMid, i2, dir + 1);
    }
Title: Re: Kd Tree
Post by: gile on April 02, 2013, 01:32:40 AM
Thanks Paul.

That makes sense, the issue now is: am I able to build a quick select routine faster than Array.Sort() ?
Title: Re: Kd Tree
Post by: TheMaster on April 02, 2013, 06:35:00 AM
Thanks Paul.

That makes sense, the issue now is: am I able to build a quick select routine faster than Array.Sort() ?

Purely out of curiosity, I tried to parallelize your Construct() method, but only saw a minor improvement of around 20-30%.

The search can also be parallelized, but it requires major surgery (it needs to return results rather than add them to a single list, or you have to use a collection from System.Collections.Concurrent).

Also, you don't need to use 'ref' parameters to pass and use reference types, like List<T>. Use 'ref' only when you're going to actually change the value of the variable in the caller, rather than operate on an instance of a passed reference type. See my note in the code.

Code - C#: [Select]
  1.  
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Autodesk.AutoCAD.ApplicationServices;
  6. using Autodesk.AutoCAD.DatabaseServices;
  7. using Autodesk.AutoCAD.EditorInput;
  8. using Autodesk.AutoCAD.Geometry;
  9. using Autodesk.AutoCAD.Runtime;
  10. using AcAp = Autodesk.AutoCAD.ApplicationServices.Application;
  11. using System.Diagnostics;
  12.  
  13. namespace PointKdTree
  14. {
  15.    public class KdTreeNode<T>
  16.    {
  17.       public T Value
  18.       {
  19.          get;
  20.          internal set;
  21.       }
  22.       public KdTreeNode<T> LeftChild
  23.       {
  24.          get;
  25.          internal set;
  26.       }
  27.       public KdTreeNode<T> RightChild
  28.       {
  29.          get;
  30.          internal set;
  31.       }
  32.       public int Depth
  33.       {
  34.          get;
  35.          internal set;
  36.       }
  37.       public bool Processed
  38.       {
  39.          get;
  40.          set;
  41.       }
  42.  
  43.       public KdTreeNode( T value )
  44.       {
  45.          this.Value = value;
  46.       }
  47.    }
  48.  
  49.    public class Point3dNode : KdTreeNode<Point3d>
  50.    {
  51.       public Point3dNode( Point3d value ) : base(value)
  52.       {
  53.       }
  54.  
  55.       public double this[int dim]
  56.       {
  57.          get
  58.          {
  59.             return this.Value[this.Depth % dim];
  60.          }
  61.       }
  62.    }
  63.  
  64.    public class Point3dTree
  65.    {
  66.       private int dimension;
  67.       public KdTreeNode<Point3d> Root
  68.       {
  69.          get;
  70.          private set;
  71.       }
  72.  
  73.       public Point3dTree( IEnumerable<Point3d> points ) : this( points, false )
  74.       {
  75.       }
  76.  
  77.       public Point3dTree( IEnumerable<Point3d> points, bool ignoreZ )
  78.       {
  79.          if( points == null )
  80.             throw new ArgumentNullException( "points" );
  81.          this.dimension = ignoreZ ? 2 : 3;
  82.          Point3d[] pts = points.Distinct().ToArray();
  83.          this.Root = Construct( pts, 0 );
  84.       }
  85.  
  86.       public List<Point3d> NearestNeighbours( KdTreeNode<Point3d> node, double radius )
  87.       {
  88.          if( node == null )
  89.             throw new ArgumentNullException( "node" );
  90.          List<Point3d> result = new List<Point3d>();
  91.          GetNeighbours( node.Value, radius, this.Root.LeftChild, result );
  92.          GetNeighbours( node.Value, radius, this.Root.RightChild, result );
  93.          return result;
  94.       }
  95.  
  96.       private void GetNeighbours( Point3d center, double radius, KdTreeNode<Point3d> node, List<Point3d> result )
  97.       {
  98.          if( node == null ) return;
  99.          Point3d pt = node.Value;
  100.          if( !node.Processed && center.DistanceTo( pt ) <= radius )
  101.          {
  102.             result.Add( pt );
  103.          }
  104.          int d = node.Depth % this.dimension;
  105.          double coordCen = center[d];
  106.          double coordPt = pt[d];
  107.          if( Math.Abs( coordCen - coordPt ) > radius )
  108.          {
  109.             if( coordCen < coordPt )
  110.             {
  111.                GetNeighbours( center, radius, node.LeftChild, result );
  112.             }
  113.             else
  114.             {
  115.                GetNeighbours( center, radius, node.RightChild, result );
  116.             }
  117.          }
  118.          else
  119.          {
  120.             GetNeighbours( center, radius, node.LeftChild, result );
  121.             GetNeighbours( center, radius, node.RightChild, result );
  122.          }
  123.       }
  124.  
  125.       internal static bool useParallel = false;
  126.  
  127.       private KdTreeNode<Point3d> Construct( Point3d[] points, int depth )
  128.       {
  129.          int length = points.Length;
  130.          if( length == 0 ) return null;
  131.          int d = depth % this.dimension;
  132.          Array.Sort( points, ( p1, p2 ) => p1[d].CompareTo( p2[d] ) );
  133.          int med = length / 2;
  134.          KdTreeNode<Point3d> node = new KdTreeNode<Point3d>( points[med] );
  135.          node.Depth = depth;
  136.          var left = points.Take( med ).ToArray();
  137.          var right = points.Skip( med + 1 ).ToArray();\
  138.          
  139.          /////////////////////////////////////////////////////////////
  140.          /// Do construction in parallel up to 8 levels (it
  141.          /// makes little sense to parallelize any more than
  142.          /// that, unless you got dozens of CPUs):
  143.          
  144.          if( useParallel && depth < 9 )
  145.          {
  146.             System.Threading.Tasks.Parallel.Invoke(
  147.                () => node.LeftChild = Construct( left, depth + 1 ),
  148.                () => node.RightChild = Construct( right, depth + 1 )
  149.             );
  150.          }
  151.          else
  152.          {
  153.             node.LeftChild = Construct( left, depth + 1 );
  154.             node.RightChild = Construct( right, depth + 1 );
  155.          }
  156.          return node;
  157.       }
  158.  
  159.    }
  160.  
  161.    public class CommandMethods
  162.    {
  163.       struct Point3dPair
  164.       {
  165.          public readonly Point3d Start;
  166.          public readonly Point3d End;
  167.  
  168.          public Point3dPair( Point3d start, Point3d end )
  169.          {
  170.             this.Start = start;
  171.             this.End = end;
  172.          }
  173.       }
  174.  
  175.       /// Toggle parallel execution on/off:
  176.      
  177.       [CommandMethod( "CONNECT_PARALLEL" )]
  178.       public static void ConnectParallelSwitch()
  179.       {
  180.          Point3dTree.useParallel ^= true;
  181.          Application.DocumentManager.MdiActiveDocument.Editor.WriteMessage(
  182.             "\nConnect using Parallel = {0}", Point3dTree.useParallel );
  183.       }
  184.  
  185.       [CommandMethod( "Connect" )]
  186.       public static void Connect()
  187.       {
  188.          Stopwatch sw = new System.Diagnostics.Stopwatch();
  189.          Stopwatch sw2 = null;
  190.          sw.Start();
  191.          Document doc = AcAp.DocumentManager.MdiActiveDocument;
  192.          Database db = doc.Database;
  193.          Editor ed = doc.Editor;
  194.          if( Point3dTree.useParallel )
  195.             ed.WriteMessage( "\n(Using Parallel Execution)\n" );
  196.          else
  197.             ed.WriteMessage( "\n(Using Sequential Execution)\n" );
  198.          RXClass rxc = RXClass.GetClass( typeof( DBPoint ) );
  199.          using( BlockTableRecord btr = (BlockTableRecord) db.CurrentSpaceId.Open( OpenMode.ForWrite ) )
  200.          {
  201.             var pts = btr.Cast<ObjectId>()
  202.                 .Where( id => id.ObjectClass == rxc )
  203.                 .Select( id =>
  204.                 {
  205.                    using( DBPoint pt = (DBPoint) id.Open( OpenMode.ForRead ) )
  206.                    {
  207.                       return pt.Position;
  208.                    }
  209.                 } );
  210.  
  211.             sw2 = Stopwatch.StartNew();
  212.             Point3dTree tree = new Point3dTree( pts, true );
  213.             List<Point3dPair> pairs = new List<Point3dPair>();
  214.             ConnectPoints( tree, tree.Root, 70.0, pairs );
  215.             sw2.Stop();
  216.             foreach( Point3dPair pair in pairs )
  217.             {
  218.                using( Line line = new Line( pair.Start, pair.End ) )
  219.                {
  220.                   btr.AppendEntity( line );
  221.                }
  222.             }
  223.          }
  224.          sw.Stop();
  225.          ed.WriteMessage( "\nElapsed milliseconds: {0}", sw.ElapsedMilliseconds );
  226.          ed.WriteMessage( "\nNearestNeighbor time: {0}", sw2.ElapsedMilliseconds );
  227.       }
  228.  
  229.  
  230.       /// Don't use 'ref' parameters for reference types, unless
  231.       /// your intention is to modify the value of the variable
  232.       /// in the caller. If you are only accessing the reference
  233.       /// type (e.g., adding items to a List<T>), you don't need
  234.       /// 'ref', which has some performance overhead.
  235.      
  236.       private static void ConnectPoints( Point3dTree tree, KdTreeNode<Point3d> node, double dist, List<Point3dPair> pointPairs )
  237.       {
  238.          if( node == null ) return;
  239.          node.Processed = true;
  240.          Point3d center = node.Value;
  241.          foreach( Point3d pt in tree.NearestNeighbours( node, dist ) )
  242.          {
  243.             pointPairs.Add( new Point3dPair( center, pt ) );
  244.          }
  245.          ConnectPoints( tree, node.LeftChild, dist, pointPairs );
  246.          ConnectPoints( tree, node.RightChild, dist, pointPairs );
  247.       }
  248.    }
  249. }
  250.  
  251.  

Quick test with 50K points:

Code - Text: [Select]
  1.  
  2. Command: CONNECT
  3. (Using Sequential Execution)
  4. Elapsed milliseconds: 875
  5. NearestNeighbor time: 858
  6.  
  7. Command: CONNECT_PARALLEL
  8. Connect using Parallel = True
  9.  
  10. Command: CONNECT
  11. (Using Parallel Execution)
  12. Elapsed milliseconds: 604
  13. NearestNeighbor time: 602
  14.  
  15.  
Title: Re: Kd Tree
Post by: pkohut on April 02, 2013, 10:00:17 AM
Nice speed bump and coding Tony.

FWIW, Array.Sort is O(n log n), while nth_element is O(n).
Title: Re: Kd Tree
Post by: TheMaster on April 02, 2013, 06:29:50 PM
Nice speed bump and coding Tony.

FWIW, Array.Sort is O(n log n), while nth_element is O(n).

Hi Paul, and thanks.

I posted what I thought was a notable improvement to Gile's original code, but as it turned out, it was Linq's deferred execution that was corrupting the timings, so I retracted it.

I tried using the QuickSelect algorithm in place of Array.Sort(), but that didn't seem to help much, and I think that's mainly because the majority of the execution time of Gile's code is consumed by getting the coordinate data from the DBPoints, drawing the Lines, and most-notably, by a call to Enumerable.Distinct() to remove all duplicates. Only a small fraction of it is consumed by building the tree, and I don't know if that's as important as searching it.

Below is Gile's code with a few ideas on how he might improve overall performance implemented, which has various tweaks, such as avoiding the use of ToArray() when the quantity is known, avoiding use of Enumerable.Distinct() (which could be dealt with in the process of building the tree), and storing the entire dataset of points in array, and operating on int[] arrays of indices into the array, rather than directly on the Point3ds themselves.  So, the Construct() method was refactored to pass arrays of int[], that contain indices of elements in the Point3d[] array, which is never duplicated.

Code - C#: [Select]
  1.  
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Autodesk.AutoCAD.ApplicationServices;
  6. using Autodesk.AutoCAD.DatabaseServices;
  7. using Autodesk.AutoCAD.EditorInput;
  8. using Autodesk.AutoCAD.Geometry;
  9. using Autodesk.AutoCAD.Runtime;
  10. using AcAp = Autodesk.AutoCAD.ApplicationServices.Application;
  11. using System.Diagnostics;
  12.  
  13. namespace PointKdTree
  14. {
  15.    public class KdTreeNode<T>
  16.    {
  17.       public T Value
  18.       {
  19.          get;
  20.          internal set;
  21.       }
  22.       public KdTreeNode<T> LeftChild
  23.       {
  24.          get;
  25.          internal set;
  26.       }
  27.       public KdTreeNode<T> RightChild
  28.       {
  29.          get;
  30.          internal set;
  31.       }
  32.       public int Depth
  33.       {
  34.          get;
  35.          internal set;
  36.       }
  37.       public bool Processed
  38.       {
  39.          get;
  40.          set;
  41.       }
  42.  
  43.       public KdTreeNode( T value )
  44.       {
  45.          this.Value = value;
  46.       }
  47.    }
  48.  
  49.    public class Point3dNode : KdTreeNode<Point3d>
  50.    {
  51.       public Point3dNode( Point3d value ) : base(value)
  52.       {
  53.       }
  54.  
  55.       public double this[int dim]
  56.       {
  57.          get
  58.          {
  59.             return this.Value[this.Depth % dim];
  60.          }
  61.       }
  62.    }
  63.  
  64.    public class Point3dTree
  65.    {
  66.       private int dimension;
  67.  
  68.       public KdTreeNode<Point3d> Root
  69.       {
  70.          get;
  71.          private set;
  72.       }
  73.  
  74.       public Point3dTree( IEnumerable<Point3d> points )
  75.          : this( points, false )
  76.       {
  77.       }
  78.  
  79.       /// <summary>
  80.       /// Eliminated call to Enumerable.Distinct() which
  81.       /// was consuming most of the time required here:
  82.       /// </summary>
  83.       public Point3dTree( IEnumerable<Point3d> points, bool ignoreZ )
  84.       {
  85.          if( points == null )
  86.             throw new ArgumentNullException( "points" );
  87.          this.dimension = ignoreZ ? 2 : 3;
  88.          var data = points as Point3d[] ?? points.ToArray(); //
  89.          int[] indices = new int[data.Length];
  90.          for( int i = 0; i < data.Length; i++ )
  91.             indices[i] = i;
  92.          this.Root = Construct( data, indices, 0 );
  93.       }
  94.  
  95.       public List<Point3d> NearestNeighbours( KdTreeNode<Point3d> node, double radius )
  96.       {
  97.          if( node == null )
  98.             throw new ArgumentNullException( "node" );
  99.          List<Point3d> result = new List<Point3d>();
  100.          GetNeighbours( node.Value, radius, this.Root.LeftChild, result );
  101.          GetNeighbours( node.Value, radius, this.Root.RightChild, result );
  102.          return result;
  103.       }
  104.  
  105.       private void GetNeighbours( Point3d center, double radius, KdTreeNode<Point3d> node, List<Point3d> result )
  106.       {
  107.          if( node == null ) return;
  108.          Point3d pt = node.Value;
  109.          if( !node.Processed && center.DistanceTo( pt ) <= radius )
  110.          {
  111.             result.Add( pt );
  112.          }
  113.          int d = node.Depth % this.dimension;
  114.          double coordCen = center[d];
  115.          double coordPt = pt[d];
  116.          if( Math.Abs( coordCen - coordPt ) > radius )
  117.          {
  118.             if( coordCen < coordPt )
  119.             {
  120.                GetNeighbours( center, radius, node.LeftChild, result );
  121.             }
  122.             else
  123.             {
  124.                GetNeighbours( center, radius, node.RightChild, result );
  125.             }
  126.          }
  127.          else
  128.          {
  129.             GetNeighbours( center, radius, node.LeftChild, result );
  130.             GetNeighbours( center, radius, node.RightChild, result );
  131.          }
  132.       }
  133.  
  134.       internal static bool useParallel = false;
  135.  
  136.       /// <summary>
  137.       /// Refactored to pass entire dataset of Point3d in data,
  138.       /// and to manipulate int[] arrays of indices.
  139.       /// </summary>
  140.      
  141.       private KdTreeNode<Point3d> Construct( Point3d[] data, int[] points, int depth )
  142.       {
  143.          int length = points.Length;
  144.          if( length == 0 ) return null;
  145.          int d = depth % this.dimension;
  146.          int l1 = length / 2;
  147.          int l2 = length - l1 - 1;
  148.          var items = new KeyValuePair<double, int>[length];
  149.          for( int i = 0; i < length; i++ )
  150.             items[i] = new KeyValuePair<double, int>( data[points[i]][d], i );
  151.          int pivot = OrderedSelect( items, d );
  152.          var left = new int[l1];
  153.          for( int i = 0; i < l1; i++ )
  154.             left[i] = items[i].Value;
  155.          var right = new int[l2];
  156.          ++l1;
  157.          for( int i = 0; i < l2; i++ )
  158.             right[i] = items[i+l1].Value;
  159.          KdTreeNode<Point3d> node = new KdTreeNode<Point3d>( data[pivot] );
  160.          KdTreeNode<Point3d> leftchild = null;
  161.          KdTreeNode<Point3d> rightchild = null;
  162.          node.Depth = depth;
  163.          if( useParallel && depth < 4 ) // parallelize up to 4 levels deep
  164.          {
  165.             System.Threading.Tasks.Parallel.Invoke(
  166.                () => leftchild = Construct( data, left, depth + 1 ),
  167.                () => rightchild = Construct( data, right, depth + 1 ) );
  168.          }
  169.          else
  170.          {
  171.             leftchild = Construct( data, left, depth + 1 );
  172.             rightchild = Construct( data, right, depth + 1 );
  173.          }
  174.          node.LeftChild = leftchild;
  175.          node.RightChild = rightchild;
  176.          return node;
  177.       }
  178.  
  179.       // Quick select algorithm operating on KeyValuePair<double,int>
  180.       // where the key is the ordinate/dimension and the value is the
  181.       // index of the associated Point3d in the dataset:
  182.  
  183.       public static int OrderedSelect( KeyValuePair<double, int>[] points, int dim )
  184.       {
  185.          int k = points.Length / 2;
  186.          if( points == null || points.Length <= k )
  187.             throw new ArgumentException( "array" );
  188.  
  189.          int from = 0, to = points.Length - 1;
  190.          while( from < to )
  191.          {
  192.             int r = from, w = to;
  193.             double pivot = points[( r + w ) / 2].Key;
  194.             while( r < w )
  195.             {
  196.                if( points[r].Key >= pivot )
  197.                {
  198.                   KeyValuePair<double, int> tmp = points[w];
  199.                   points[w] = points[r];
  200.                   points[r] = tmp;
  201.                   w--;
  202.                }
  203.                else
  204.                {
  205.                   r++;
  206.                }
  207.             }
  208.             if( points[r].Key > pivot )
  209.             {
  210.                r--;
  211.             }
  212.             if( k <= r )
  213.             {
  214.                to = r;
  215.             }
  216.             else
  217.             {
  218.                from = r + 1;
  219.             }
  220.          }
  221.          return points[k].Value;
  222.       }
  223.  
  224.    }
  225.  
  226.    public class CommandMethods
  227.    {
  228.       struct Point3dPair
  229.       {
  230.          public readonly Point3d Start;
  231.          public readonly Point3d End;
  232.  
  233.          public Point3dPair( Point3d start, Point3d end )
  234.          {
  235.             this.Start = start;
  236.             this.End = end;
  237.          }
  238.       }
  239.  
  240.       [CommandMethod( "CONNECT_PARALLEL" )]
  241.       public static void ConnectParallelSwitch()
  242.       {
  243.          Point3dTree.useParallel ^= true;
  244.          Application.DocumentManager.MdiActiveDocument.Editor.WriteMessage(
  245.             "\nConnect using Parallel = {0}", Point3dTree.useParallel );
  246.       }
  247.  
  248.       [CommandMethod( "Connect" )]
  249.       public static void Connect()
  250.       {
  251.          Stopwatch sw = new System.Diagnostics.Stopwatch();
  252.          Stopwatch sw2 = null;
  253.          Stopwatch sw3 = null;
  254.          long distinctTime = 0L;
  255.          sw.Start();
  256.          Document doc = AcAp.DocumentManager.MdiActiveDocument;
  257.          Database db = doc.Database;
  258.          Editor ed = doc.Editor;
  259.          if( Point3dTree.useParallel )
  260.             ed.WriteMessage( "\n(Using Parallel Execution)\n" );
  261.          else
  262.             ed.WriteMessage( "\n(Using Sequential Execution)\n" );
  263.          RXClass rxc = RXClass.GetClass( typeof( DBPoint ) );
  264.          int cnt = 0;
  265.  
  266.          using( BlockTableRecord btr = (BlockTableRecord) db.CurrentSpaceId.Open( OpenMode.ForWrite ) )
  267.          {
  268.             var dbpoints = btr.Cast<ObjectId>().Where( id => id.ObjectClass == rxc );
  269.             ObjectId[] ids = dbpoints.ToArray();
  270.             int len = ids.Length;
  271.             Point3d[] array = new Point3d[len];
  272.             for( int i = 0; i < len; i++ )
  273.             {
  274.                using( DBPoint pt = (DBPoint) ids[i].Open( OpenMode.ForRead ) )
  275.                {
  276.                   array[i] = pt.Position;
  277.                }
  278.             }
  279.             sw3 = Stopwatch.StartNew();
  280.             Point3dTree tree = new Point3dTree( array, true );
  281.             sw3.Stop();
  282.             sw2 = Stopwatch.StartNew();
  283.             List<Point3dPair> pairs = new List<Point3dPair>( cnt / 2 );
  284.             ConnectPoints( tree, tree.Root, 70.0, pairs );
  285.             sw2.Stop();
  286.             foreach( Point3dPair pair in pairs )
  287.             {
  288.                using( Line line = new Line( pair.Start, pair.End ) )
  289.                {
  290.                   btr.AppendEntity( line );
  291.                }
  292.             }
  293.          }
  294.          sw.Stop();
  295.          ed.WriteMessage( "\nTotal time:                {0,5}", sw.ElapsedMilliseconds );
  296.          ed.WriteMessage( "\nConstruct time:            {0,5}", sw3.ElapsedMilliseconds );
  297.          ed.WriteMessage( "\nNearestNeighbor time:      {0,5}", sw2.ElapsedMilliseconds );
  298.       }
  299.  
  300.       private static void ConnectPoints( Point3dTree tree, KdTreeNode<Point3d> node, double dist, List<Point3dPair> pointPairs )
  301.       {
  302.          if( node == null ) return;
  303.          node.Processed = true;
  304.          Point3d center = node.Value;
  305.          foreach( Point3d pt in tree.NearestNeighbours( node, dist ) )
  306.          {
  307.             pointPairs.Add( new Point3dPair( center, pt ) );
  308.          }
  309.          ConnectPoints( tree, node.LeftChild, dist, pointPairs );
  310.          ConnectPoints( tree, node.RightChild, dist, pointPairs );
  311.       }
  312.    }
  313. }
  314.  
  315.  
Title: Re: Kd Tree
Post by: gile on April 03, 2013, 07:00:58 AM
Hi,

First of all, thanks to Tony and Paul for their interest and advice.

Parallelization
It was part of my goals. I tried to develop (in parallel) the same thing in F# in order to use these features but I still encountering some problems.
I tried to parallelize the search part using a ConcurrentBag, it seems to have much overhead so that it slows down significantly the process.

Unnecessary use of ref
Thank you for pointing that. The back and forth between C# and F#, imperative and functional programming, mutable and immutable data have probably helped to make me miss it.

Array.Copy vs Take / Skip
I changed this as Tony's advice.

Array.Sort vs QuickSelect
I tried to implement some other QuickSelect and nth-smallest algorithms. Each time Array.Sort () is (much) faster.
Tony, your OrderedSelect() function is not working here.

Distinct ()
I did not notice any significant speed difference with or without.

Here're somme results with my current but not last version (changes made for point 1, 2 and 3, keeping using Array.Sort() and Distinct()).

100K points
Code: [Select]
Connect using Parallel = False
Get points:      388
Build tree:      718
Connect points:  659
Draw lines:      869
Total:          2634

Connect using Parallel = True
Get points:      374
Build tree:      287
Connect points:  669
Draw lines:      883
Total:          2213
Title: Re: Kd Tree
Post by: TheMaster on April 03, 2013, 10:09:32 AM

Tony, your OrderedSelect() function is not working here.


Hi Gile, yes, there's a bug in the code I posted.

Code - C#: [Select]
  1.  
  2.      for( int i = 0; i < length; i++ )
  3.           items[i] = new KeyValuePair<double, int>( data[points[i]][d], i );
  4.  
  5. // Should be:
  6.  
  7.      for( int i = 0; i < length; i++ )
  8.      {  
  9.           int index = points[i];
  10.           items[i] = new KeyValuePair<double, int>( data[index][d], index );
  11.      }
  12.  
  13.  

I believe, with that it should work, but I haven't seen much of a difference in the result.

FWIW, the C++ kdtree code I've used in the past uses a bubble-sort, which is probably also not optimal.

[/quote]
Title: Re: Kd Tree
Post by: TheMaster on April 03, 2013, 10:22:47 AM
Here're somme results with my current but not last version (changes made for point 1, 2 and 3, keeping using Array.Sort() and Distinct()).

100K points
Code: [Select]
Connect using Parallel = False
Get points:      388
Build tree:      718
Connect points:  659
Draw lines:      869
Total:          2634

Connect using Parallel = True
Get points:      374
Build tree:      287
Connect points:  669
Draw lines:      883
Total:          2213

Hi Gile.  In your original code, it wasn't possible to break out the time required to build the tree, because the Linq code that opens each DBPoint and gets its position wasn't executing until the constructor of your tree class called Distinct().ToArray() on the IEnumerable<Point3d> argument.

You might notice that I rewrote that part, so that I could measure the tree construction time, without also including the time required to iterate the BlockTableRecord, open  each DBPoint and get its Position.

So, I'm not entirely sure what to make of your numbers because I don't know if you're using the same code from the above post, or code based on your original post, but because of the Linq deferred execution, the times will be very different.
Title: Re: Kd Tree
Post by: gile on April 03, 2013, 11:02:02 AM
Quote
So, I'm not entirely sure what to make of your numbers because I don't know if you're using the same code from the above post, or code based on your original post, but because of the Linq deferred execution, the times will be very different.
I changed the code the same way you did: build a Point3d array (Get points) and then corstruvt the tree (Build tree).
Title: Re: Kd Tree
Post by: TheMaster on April 03, 2013, 12:39:09 PM

Array.Sort vs QuickSelect
I tried to implement some other QuickSelect and nth-smallest algorithms. Each time Array.Sort () is (much) faster. Tony, your OrderedSelect() function is not working here.

Gile, after applying the bug fix I showed in my other reply, it seems to be working, and I wanted to be sure, so I wrote this quick test, and it also confirms what Paul has suggested (much faster than Array.Sort), but of course, since the time to build the tree is a very small fraction of your overall time, it isn't significant to that, but I would still use it rather than Array.Sort anyways.

Code - C#: [Select]
  1.  
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using Autodesk.AutoCAD.Runtime;
  7. using Autodesk.AutoCAD.EditorInput;
  8. using System.Diagnostics;
  9. using Autodesk.AutoCAD.ApplicationServices;
  10.  
  11. namespace System
  12. {
  13.  
  14.    public static class QuickSelectExtensions
  15.    {
  16.       /// <summary>
  17.       /// Quick Select algorithm generic implementation
  18.       ///
  19.       /// Rearranges the given array up to the k'th element
  20.       /// (the median element) so that all values to the left
  21.       /// are < the median element. Returns the median element.
  22.       ///
  23.       /// </summary>
  24.      
  25.       public static T QuickSelectMedian<T>( this T[] items, Comparison<T> compare )
  26.       {
  27.          if( items == null )
  28.             throw new ArgumentException( "items" );
  29.          int k = items.Length / 2;
  30.          int from = 0;
  31.          int to = items.Length - 1;
  32.          while( from < to )
  33.          {
  34.             int r = from;
  35.             int w = to;
  36.             T current = items[( r + w ) / 2];
  37.             while( r < w )
  38.             {
  39.                if( compare( items[r], current ) > -1 )
  40.                {
  41.                   var tmp = items[w];
  42.                   items[w] = items[r];
  43.                   items[r] = tmp;
  44.                   w--;
  45.                }
  46.                else
  47.                {
  48.                   r++;
  49.                }
  50.             }
  51.             if( compare( items[r], current ) > 0 )
  52.             {
  53.                r--;
  54.             }
  55.             if( k <= r )
  56.             {
  57.                to = r;
  58.             }
  59.             else
  60.             {
  61.                from = r + 1;
  62.             }
  63.          }
  64.          return items[k];
  65.       }
  66.    }
  67.  
  68.    public static class QSelectVsQuickSortCommands
  69.    {
  70.       /// <summary>
  71.       ///
  72.       /// Compare QuickSort (Array.Sort), with QuickSelectMedian()
  73.       /// on a random array of double[]:
  74.       ///
  75.       /// </summary>
  76.      
  77.       [CommandMethod( "QSELVSQSORT" )]
  78.       public static void QuickSelectVersesQuickSort()
  79.       {
  80.          Editor ed = Application.DocumentManager.MdiActiveDocument.Editor;
  81.  
  82. #if( DEBUG )
  83.          ed.WriteMessage( "\nTesting should NOT be done in DEBUG builds 8)");
  84.          return;
  85. #endif
  86.          int count = 1000000;
  87.          Random rnd = new Random();
  88.          double[] array = new double[count];
  89.          for( int i = 0; i < count; i++ )
  90.             array[i] = rnd.NextDouble();
  91.          double[] array2 = new double[count];
  92.          Array.Copy( array, array2, count );
  93.          Stopwatch sw = Stopwatch.StartNew();
  94.          Array.Sort( array, ( a, b ) => a.CompareTo( b ) );
  95.          sw.Stop();
  96.          Stopwatch sw2 = Stopwatch.StartNew();
  97.          double med = array2.QuickSelectMedian(
  98.             Comparer<double>.Default.Compare );
  99.          sw2.Stop();
  100.          ed.WriteMessage( "\nArray.Sort():        {0,6}",
  101.             sw.ElapsedMilliseconds );
  102.          ed.WriteMessage( "\nQuickSelectMedian(): {0,6}",
  103.             sw2.ElapsedMilliseconds );
  104.       }
  105.  
  106.       /// <summary>
  107.       ///  Verify if QuickSelectMedian() is working:
  108.       ///  
  109.       ///  After returning, all values in the array
  110.       ///  whose indices are below the median value
  111.       ///  must be less than the median value. The
  112.       ///  median value is the value at [length/2]
  113.       ///  
  114.       /// </summary>
  115.       [CommandMethod( "QSELTEST" )]
  116.       public static void Qselect()
  117.       {
  118.          int count = 25;
  119.          Random rnd = new Random();
  120.          double[] array = new double[count];
  121.          for( int i = 0; i < count; i++ )
  122.             array[i] = 100 * rnd.NextDouble();
  123.          WriteLine( "\narray Before QuickSelectMedian():\n" );
  124.          Dump( array );
  125.          double med = array.QuickSelectMedian(
  126.             Comparer<double>.Default.Compare );
  127.          WriteLine( "\nQuickSelectMedian(): [{0}] = {1}\n",
  128.             array.Length / 2, med );
  129.          WriteLine( "\narray After QuickSelectMedian():\n" );
  130.          Dump( array );
  131.       }
  132.  
  133.       public static void WriteLine( string msg, params object[] args )
  134.       {
  135.          Application.DocumentManager.MdiActiveDocument.Editor.WriteMessage(
  136.             "\n" + msg, args );
  137.       }
  138.  
  139.       static void Dump( double[] array )
  140.       {
  141.          int cnt = array.Length;
  142.          for( int i = 0; i < cnt; i++ )
  143.          {
  144.             WriteLine( "array[{0}] = {1:f4}", i, array[i] );
  145.          }
  146.       }
  147.    }
  148. }
  149.  
  150.  


Code - Text: [Select]
  1.  
  2. Command: QSELTEST
  3.  
  4. array Before QuickSelectMedian():
  5.  
  6. array[0] = 99.9556
  7. array[1] = 87.1161
  8. array[2] = 65.3429
  9. array[3] = 95.0263
  10. array[4] = 5.5207
  11. array[5] = 38.1086
  12. array[6] = 82.9302
  13. array[7] = 83.4299
  14. array[8] = 34.1567
  15. array[9] = 46.9312
  16. array[10] = 24.4187
  17. array[11] = 55.2761
  18. array[12] = 9.0754
  19. array[13] = 92.8103
  20. array[14] = 35.5625
  21. array[15] = 88.6875
  22. array[16] = 7.8821
  23. array[17] = 10.9855
  24. array[18] = 69.7092
  25. array[19] = 17.6414
  26. array[20] = 82.9813
  27. array[21] = 12.4893
  28. array[22] = 2.8624
  29. array[23] = 21.3571
  30. array[24] = 74.9873
  31.  
  32. QuickSelectMedian(): [12] = 46.9311974229902
  33.  
  34. array After QuickSelectMedian():
  35.  
  36. array[0] = 2.8624
  37. array[1] = 7.8821
  38. array[2] = 5.5207
  39. array[3] = 21.3571
  40. array[4] = 12.4893
  41. array[5] = 17.6414
  42. array[6] = 10.9855
  43. array[7] = 34.1567
  44. array[8] = 9.0754
  45. array[9] = 24.4187
  46. array[10] = 35.5625
  47. array[11] = 38.1086
  48. array[12] = 46.9312
  49. array[13] = 55.2761
  50. array[14] = 65.3429
  51. array[15] = 69.7092
  52. array[16] = 82.9302
  53. array[17] = 82.9813
  54. array[18] = 74.9873
  55. array[19] = 87.1161
  56. array[20] = 83.4299
  57. array[21] = 88.6875
  58. array[22] = 99.9556
  59. array[23] = 95.0263
  60. array[24] = 92.8103
  61.  
  62.  
  63.  
  64.  
  65. Command: QSELVSQSORT
  66.  
  67. Array.Sort():           379
  68. QuickSelectMedian():     47
  69.  
  70. Command:
  71.  
  72.  
Title: Re: Kd Tree
Post by: gile on April 03, 2013, 02:16:00 PM
Thanks again, Tony.
I'll give a look and try to understand why my QuickSelect was so slow.
Title: Re: Kd Tree
Post by: gile on April 03, 2013, 06:05:14 PM
thanks again Tony, I get some improvement using your QuickSelectMedian<T>() method.

Here were I am:

Benchmarck (Release mode for 100K points)
Code: [Select]
Connect using Parallel = False
Get points:      236
Build tree:      203
Connect points:  265
Draw lines:      722
Total:          1426

Connect using Parallel = True
Get points:      244
Build tree:      105
Connect points:  282
Draw lines:      721
Total:          1352

The Point3dTree class.
I added a NearestNeighbour() method to find the nearest point of a given point.
Code - C#: [Select]
  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Autodesk.AutoCAD.Geometry;
  6.  
  7. namespace PointKdTree
  8. {
  9.     public class KdTreeNode<T>
  10.     {
  11.         public KdTreeNode(T value) { this.Value = value; }
  12.  
  13.         public T Value { get; internal set; }
  14.         public KdTreeNode<T> LeftChild { get; internal set; }
  15.         public KdTreeNode<T> RightChild { get; internal set; }
  16.         public int Depth { get; internal set; }
  17.         public bool Processed { get; set; }
  18.     }
  19.  
  20.     public class Point3dTree
  21.     {
  22.         private int dimension;
  23.         internal static bool useParallel = false;
  24.         public KdTreeNode<Point3d> Root { get; private set; }
  25.  
  26.         public Point3dTree(IEnumerable<Point3d> points) : this(points, false) { }
  27.  
  28.         public Point3dTree(IEnumerable<Point3d> points, bool ignoreZ)
  29.         {
  30.             if (points == null)
  31.                 throw new ArgumentNullException("points");
  32.             this.dimension = ignoreZ ? 2 : 3;
  33.             Point3d[] pts = points.Distinct().ToArray();
  34.             this.Root = Construct(pts, 0);
  35.         }
  36.  
  37.         public List<Point3d> NearestNeighbours(KdTreeNode<Point3d> node, double radius)
  38.         {
  39.             if (node == null)
  40.                 throw new ArgumentNullException("node");
  41.             List<Point3d> result = new List<Point3d>();
  42.             NearestNeighbours(node.Value, radius, this.Root.LeftChild, result);
  43.             NearestNeighbours(node.Value, radius, this.Root.RightChild, result);
  44.             return result;
  45.         }
  46.  
  47.         private void NearestNeighbours(Point3d center, double radius, KdTreeNode<Point3d> node, List<Point3d> result)
  48.         {
  49.             if (node == null) return;
  50.             Point3d pt = node.Value;
  51.             if (!node.Processed && center.DistanceTo(pt) <= radius)
  52.             {
  53.                 result.Add(pt);
  54.             }
  55.             int d = node.Depth % this.dimension;
  56.             double coordCen = center[d];
  57.             double coordPt = pt[d];
  58.             if (Math.Abs(coordCen - coordPt) > radius)
  59.             {
  60.                 if (coordCen < coordPt)
  61.                 {
  62.                     NearestNeighbours(center, radius, node.LeftChild, result);
  63.                 }
  64.                 else
  65.                 {
  66.                     NearestNeighbours(center, radius, node.RightChild, result);
  67.                 }
  68.             }
  69.             else
  70.             {
  71.                 NearestNeighbours(center, radius, node.LeftChild, result);
  72.                 NearestNeighbours(center, radius, node.RightChild, result);
  73.             }
  74.         }
  75.  
  76.         public Point3d NearestNeighbour(Point3d location)
  77.         {
  78.             return NearestNeighbour(location, this.Root, this.Root.Value, double.MaxValue);
  79.         }
  80.  
  81.         private Point3d NearestNeighbour(Point3d location, KdTreeNode<Point3d> node, Point3d currentBest, double bestDistance)
  82.         {
  83.             if (node == null)
  84.                 return currentBest;
  85.             int dim = node.Depth % this.dimension;
  86.             Point3d nodeLocation = node.Value;
  87.             double distance = location.DistanceTo(nodeLocation);
  88.             if (distance >= 0.0 && distance < bestDistance)
  89.             {
  90.                 currentBest = nodeLocation;
  91.                 bestDistance = distance;
  92.             }
  93.             bool isLeftNearer = location[dim] < nodeLocation[dim];
  94.             KdTreeNode<Point3d> nearestChild = isLeftNearer ? node.LeftChild : node.RightChild;
  95.             if (nearestChild != null)
  96.             {
  97.                 currentBest = NearestNeighbour(location, nearestChild, currentBest, bestDistance);
  98.                 bestDistance = currentBest.DistanceTo(location);
  99.             }
  100.             if (bestDistance > Math.Abs(location[dim] - nodeLocation[dim]))
  101.             {
  102.                 KdTreeNode<Point3d> farestChild = isLeftNearer ? node.RightChild : node.LeftChild;
  103.                 if (farestChild != null)
  104.                 {
  105.                     currentBest = NearestNeighbour(location, farestChild, currentBest, bestDistance);
  106.                     bestDistance = currentBest.DistanceTo(location);
  107.                 }
  108.             }
  109.             return currentBest;
  110.         }
  111.  
  112.         private KdTreeNode<Point3d> Construct(Point3d[] points, int depth)
  113.         {
  114.             int length = points.Length;
  115.             if (length == 0) return null;
  116.             int d = depth % this.dimension;
  117.             Point3d median = QuickSelectMedian(points, (p1, p2) => p1[d].CompareTo(p2[d]));
  118.             KdTreeNode<Point3d> node = new KdTreeNode<Point3d>(median);
  119.             node.Depth = depth;
  120.             int mid = length / 2;
  121.             int rlen = length - mid - 1;
  122.             Point3d[] left = new Point3d[mid];
  123.             Point3d[] right = new Point3d[rlen];
  124.             Array.Copy(points, 0, left, 0, mid);
  125.             Array.Copy(points, mid + 1, right, 0, rlen);
  126.             if (useParallel && depth < 4)
  127.             {
  128.                 System.Threading.Tasks.Parallel.Invoke(
  129.                    () => node.LeftChild = Construct(left, depth + 1),
  130.                    () => node.RightChild = Construct(right, depth + 1)
  131.                 );
  132.             }
  133.             else
  134.             {
  135.                 node.LeftChild = Construct(left, depth + 1);
  136.                 node.RightChild = Construct(right, depth + 1);
  137.             }
  138.             return node;
  139.         }
  140.  
  141.         // From Tony Tanzillo
  142.         // http://www.theswamp.org/index.php?topic=44312.msg495808#msg495808
  143.         private T QuickSelectMedian<T>(T[] items, Comparison<T> compare)
  144.         {
  145.             int l = items.Length;
  146.             int k = l / 2;
  147.             if (items == null || items.Length == 0)
  148.                 throw new ArgumentException("array");
  149.             int from = 0;
  150.             int to = l - 1;
  151.             while (from < to)
  152.             {
  153.                 int r = from;
  154.                 int w = to;
  155.                 T current = items[(r + w) / 2];
  156.                 while (r < w)
  157.                 {
  158.                     if (compare(items[r], current) > -1)
  159.                     {
  160.                         var tmp = items[w];
  161.                         items[w] = items[r];
  162.                         items[r] = tmp;
  163.                         w--;
  164.                     }
  165.                     else
  166.                     {
  167.                         r++;
  168.                     }
  169.                 }
  170.                 if (compare(items[r], current) > 0)
  171.                 {
  172.                     r--;
  173.                 }
  174.                 if (k <= r)
  175.                 {
  176.                     to = r;
  177.                 }
  178.                 else
  179.                 {
  180.                     from = r + 1;
  181.                 }
  182.             }
  183.             return items[k];
  184.         }
  185.     }
  186. }
  187.  

The command used for the tests
Code - C#: [Select]
  1. // (C) Copyright 2012 by Gilles Chanteau
  2. //
  3. using System.Collections.Generic;
  4. using System.Collections.Concurrent;
  5. using System.Linq;
  6. using Autodesk.AutoCAD.ApplicationServices;
  7. using Autodesk.AutoCAD.DatabaseServices;
  8. using Autodesk.AutoCAD.EditorInput;
  9. using Autodesk.AutoCAD.Geometry;
  10. using Autodesk.AutoCAD.Runtime;
  11. using AcAp = Autodesk.AutoCAD.ApplicationServices.Application;
  12. using System;
  13.  
  14. [assembly: CommandClass(typeof(PointKdTree.CommandMethods))]
  15.  
  16. namespace PointKdTree
  17. {
  18.     public class CommandMethods
  19.     {
  20.         struct Point3dPair
  21.         {
  22.             public readonly Point3d Start;
  23.             public readonly Point3d End;
  24.  
  25.             public Point3dPair(Point3d start, Point3d end)
  26.             {
  27.                 this.Start = start;
  28.                 this.End = end;
  29.             }
  30.         }
  31.  
  32.         // From Tony Tanzillo
  33.         [CommandMethod("CONNECT_PARALLEL")]
  34.         public static void ConnectParallelSwitch()
  35.         {
  36.             Point3dTree.useParallel ^= true;
  37.             Application.DocumentManager.MdiActiveDocument.Editor.WriteMessage(
  38.                "\nConnect using Parallel = {0}", Point3dTree.useParallel);
  39.         }
  40.  
  41.         [CommandMethod("CONNECT")]
  42.         public void Connect()
  43.         {
  44.             System.Diagnostics.Stopwatch sw = new System.Diagnostics.Stopwatch();
  45.             sw.Start();
  46.             Document doc = AcAp.DocumentManager.MdiActiveDocument;
  47.             Database db = doc.Database;
  48.             Editor ed = doc.Editor;
  49.             RXClass rxc = RXClass.GetClass(typeof(DBPoint));
  50.             using (BlockTableRecord btr = (BlockTableRecord)db.CurrentSpaceId.Open(OpenMode.ForWrite))
  51.             {
  52.                 long t0 = sw.ElapsedMilliseconds;
  53.                 ObjectId[] ids = btr.Cast<ObjectId>().Where(id => id.ObjectClass == rxc).ToArray();
  54.                 int len = ids.Length;
  55.                 Point3d[] pts = new Point3d[len];
  56.                 for (int i = 0; i < len; i++)
  57.                 {
  58.                     using (DBPoint pt = (DBPoint)ids[i].Open(OpenMode.ForRead))
  59.                     {
  60.                         pts[i] = pt.Position;
  61.                     }
  62.                 }
  63.                 long t1 = sw.ElapsedMilliseconds;
  64.                 Point3dTree tree = new Point3dTree(pts, true);
  65.                 long t2 = sw.ElapsedMilliseconds;
  66.                 List<Point3dPair> pairs = new List<Point3dPair>();
  67.                 ConnectPoints(tree, tree.Root, 70.0, pairs);
  68.                 long t3 = sw.ElapsedMilliseconds;
  69.                 foreach (Point3dPair pair in pairs)
  70.                 {
  71.                     using (Line line = new Line(pair.Start, pair.End))
  72.                     {
  73.                         btr.AppendEntity(line);
  74.                     }
  75.                 }
  76.                 sw.Stop();
  77.                 long t4 = sw.ElapsedMilliseconds;
  78.                 ed.WriteMessage("\nConnect using Parallel = {0}", Point3dTree.useParallel);
  79.                 ed.WriteMessage(
  80.                     "\nGet points:{0,9}\nBuild tree:{1,9}\nConnect points:{2,5}\nDraw lines:{3,9}\nTotal:{4,14}",
  81.                     t1 - t0, t2 - t1, t3 - t2, t4 - t3, t4);
  82.             }
  83.         }
  84.  
  85.         private void ConnectPoints(Point3dTree tree, KdTreeNode<Point3d> node, double dist, List<Point3dPair> pointPairs)
  86.         {
  87.             if (node == null) return;
  88.             node.Processed = true;
  89.             Point3d center = node.Value;
  90.             foreach (Point3d pt in tree.NearestNeighbours(node, dist))
  91.             {
  92.                 pointPairs.Add(new Point3dPair(center, pt));
  93.             }
  94.             ConnectPoints(tree, node.LeftChild, dist, pointPairs);
  95.             ConnectPoints(tree, node.RightChild, dist, pointPairs);
  96.         }
  97.     }
  98. }
  99.  
Title: Re: Kd Tree
Post by: TheMaster on April 03, 2013, 07:42:37 PM
Hi Gile - Great job.

Unless you're running on a box with 8 physical CPU cores, change the test to run in parallel only if depth < 3 for a quad-core system. You could also check the number of physical CPU cores at runtime and adjust that (for example, on a 2-core system, use Parallel.Invoke() only if depth < 2).
Title: Re: Kd Tree
Post by: TheMaster on April 07, 2013, 11:25:11 AM
Hi Gile - Great job.

Unless you're running on a box with 8 physical CPU cores, change the test to run in parallel only if depth < 3 for a quad-core system. You could also check the number of physical CPU cores at runtime and adjust that (for example, on a 2-core system, use Parallel.Invoke() only if depth < 2).

As it turns out, getting the number of physical cores isn't trivial.

Code - Text: [Select]
  1.  
  2. /// <summary>
  3. /// Queries the system for the number of physical
  4. /// CPU cores and logical processors. Requires a
  5. /// reference to System.Management.dll
  6. ///
  7. /// Credit: http://www.stev.org/post/2011/10/27/C-Number-Of-Cores-and-Processors.aspx
  8. ///
  9. /// </summary>
  10.  
  11. public static class ThreadSystemInfo
  12. {
  13.     static ThreadSystemInfo()
  14.     {
  15.         try
  16.         {
  17.             ObjectQuery wql = new ObjectQuery( "SELECT * FROM Win32_Processor" );
  18.             using( var searcher = new ManagementObjectSearcher( wql ) )
  19.             using( var results = searcher.Get() )
  20.             {
  21.                 if( results != null )
  22.                 {
  23.                     var items = results.Cast<ManagementObject>();
  24.                     if( items.Any() )
  25.                     {
  26.                         var item = items.First();
  27.                         cores = int.Parse( item["NumberOfCores"].ToString() );
  28.                         logicalProcessors = int.Parse( item["NumberOfLogicalProcessors"].ToString() );
  29.                     }
  30.                 }
  31.             }
  32.         }
  33.         catch // above can fail on some systems, so kludge it.
  34.         {
  35.             logicalProcessors = System.Environment.ProcessorCount;
  36.             if( logicalProcessors > 1 )
  37.                 cores = logicalProcessors / 2;
  38.         }
  39.     }
  40.  
  41.     private static int cores = 1;
  42.     public static int PhysicalCoreCount
  43.     {
  44.         get
  45.         {
  46.             return cores;
  47.         }
  48.     }
  49.  
  50.     private static int logicalProcessors = 1;
  51.     public static int LogicalProcessorCount
  52.     {
  53.         get
  54.         {
  55.             return logicalProcessors;
  56.         }
  57.     }
  58. }
  59.  
  60.  
Title: Re: Kd Tree
Post by: gile on April 08, 2013, 04:19:01 AM
Quote
Hi Gile - Great job.
"With a Little Help from My Friends", thanks.

For now I care less which is specific to "connect points challenge" and I focus more on the implementation of a class with methods more useful.
Title: Re: Kd Tree
Post by: TheMaster on April 08, 2013, 07:15:20 AM
Quote
Hi Gile - Great job.
"With a Little Help from My Friends", thanks.

For now I care less which is specific to "connect points challenge" and I focus more on the implementation of a class with methods more useful.

In case anyone wants to know more about the concepts of parallelism and its use in recursive solutions to various programming problems, like that which was used in this thread, this is highly-recommended reading:
 
   http://blogs.msdn.com/b/pfxteam/archive/2008/01/31/7357135.aspx


Title: Re: Kd Tree
Post by: Jeff H on April 09, 2013, 05:55:51 PM
Hi Gile,
I am coming down with something and hopefully will not get too sick, but if you like I could send you more info tomorrow or the next day when feeling better if you wanted analyze it.
 
 
Title: Re: Kd Tree
Post by: gile on April 12, 2013, 12:09:18 PM
Thanks Jeff, I hope you feel better.

I added some public methods to the Point3dTree class.

This class may be used to improve performances in case of many queries in a quite large amount of points.
It targets the Framwork 4.0 to use some parallelization features.

According to the value of the 'ignoreZ' constructor argument, a 2d tree (ignoreZ = true) or a 3d tree (ignoreZ = false, default).
Use ignoreZ = true if all points in the input collection lie on a plane parallel to XY or if the points have to be considered as projected on the XY plane.

Public methods:
NearestNeighbour(Point3d) Gets the nearest neighbour.
NearestNeighbours(Point3d, int) Gets the n nearest neighbours.
NearestNeighbours(Point3d, double) Gets the nearest neighbours within the distance.
BoxedRange(Point3d, Point3d) Gets the points in a range.
ConnectAll(double) Gets all the pairs of points which distance is less or equal than the specified distance.
The last one was was almost used for performance tests (and to reply to this challenge (http://www.theswamp.org/index.php?topic=32874.msg383557#msg383557)).

Code - C#: [Select]
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Autodesk.AutoCAD.Geometry;
  5.  
  6. namespace PointKdTree
  7. {
  8.     /// <summary>
  9.     /// Node of a Point3dTree
  10.     /// </summary>
  11.     public class TreeNode
  12.     {
  13.         /// <summary>
  14.         /// Creates a new instance of TreeNode
  15.         /// </summary>
  16.         /// <param name="value">The 3d point value of the node.</param>
  17.         public TreeNode(Point3d value) { this.Value = value; }
  18.  
  19.         /// <summary>
  20.         /// Gets the value (Point3d) of the node.
  21.         /// </summary>
  22.         public Point3d Value { get; internal set; }
  23.  
  24.         /// <summary>
  25.         /// Gets the parent node.
  26.         /// </summary>
  27.         public TreeNode Parent { get; internal set; }
  28.  
  29.         /// <summary>
  30.         /// Gets the left child node.
  31.         /// </summary>
  32.         public TreeNode LeftChild { get; internal set; }
  33.  
  34.         /// <summary>
  35.         /// Gets the right child node.
  36.         /// </summary>
  37.         public TreeNode RightChild { get; internal set; }
  38.  
  39.         /// <summary>
  40.         /// Gets the depth of the node in tree.
  41.         /// </summary>
  42.         public int Depth { get; internal set; }
  43.     }
  44.  
  45.     /// <summary>
  46.     /// Provides methods to organize 3d points in a Kd tree structure to speed up the search of neighbours.
  47.     /// A boolean constructor parameter (ignoreZ) indicates if the resulting Kd tree is a 3d tree or a 2d tree.
  48.     /// Use ignoreZ = true if all points in the input collection lie on a plane parallel to XY
  49.     /// or if the points have to be considered as projected on the XY plane.
  50.     /// </summary>
  51.     public class Point3dTree
  52.     {
  53.         #region Private fields
  54.  
  55.         private int dimension;
  56.         private int parallelDepth;
  57.         private bool ignoreZ;
  58.         private Point2d center2d;
  59.         private Point3d center;
  60.         private Point3d current;
  61.         private Func<Point3d, double> distanceTo;
  62.  
  63.         #endregion
  64.  
  65.         #region Constructor
  66.  
  67.         /// <summary>
  68.         /// Creates an new instance of Point3dTree.
  69.         /// </summary>
  70.         /// <param name="points">The Point3d collection to fill the tree.</param>
  71.         /// <param name="ignoreZ">A value indicating if the Z coordinate of points is ignored
  72.         /// (as if all points were projected to the XY plane).</param>
  73.         public Point3dTree(IEnumerable<Point3d> points, bool ignoreZ = false)
  74.         {
  75.             if (points == null)
  76.                 throw new ArgumentNullException("points");
  77.             this.ignoreZ = ignoreZ;
  78.             this.dimension = ignoreZ ? 2 : 3;
  79.             int numProc = System.Environment.ProcessorCount;
  80.             this.parallelDepth = -1;
  81.             while (numProc >> ++this.parallelDepth > 1) ;
  82.             Point3d[] pts = points.Distinct().ToArray();
  83.             this.Root = Construct(pts, 0, null);
  84.         }
  85.  
  86.         #endregion
  87.  
  88.         #region Public properties
  89.  
  90.         /// <summary>
  91.         /// Gets the root node of the tree.
  92.         /// </summary>
  93.         public TreeNode Root { get; private set; }
  94.  
  95.         #endregion
  96.  
  97.         #region Public methods
  98.  
  99.         /// <summary>
  100.         /// Gets the nearest neighbour.
  101.         /// </summary>
  102.         /// <param name="point">The point from which search the nearest neighbour.</param>
  103.         /// <returns>The nearest point in the collection from the specified one.</returns>
  104.         public Point3d NearestNeighbour(Point3d point)
  105.         {
  106.             this.center = point;
  107.             this.current = this.Root.Value;
  108.             if (ignoreZ)
  109.             {
  110.                 center2d = point.Flatten();
  111.                 this.distanceTo = Distance2dTo;
  112.             }
  113.             else
  114.             {
  115.                 this.distanceTo = Distance3dTo;
  116.             }
  117.             return GetNeighbour(this.Root, this.Root.Value, double.MaxValue);
  118.         }
  119.  
  120.         /// <summary>
  121.         /// Gets the neighbours within the specified distance.
  122.         /// </summary>
  123.         /// <param name="point">The point from which search the nearest neighbours.</param>
  124.         /// <param name="radius">The distance in which collect the neighbours.</param>
  125.         /// <returns>The points which distance from the specified point is less or equal to the specified distance.</returns>
  126.         public Point3dCollection NearestNeighbours(Point3d point, double radius)
  127.         {
  128.             Point3dCollection points = new Point3dCollection();
  129.             this.center = point;
  130.             if (this.ignoreZ)
  131.             {
  132.                 this.center2d = point.Flatten();
  133.                 this.distanceTo = Distance2dTo;
  134.             }
  135.             else
  136.             {
  137.                 this.distanceTo = Distance3dTo;
  138.             }
  139.             GetNeighboursAtDistance(radius, this.Root.LeftChild, points);
  140.             GetNeighboursAtDistance(radius, this.Root.RightChild, points);
  141.             return points;
  142.         }
  143.  
  144.         /// <summary>
  145.         /// Gets the n nearest neighbours.
  146.         /// </summary>
  147.         /// <param name="point">The point from which search the nearest neighbours.</param>
  148.         /// <param name="number">The number of points to collect.</param>
  149.         /// <returns>The n nearest neighbours of the specified point.</returns>
  150.         public Point3dCollection NearestNeighbours(Point3d point, int number)
  151.         {
  152.             List<Tuple<double, Point3d>> pairs = new List<Tuple<double, Point3d>>(number);
  153.             this.center = point;
  154.             if (ignoreZ)
  155.             {
  156.                 this.center2d = point.Flatten();
  157.                 this.distanceTo = Distance2dTo;
  158.             }
  159.             else
  160.             {
  161.                 this.distanceTo = Distance3dTo;
  162.             }
  163.             GetNNeighbours(number, double.MaxValue, this.Root, pairs);
  164.             Point3dCollection points = new Point3dCollection();
  165.             for (int i = 0; i < pairs.Count; i++)
  166.             {
  167.                 points.Add(pairs[i].Item2);
  168.             }
  169.             return points;
  170.         }
  171.  
  172.         /// <summary>
  173.         /// Gets the points in a range.
  174.         /// </summary>
  175.         /// <param name="pt1">The first corner of range.</param>
  176.         /// <param name="pt2">The opposite corner of the range.</param>
  177.         /// <returns>All points within the box.</returns>
  178.         public Point3dCollection BoxedRange(Point3d pt1, Point3d pt2)
  179.         {
  180.             Point3d lowerLeft = new Point3d(
  181.                 Math.Min(pt1.X, pt2.X), Math.Min(pt1.Y, pt2.Y), Math.Min(pt1.Z, pt2.Z));
  182.             Point3d upperRight = new Point3d(
  183.                 Math.Max(pt1.X, pt2.X), Math.Max(pt1.Y, pt2.Y), Math.Max(pt1.Z, pt2.Z));
  184.             Point3dCollection points = new Point3dCollection();
  185.             FindRange(lowerLeft, upperRight, this.Root, points);
  186.             return points;
  187.         }
  188.  
  189.         /// <summary>
  190.         /// Gets all the pairs of points which distance is less or equal than the specified distance.
  191.         /// </summary>
  192.         /// <param name="radius">The maximum distance between two points. </param>
  193.         /// <returns>The pairs of points which distance is less or equal than the specified distance.</returns>
  194.         public List<Tuple<Point3d, Point3d>> ConnectAll(double radius)
  195.         {
  196.             List<Tuple<Point3d, Point3d>> connexions = new List<Tuple<Point3d, Point3d>>();
  197.             Stack<TreeNode> nodes = new Stack<TreeNode>();
  198.             GetConnexions(this.Root, radius, connexions, nodes);
  199.             return connexions;
  200.         }
  201.  
  202.         #endregion
  203.  
  204.         #region Private methods
  205.  
  206.         private TreeNode Construct(Point3d[] points, int depth, TreeNode parent)
  207.         {
  208.             int length = points.Length;
  209.             if (length == 0) return null;
  210.             int d = depth % this.dimension;
  211.             Point3d median = points.QuickSelectMedian((p1, p2) => p1[d].CompareTo(p2[d]));
  212.             TreeNode node = new TreeNode(median);
  213.             node.Depth = depth;
  214.             node.Parent = parent;
  215.             int mid = length / 2;
  216.             int rlen = length - mid - 1;
  217.             Point3d[] left = new Point3d[mid];
  218.             Point3d[] right = new Point3d[rlen];
  219.             Array.Copy(points, 0, left, 0, mid);
  220.             Array.Copy(points, mid + 1, right, 0, rlen);
  221.             if (depth < this.parallelDepth)
  222.             {
  223.                 System.Threading.Tasks.Parallel.Invoke(
  224.                    () => node.LeftChild = Construct(left, depth + 1, node),
  225.                    () => node.RightChild = Construct(right, depth + 1, node)
  226.                 );
  227.             }
  228.             else
  229.             {
  230.                 node.LeftChild = Construct(left, depth + 1, node);
  231.                 node.RightChild = Construct(right, depth + 1, node);
  232.             }
  233.             return node;
  234.         }
  235.  
  236.         private Point3d GetNeighbour(TreeNode node, Point3d currentBest, double bestDist)
  237.         {
  238.             if (node == null)
  239.                 return currentBest;
  240.             this.current = node.Value;
  241.             int d = node.Depth % this.dimension;
  242.             double coordCen = center[d];
  243.             double coordCur = current[d];
  244.             double dist = this.distanceTo(this.current);
  245.             if (dist >= 0.0 && dist < bestDist)
  246.             {
  247.                 currentBest = this.current;
  248.                 bestDist = dist;
  249.             }
  250.             if (bestDist >= Math.Abs(coordCen - coordCur))
  251.             {
  252.                 currentBest = GetNeighbour(
  253.                     coordCen < coordCur ? node.LeftChild : node.RightChild, currentBest, bestDist);
  254.                 bestDist = this.distanceTo(currentBest);
  255.             }
  256.             else
  257.             {
  258.                 currentBest = GetNeighbour(node.LeftChild, currentBest, bestDist);
  259.                 bestDist = this.distanceTo(currentBest);
  260.                 currentBest = GetNeighbour(node.RightChild, currentBest, bestDist);
  261.                 bestDist = this.distanceTo(currentBest);
  262.             }
  263.             return currentBest;
  264.         }
  265.  
  266.         private void GetNeighboursAtDistance(double radius, TreeNode node, Point3dCollection points)
  267.         {
  268.             if (node == null) return;
  269.             this.current = node.Value;
  270.             double dist = this.distanceTo(this.current);
  271.             if (dist <= radius)
  272.             {
  273.                 points.Add(this.current);
  274.             }
  275.             int d = node.Depth % this.dimension;
  276.             double coordCen = this.center[d];
  277.             double coordCur = this.current[d];
  278.             if (Math.Abs(coordCen - coordCur) > radius)
  279.             {
  280.                 if (coordCen < coordCur)
  281.                 {
  282.                     GetNeighboursAtDistance(radius, node.LeftChild, points);
  283.                 }
  284.                 else
  285.                 {
  286.                     GetNeighboursAtDistance(radius, node.RightChild, points);
  287.                 }
  288.             }
  289.             else
  290.             {
  291.                 GetNeighboursAtDistance(radius, node.LeftChild, points);
  292.                 GetNeighboursAtDistance(radius, node.RightChild, points);
  293.             }
  294.         }
  295.  
  296.         private void GetNNeighbours(int number, double worstDist, TreeNode node, List<Tuple<double, Point3d>> pairs)
  297.         {
  298.             if (node == null) return;
  299.             this.current = node.Value;
  300.             double dist = this.distanceTo(this.current);
  301.             int cnt = pairs.Count;
  302.             if (cnt < number)
  303.             {
  304.                 pairs.Add(new Tuple<double, Point3d>(dist, this.current));
  305.                 pairs.Sort((p1, p2) => p1.Item1.CompareTo(p2.Item1));
  306.                 worstDist = pairs[cnt].Item1;
  307.             }
  308.             else if (dist < worstDist)
  309.             {
  310.                 pairs.RemoveAt(number - 1);
  311.                 pairs.Add(new Tuple<double, Point3d>(dist, this.current));
  312.                 pairs.Sort((p1, p2) => p1.Item1.CompareTo(p2.Item1));
  313.                 worstDist = pairs[number - 1].Item1;
  314.             }
  315.             int d = node.Depth % this.dimension;
  316.             double coordCen = center[d];
  317.             double coordCur = current[d];
  318.             if (Math.Abs(coordCen - coordCur) > worstDist)
  319.             {
  320.                 if (coordCen < coordCur)
  321.                 {
  322.                     GetNNeighbours(number, worstDist, node.LeftChild, pairs);
  323.                 }
  324.                 else
  325.                 {
  326.                     GetNNeighbours(number, worstDist, node.RightChild, pairs);
  327.                 }
  328.             }
  329.             else
  330.             {
  331.                 GetNNeighbours(number, worstDist, node.LeftChild, pairs);
  332.                 GetNNeighbours(number, pairs[pairs.Count - 1].Item1, node.RightChild, pairs);
  333.             }
  334.         }
  335.  
  336.         private void FindRange(Point3d lowerLeft, Point3d upperRight, TreeNode node, Point3dCollection points)
  337.         {
  338.             if (node == null)
  339.                 return;
  340.             this.current = node.Value;
  341.             if (ignoreZ)
  342.             {
  343.                 if (this.current.X >= lowerLeft.X && this.current.X <= upperRight.X &&
  344.                     this.current.Y >= lowerLeft.Y && this.current.Y <= upperRight.Y)
  345.                     points.Add(this.current);
  346.             }
  347.             else
  348.             {
  349.                 if (this.current.X >= lowerLeft.X && this.current.X <= upperRight.X &&
  350.                     this.current.Y >= lowerLeft.Y && this.current.Y <= upperRight.Y &&
  351.                     this.current.Z >= lowerLeft.Z && this.current.Z <= upperRight.Z)
  352.                     points.Add(this.current);
  353.             }
  354.             int d = node.Depth % this.dimension;
  355.             if (upperRight[d] < this.current[d])
  356.                 FindRange(lowerLeft, upperRight, node.LeftChild, points);
  357.             else if (lowerLeft[d] > this.current[d])
  358.                 FindRange(lowerLeft, upperRight, node.RightChild, points);
  359.             else
  360.             {
  361.                 FindRange(lowerLeft, upperRight, node.LeftChild, points);
  362.                 FindRange(lowerLeft, upperRight, node.RightChild, points);
  363.             }
  364.         }
  365.  
  366.         private void GetConnexions(TreeNode node, double radius, List<Tuple<Point3d, Point3d>> connexions, Stack<TreeNode> nodes)
  367.         {
  368.             if (node == null) return;
  369.             Point3dCollection points = new Point3dCollection();
  370.             this.center = node.Value;
  371.             if (ignoreZ)
  372.             {
  373.                 this.center2d = this.center.Flatten();
  374.                 this.distanceTo = Distance2dTo;
  375.             }
  376.             else
  377.             {
  378.                 this.distanceTo = Distance3dTo;
  379.             }
  380.             foreach (TreeNode tn in nodes)
  381.             {
  382.                 TreeNode parent = tn.Parent;
  383.                 int d = parent.Depth % this.dimension;
  384.                 if (Math.Abs(this.center[d] - parent.Value[d]) <= radius)
  385.                 {
  386.                     GetNeighboursAtDistance(radius, tn, points);
  387.                 }
  388.             }
  389.             GetNeighboursAtDistance(radius, node.LeftChild, points);
  390.             GetNeighboursAtDistance(radius, node.RightChild, points);
  391.             for (int i = 0; i < points.Count; i++)
  392.             {
  393.                 connexions.Add(new Tuple<Point3d, Point3d>(this.center, points[i]));
  394.             }
  395.  
  396.             if (node.RightChild != null)
  397.             {
  398.                 nodes.Push(node.RightChild);
  399.             }
  400.             else if (node.LeftChild == null && nodes.Count > 0)
  401.             {
  402.                 nodes.Pop();
  403.             }
  404.             GetConnexions(node.LeftChild, radius, connexions, nodes);
  405.             GetConnexions(node.RightChild, radius, connexions, nodes);
  406.         }
  407.  
  408.         private double Distance2dTo(Point3d pt)
  409.         {
  410.             return this.center2d.GetDistanceTo(pt.Flatten());
  411.         }
  412.  
  413.         private double Distance3dTo(Point3d pt)
  414.         {
  415.             return this.center.DistanceTo(pt);
  416.         }
  417.  
  418.         #endregion
  419.     }
  420.  
  421.     static class Extensions
  422.     {
  423.         public static Point2d Flatten(this Point3d pt)
  424.         {
  425.             return new Point2d(pt.X, pt.Y);
  426.         }
  427.  
  428.         // Credit: Tony Tanzillo
  429.         // http://www.theswamp.org/index.php?topic=44312.msg495808#msg495808
  430.         public static T QuickSelectMedian<T>(this T[] items, Comparison<T> compare)
  431.         {
  432.             int l = items.Length;
  433.             int k = l / 2;
  434.             if (items == null || l == 0)
  435.                 throw new ArgumentException("array");
  436.             int from = 0;
  437.             int to = l - 1;
  438.             while (from < to)
  439.             {
  440.                 int r = from;
  441.                 int w = to;
  442.                 T current = items[(r + w) / 2];
  443.                 while (r < w)
  444.                 {
  445.                     if (compare(items[r], current) > -1)
  446.                     {
  447.                         var tmp = items[w];
  448.                         items[w] = items[r];
  449.                         items[r] = tmp;
  450.                         w--;
  451.                     }
  452.                     else
  453.                     {
  454.                         r++;
  455.                     }
  456.                 }
  457.                 if (compare(items[r], current) > 0)
  458.                 {
  459.                     r--;
  460.                 }
  461.                 if (k <= r)
  462.                 {
  463.                     to = r;
  464.                 }
  465.                 else
  466.                 {
  467.                     from = r + 1;
  468.                 }
  469.             }
  470.             return items[k];
  471.         }
  472.     }
  473. }
  474.  
Title: Re: Kd Tree
Post by: TheMaster on April 12, 2013, 03:09:57 PM
The last one was was almost used for performance tests (and to reply to this challenge (http://www.theswamp.org/index.php?topic=32874.msg383557#msg383557)).

I read that thread, but has anyone compared your implementation to the C/C++ based ones from that thread?

I also noticed in one post, that someone is using (getvar "DATE") to measure times - which is not very accurate (for sub-second times).
Title: Re: Kd Tree
Post by: gile on April 12, 2013, 04:12:46 PM
I read that thread, but has anyone compared your implementation to the C/C++ based ones from that thread?
I don't think so, this thread is quite old and all the C/C++ part was/is other my head. I remember there was impressive LISP results from Evgeniy.
Title: Re: Kd Tree
Post by: gile on April 13, 2013, 05:32:02 PM
I changed the way to avoid duplicated lines in the ConnectAll. Rather than a boolean 'Processed' property in the TreeNode, I now use a Stack<TreeNode> containing the  right children of the parents of the current node where to search (or not if far enough) for neighbours. This improves a little the 'connect points" part speed.
The tree is now double linked.

Some results:
Code - Text: [Select]
  1. 100K points (128 026 lines)
  2. Get points:      220
  3. Build tree:      114
  4. Connect points:  199
  5. Draw lines:      764
  6. Total:          1297
  7.  
  8. 1 Million points (1 280 741 lines)
  9. Get points:     2118
  10. Build tree:     1388
  11. Connect points: 2107
  12. Draw lines:    12046
  13. Total:         17659

The testing command:
Code - C#: [Select]
  1.         [CommandMethod("CONNECT")]
  2.         public void Connect()
  3.         {
  4.             Document doc = AcAp.DocumentManager.MdiActiveDocument;
  5.             Database db = doc.Database;
  6.             Editor ed = doc.Editor;
  7.             try
  8.             {
  9.                 System.Diagnostics.Stopwatch sw = new System.Diagnostics.Stopwatch();
  10.                 sw.Start();
  11.                 RXClass rxc = RXClass.GetClass(typeof(DBPoint));
  12.                 using (BlockTableRecord btr = (BlockTableRecord)db.CurrentSpaceId.Open(OpenMode.ForWrite))
  13.                 {
  14.                     long t0 = sw.ElapsedMilliseconds;
  15.                     ObjectId[] ids = btr.Cast<ObjectId>().Where(id => id.ObjectClass == rxc).ToArray();
  16.                     int len = ids.Length;
  17.                     Point3d[] pts = new Point3d[len];
  18.                     for (int i = 0; i < len; i++)
  19.                     {
  20.                         using (DBPoint pt = (DBPoint)ids[i].Open(OpenMode.ForRead))
  21.                         {
  22.                             pts[i] = pt.Position;
  23.                         }
  24.                     }
  25.                     long t1 = sw.ElapsedMilliseconds;
  26.                     Point3dTree tree = new Point3dTree(pts, true); // <- true = 2d tree, false = 3d tree
  27.                     long t2 = sw.ElapsedMilliseconds;
  28.                     List<Tuple<Point3d, Point3d>> pairs = tree.ConnectAll(70.0);
  29.                     long t3 = sw.ElapsedMilliseconds;
  30.                     foreach (var pair in pairs)
  31.                     {
  32.                         using (Line line = new Line(pair.Item1, pair.Item2))
  33.                         {
  34.                             btr.AppendEntity(line);
  35.                         }
  36.                     }
  37.                     db.TransactionManager.QueueForGraphicsFlush();
  38.                     sw.Stop();
  39.                     long t4 = sw.ElapsedMilliseconds;
  40.                     ed.WriteMessage(
  41.                         "\nGet points:{0,9}\nBuild tree:{1,9}\nConnect points:{2,5}\nDraw lines:{3,9}\nTotal:{4,14}",
  42.                         t1 - t0, t2 - t1, t3 - t2, t4 - t3, t4);
  43.                 }
  44.             }
  45.             catch (System.Exception ex)
  46.             {
  47.                 ed.WriteMessage("\n{0}\n{1}", ex.Message, ex.StackTrace);
  48.             }
  49.         }
Title: Re: Kd Tree
Post by: TheMaster on April 14, 2013, 01:25:23 PM
I read that thread, but has anyone compared your implementation to the C/C++ based ones from that thread?
I don't think so, this thread is quite old and all the C/C++ part was/is other my head. I remember there was impressive LISP results from Evgeniy.

I would expect that once a dataset is indexed, locating points within a given distance should not be too difficult in any language, since the whole point to indexing the data is to minimize the amount of work needed to find nearest coordinates.

LISP is well-suited to tree-based, recursive algorithms, so it wouldn't surprise me that can find points in an index dataset quickly. But, I would be surprised if he was able to construct the tree as fast as it can be done in .NET or native code. I very much doubt that, if for no other reason, because of the inherent overhead of dealing with immutable lists.
Title: Re: Kd Tree
Post by: gile on April 14, 2013, 01:52:17 PM
Hi Tony,

It seems to me he didn't use a tree structure but a 'divide and conquer' algorithm as shown by this picture (http://divide and conqueer).
By my side, I tried to implement a more reusable (and may be usefull) structure.

I found another way to avoid duplicated connexions with the ConnectAll() method. Rather than storing nodes in a stack, I can get the 'right parents nodes' of the current node using a new property in the TreeNode which indicates if the the node is a left child node or not. This seems to improve a little ConnectAll() speed.

Here's the new implementation.

New version:
- Replaced the using global fields by function parameters so that the public methods can be used in parallel execution.
- Using of square distances

Code - C#: [Select]
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Autodesk.AutoCAD.Geometry;
  5.  
  6. namespace PointKdTree
  7. {
  8.     /// <summary>
  9.     /// Node of a Point3dTree
  10.     /// </summary>
  11.     public class TreeNode
  12.     {
  13.         /// <summary>
  14.         /// Creates a new instance of TreeNode
  15.         /// </summary>
  16.         /// <param name="value">The 3d point value of the node.</param>
  17.         public TreeNode(Point3d value) { this.Value = value; }
  18.  
  19.         /// <summary>
  20.         /// Gets the value (Point3d) of the node.
  21.         /// </summary>
  22.         public Point3d Value { get; internal set; }
  23.  
  24.         /// <summary>
  25.         /// Gets the parent node.
  26.         /// </summary>
  27.         public TreeNode Parent { get; internal set; }
  28.  
  29.         /// <summary>
  30.         /// Gets the left child node.
  31.         /// </summary>
  32.         public TreeNode LeftChild { get; internal set; }
  33.  
  34.         /// <summary>
  35.         /// Gets the right child node.
  36.         /// </summary>
  37.         public TreeNode RightChild { get; internal set; }
  38.  
  39.         /// <summary>
  40.         /// Gets the depth of the node in tree.
  41.         /// </summary>
  42.         public int Depth { get; internal set; }
  43.  
  44.         /// <summary>
  45.         /// Gets a value indicating if the current node is a LeftChild node.
  46.         /// </summary>
  47.         public bool IsLeft { get; internal set; }
  48.     }
  49.  
  50.     /// <summary>
  51.     /// Provides methods to organize 3d points in a Kd tree structure to speed up the search of neighbours.
  52.     /// A boolean constructor parameter (ignoreZ) indicates if the resulting Kd tree is a 3d tree or a 2d tree.
  53.     /// Use ignoreZ = true if all points in the input collection lie on a plane parallel to XY
  54.     /// or if the points have to be considered as projected on the XY plane.
  55.     /// </summary>
  56.     public class Point3dTree
  57.     {
  58.         #region Private fields
  59.  
  60.         private int dimension;
  61.         private int parallelDepth;
  62.         private bool ignoreZ;
  63.         private Func<Point3d, Point3d, double> sqrDist;
  64.  
  65.         #endregion
  66.  
  67.         #region Constructor
  68.  
  69.         /// <summary>
  70.         /// Creates an new instance of Point3dTree.
  71.         /// </summary>
  72.         /// <param name="points">The Point3d collection to fill the tree.</param>
  73.         /// <param name="ignoreZ">A value indicating if the Z coordinate of points is ignored
  74.         /// (as if all points were projected to the XY plane).</param>
  75.         public Point3dTree(IEnumerable<Point3d> points, bool ignoreZ = false)
  76.         {
  77.             if (points == null)
  78.                 throw new ArgumentNullException("points");
  79.             this.ignoreZ = ignoreZ;
  80.             this.dimension = ignoreZ ? 2 : 3;
  81.             if (ignoreZ)
  82.                 this.sqrDist = SqrDistance2d;
  83.             else
  84.                 this.sqrDist = SqrDistance3d;
  85.             int numProc = System.Environment.ProcessorCount;
  86.             this.parallelDepth = -1;
  87.             while (numProc >> ++this.parallelDepth > 1) ;
  88.             Point3d[] pts = points.Distinct().ToArray();
  89.             this.Root = Create(pts, 0, null, false);
  90.         }
  91.  
  92.         #endregion
  93.  
  94.         #region Public properties
  95.  
  96.         /// <summary>
  97.         /// Gets the root node of the tree.
  98.         /// </summary>
  99.         public TreeNode Root { get; private set; }
  100.  
  101.         #endregion
  102.  
  103.         #region Public methods
  104.  
  105.         /// <summary>
  106.         /// Gets the nearest neighbour.
  107.         /// </summary>
  108.         /// <param name="point">The point from which search the nearest neighbour.</param>
  109.         /// <returns>The nearest point in the collection from the specified one.</returns>
  110.         public Point3d NearestNeighbour(Point3d point)
  111.         {
  112.             return GetNeighbour(point, this.Root, this.Root.Value, double.MaxValue);
  113.         }
  114.  
  115.         /// <summary>
  116.         /// Gets the neighbours within the specified distance.
  117.         /// </summary>
  118.         /// <param name="point">The point from which search the nearest neighbours.</param>
  119.         /// <param name="radius">The distance in which collect the neighbours.</param>
  120.         /// <returns>The points which distance from the specified point is less or equal to the specified distance.</returns>
  121.         public Point3dCollection NearestNeighbours(Point3d point, double radius)
  122.         {
  123.             Point3dCollection points = new Point3dCollection();
  124.             GetNeighboursAtDistance(point, radius * radius, this.Root, points);
  125.             return points;
  126.         }
  127.  
  128.         /// <summary>
  129.         /// Gets the given number of nearest neighbours.
  130.         /// </summary>
  131.         /// <param name="point">The point from which search the nearest neighbours.</param>
  132.         /// <param name="number">The number of points to collect.</param>
  133.         /// <returns>The n nearest neighbours of the specified point.</returns>
  134.         public Point3dCollection NearestNeighbours(Point3d point, int number)
  135.         {
  136.             List<Tuple<double, Point3d>> pairs = new List<Tuple<double, Point3d>>(number);
  137.             GetKNeighbours(point, number, this.Root, pairs);
  138.             Point3dCollection points = new Point3dCollection();
  139.             for (int i = 0; i < pairs.Count; i++)
  140.             {
  141.                 points.Add(pairs[i].Item2);
  142.             }
  143.             return points;
  144.         }
  145.  
  146.         /// <summary>
  147.         /// Gets the points in a range.
  148.         /// </summary>
  149.         /// <param name="pt1">The first corner of range.</param>
  150.         /// <param name="pt2">The opposite corner of the range.</param>
  151.         /// <returns>All points within the box.</returns>
  152.         public Point3dCollection BoxedRange(Point3d pt1, Point3d pt2)
  153.         {
  154.             Point3d lowerLeft = new Point3d(
  155.                 Math.Min(pt1.X, pt2.X), Math.Min(pt1.Y, pt2.Y), Math.Min(pt1.Z, pt2.Z));
  156.             Point3d upperRight = new Point3d(
  157.                 Math.Max(pt1.X, pt2.X), Math.Max(pt1.Y, pt2.Y), Math.Max(pt1.Z, pt2.Z));
  158.             Point3dCollection points = new Point3dCollection();
  159.             FindRange(lowerLeft, upperRight, this.Root, points);
  160.             return points;
  161.         }
  162.  
  163.         /// <summary>
  164.         /// Gets all the pairs of points which distance is less or equal than the specified distance.
  165.         /// </summary>
  166.         /// <param name="radius">The maximum distance between two points. </param>
  167.         /// <returns>The pairs of points which distance is less or equal than the specified distance.</returns>
  168.         public List<Tuple<Point3d, Point3d>> ConnectAll(double radius)
  169.         {
  170.             List<Tuple<Point3d, Point3d>> connexions = new List<Tuple<Point3d, Point3d>>();
  171.             GetConnexions(this.Root, radius * radius, connexions);
  172.             return connexions;
  173.         }
  174.  
  175.         #endregion
  176.  
  177.         #region Private methods
  178.  
  179.         private TreeNode Create(Point3d[] points, int depth, TreeNode parent, bool isLeft)
  180.         {
  181.             int length = points.Length;
  182.             if (length == 0) return null;
  183.             int d = depth % this.dimension;
  184.             Point3d median = points.QuickSelectMedian((p1, p2) => p1[d].CompareTo(p2[d]));
  185.             TreeNode node = new TreeNode(median);
  186.             node.Depth = depth;
  187.             node.Parent = parent;
  188.             node.IsLeft = isLeft;
  189.             int mid = length / 2;
  190.             int rlen = length - mid - 1;
  191.             Point3d[] left = new Point3d[mid];
  192.             Point3d[] right = new Point3d[rlen];
  193.             Array.Copy(points, 0, left, 0, mid);
  194.             Array.Copy(points, mid + 1, right, 0, rlen);
  195.             if (depth < this.parallelDepth)
  196.             {
  197.                 System.Threading.Tasks.Parallel.Invoke(
  198.                    () => node.LeftChild = Create(left, depth + 1, node, true),
  199.                    () => node.RightChild = Create(right, depth + 1, node, false)
  200.                 );
  201.             }
  202.             else
  203.             {
  204.                 node.LeftChild = Create(left, depth + 1, node, true);
  205.                 node.RightChild = Create(right, depth + 1, node, false);
  206.             }
  207.             return node;
  208.         }
  209.  
  210.         private Point3d GetNeighbour(Point3d center, TreeNode node, Point3d currentBest, double bestDist)
  211.         {
  212.             if (node == null)
  213.                 return currentBest;
  214.             Point3d current = node.Value;
  215.             int d = node.Depth % this.dimension;
  216.             double coordCen = center[d];
  217.             double coordCur = current[d];
  218.             double dist = this.sqrDist(center, current);
  219.             if (dist >= 0.0 && dist < bestDist)
  220.             {
  221.                 currentBest = current;
  222.                 bestDist = dist;
  223.             }
  224.             dist = coordCen - coordCur;
  225.             if (bestDist < dist * dist)
  226.             {
  227.                 currentBest = GetNeighbour(
  228.                     center, coordCen < coordCur ? node.LeftChild : node.RightChild, currentBest, bestDist);
  229.                 bestDist = this.sqrDist(center, currentBest);
  230.             }
  231.             else
  232.             {
  233.                 currentBest = GetNeighbour(center, node.LeftChild, currentBest, bestDist);
  234.                 bestDist = this.sqrDist(center, currentBest);
  235.                 currentBest = GetNeighbour(center, node.RightChild, currentBest, bestDist);
  236.                 bestDist = this.sqrDist(center, currentBest);
  237.             }
  238.             return currentBest;
  239.         }
  240.  
  241.         private void GetNeighboursAtDistance(Point3d center, double radius, TreeNode node, Point3dCollection points)
  242.         {
  243.             if (node == null) return;
  244.             Point3d current = node.Value;
  245.             double dist = this.sqrDist(center, current);
  246.             if (dist <= radius)
  247.             {
  248.                 points.Add(current);
  249.             }
  250.             int d = node.Depth % this.dimension;
  251.             double coordCen = center[d];
  252.             double coordCur = current[d];
  253.             dist = coordCen - coordCur;
  254.             if (dist * dist > radius)
  255.             {
  256.                 if (coordCen < coordCur)
  257.                 {
  258.                     GetNeighboursAtDistance(center, radius, node.LeftChild, points);
  259.                 }
  260.                 else
  261.                 {
  262.                     GetNeighboursAtDistance(center, radius, node.RightChild, points);
  263.                 }
  264.             }
  265.             else
  266.             {
  267.                 GetNeighboursAtDistance(center, radius, node.LeftChild, points);
  268.                 GetNeighboursAtDistance(center, radius, node.RightChild, points);
  269.             }
  270.         }
  271.  
  272.         private void GetKNeighbours(Point3d center, int number, TreeNode node, List<Tuple<double, Point3d>> pairs)
  273.         {
  274.             if (node == null) return;
  275.             Point3d current = node.Value;
  276.             double dist = this.sqrDist(center, current);
  277.             int cnt = pairs.Count;
  278.             if (cnt == 0)
  279.             {
  280.                 pairs.Add(new Tuple<double, Point3d>(dist, current));
  281.             }
  282.             else if (cnt < number)
  283.             {
  284.                 if (dist > pairs[0].Item1)
  285.                 {
  286.                     pairs.Insert(0, new Tuple<double, Point3d>(dist, current));
  287.                 }
  288.                 else
  289.                 {
  290.                     pairs.Add(new Tuple<double, Point3d>(dist, current));
  291.                 }
  292.             }
  293.             else if (dist < pairs[0].Item1)
  294.             {
  295.                 pairs[0] = new Tuple<double, Point3d>(dist, current);
  296.                 pairs.Sort((p1, p2) => p2.Item1.CompareTo(p1.Item1));
  297.             }
  298.             int d = node.Depth % this.dimension;
  299.             double coordCen = center[d];
  300.             double coordCur = current[d];
  301.             dist = coordCen - coordCur;
  302.             if (dist * dist > pairs[0].Item1)
  303.             {
  304.                 if (coordCen < coordCur)
  305.                 {
  306.                     GetKNeighbours(center, number, node.LeftChild, pairs);
  307.                 }
  308.                 else
  309.                 {
  310.                     GetKNeighbours(center, number, node.RightChild, pairs);
  311.                 }
  312.             }
  313.             else
  314.             {
  315.                 GetKNeighbours(center, number, node.LeftChild, pairs);
  316.                 GetKNeighbours(center, number, node.RightChild, pairs);
  317.             }
  318.         }
  319.  
  320.         private void FindRange(Point3d lowerLeft, Point3d upperRight, TreeNode node, Point3dCollection points)
  321.         {
  322.             if (node == null)
  323.                 return;
  324.             Point3d current = node.Value;
  325.             if (ignoreZ)
  326.             {
  327.                 if (current.X >= lowerLeft.X && current.X <= upperRight.X &&
  328.                     current.Y >= lowerLeft.Y && current.Y <= upperRight.Y)
  329.                     points.Add(current);
  330.             }
  331.             else
  332.             {
  333.                 if (current.X >= lowerLeft.X && current.X <= upperRight.X &&
  334.                     current.Y >= lowerLeft.Y && current.Y <= upperRight.Y &&
  335.                     current.Z >= lowerLeft.Z && current.Z <= upperRight.Z)
  336.                     points.Add(current);
  337.             }
  338.             int d = node.Depth % this.dimension;
  339.             if (upperRight[d] < current[d])
  340.                 FindRange(lowerLeft, upperRight, node.LeftChild, points);
  341.             else if (lowerLeft[d] > current[d])
  342.                 FindRange(lowerLeft, upperRight, node.RightChild, points);
  343.             else
  344.             {
  345.                 FindRange(lowerLeft, upperRight, node.LeftChild, points);
  346.                 FindRange(lowerLeft, upperRight, node.RightChild, points);
  347.             }
  348.         }
  349.  
  350.         private void GetConnexions(TreeNode node, double radius, List<Tuple<Point3d, Point3d>> connexions)
  351.         {
  352.             if (node == null) return;
  353.             Point3dCollection points = new Point3dCollection();
  354.             Point3d center = node.Value;
  355.             if (ignoreZ)
  356.             GetRightParentsNeighbours(center, node, radius, points);
  357.             GetNeighboursAtDistance(center, radius, node.LeftChild, points);
  358.             GetNeighboursAtDistance(center, radius, node.RightChild, points);
  359.             for (int i = 0; i < points.Count; i++)
  360.             {
  361.                 connexions.Add(new Tuple<Point3d, Point3d>(center, points[i]));
  362.             }
  363.             GetConnexions(node.LeftChild, radius, connexions);
  364.             GetConnexions(node.RightChild, radius, connexions);
  365.         }
  366.  
  367.         private void GetRightParentsNeighbours(Point3d center, TreeNode node, double radius, Point3dCollection points)
  368.         {
  369.             TreeNode parent = GetRightParent(node);
  370.             if (parent == null) return;
  371.             int d = parent.Depth % this.dimension;
  372.             double dist = center[d] - parent.Value[d];
  373.             if (dist * dist <= radius)
  374.             {
  375.                 GetNeighboursAtDistance(center, radius, parent.RightChild, points);
  376.             }
  377.             GetRightParentsNeighbours(center, parent, radius, points);
  378.         }
  379.  
  380.         private TreeNode GetRightParent(TreeNode node)
  381.         {
  382.             TreeNode parent = node.Parent;
  383.             if (parent == null) return null;
  384.             if (node.IsLeft) return parent;
  385.             return GetRightParent(parent);
  386.         }
  387.  
  388.         private double SqrDistance2d(Point3d p1, Point3d p2)
  389.         {
  390.             return (p1.X - p2.X) * (p1.X - p2.X) +
  391.                 (p1.Y - p2.Y) * (p1.Y - p2.Y);
  392.         }
  393.  
  394.         private double SqrDistance3d(Point3d p1, Point3d p2)
  395.         {
  396.             return (p1.X - p2.X) * (p1.X - p2.X) +
  397.                 (p1.Y - p2.Y) * (p1.Y - p2.Y) +
  398.                 (p1.Z - p2.Z) * (p1.Z - p2.Z);
  399.         }
  400.  
  401.         #endregion
  402.     }
  403.  
  404.     static class Extensions
  405.     {
  406.         // Credit: Tony Tanzillo
  407.         // http://www.theswamp.org/index.php?topic=44312.msg495808#msg495808
  408.         public static T QuickSelectMedian<T>(this T[] items, Comparison<T> compare)
  409.         {
  410.             int l = items.Length;
  411.             int k = l / 2;
  412.             if (items == null || l == 0)
  413.                 throw new ArgumentException("array");
  414.             int from = 0;
  415.             int to = l - 1;
  416.             while (from < to)
  417.             {
  418.                 int r = from;
  419.                 int w = to;
  420.                 T current = items[(r + w) / 2];
  421.                 while (r < w)
  422.                 {
  423.                     if (compare(items[r], current) > -1)
  424.                     {
  425.                         var tmp = items[w];
  426.                         items[w] = items[r];
  427.                         items[r] = tmp;
  428.                         w--;
  429.                     }
  430.                     else
  431.                     {
  432.                         r++;
  433.                     }
  434.                 }
  435.                 if (compare(items[r], current) > 0)
  436.                 {
  437.                     r--;
  438.                 }
  439.                 if (k <= r)
  440.                 {
  441.                     to = r;
  442.                 }
  443.                 else
  444.                 {
  445.                     from = r + 1;
  446.                 }
  447.             }
  448.             return items[k];
  449.         }
  450.     }
  451. }
  452.  
Title: Re: Kd Tree
Post by: TheMaster on April 14, 2013, 02:59:15 PM
Hi Tony,

It seems to me he didn't use a tree structure but a 'divide and conquer' algorithm as shown by this picture (http://divide and conqueer).
By my side, I tried to implement a more reusable (and may be usefull) structure.

I found another way to avoid duplicated connexions with the ConnectAll() method. Rather than storing nodes in a stack, I can get the 'right parents nodes' of the current node using a new property in the TreeNode which indicates if the the node is a left child node or not. This seems to improve a little ConnectAll() speed.


Hi Gile. Very good.

Here's one more suggestion: Depending on a number of factors like point distribution, there can be a very large number of leaf nodes in the tree (leaf nodes are nodes with no child nodes). You can save some memory by not having fields for child nodes on leaf nodes, by using two different classes for leaf and branch nodes, using something like this:

Code - C#: [Select]
  1.  
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6.  
  7. namespace Graph.Nodes
  8. {
  9.  
  10.    /// Use for leaf nodes with no children:
  11.    /// (contains no fields for child nodes)
  12.    
  13.    public class Node<T>
  14.    {
  15.       protected Node( T value )
  16.       {
  17.          this.Value = value;
  18.       }
  19.  
  20.       public T Value
  21.       {
  22.          get;
  23.          set;
  24.       }
  25.  
  26.       public virtual Node<T> Left
  27.       {
  28.          get
  29.          {
  30.             return null;
  31.          }
  32.          protected set
  33.          {
  34.          }
  35.       }
  36.  
  37.       public virtual Node<T> Right
  38.       {
  39.          get
  40.          {
  41.             return null;
  42.          }
  43.          protected set
  44.          {
  45.          }
  46.       }
  47.  
  48.       // Create leaf or branch node depending on
  49.       // if child nodes are provided:
  50.      
  51.       public static Node<T> Create( T value, Node<T> left = null, Node<T> right = null )
  52.       {
  53.          if( left == null && right == null )
  54.             return new Node<T>( value );
  55.          else
  56.             return new BranchNode<T>( value, left, right );
  57.       }
  58.    }
  59.    
  60.    /// Used for branch nodes with 1 or 2 children
  61.  
  62.    public class BranchNode<T> : Node<T>
  63.    {
  64.       Node<T> left = null;
  65.       Node<T> right = null;
  66.  
  67.       protected internal BranchNode( T value, Node<T> left, Node<T> right )
  68.          : base( value )
  69.       {
  70.          this.left = left;
  71.          this.right = right;
  72.       }
  73.  
  74.       public override Node<T> Left
  75.       {
  76.          get
  77.          {
  78.             return left;
  79.          }
  80.          protected set
  81.          {
  82.             left = value;
  83.          }
  84.       }
  85.  
  86.       public override Node<T> Right
  87.       {
  88.          get
  89.          {
  90.             return right;
  91.          }
  92.          protected set
  93.          {
  94.             right = value;
  95.          }
  96.       }
  97.    }
  98.  
  99. }
  100.  
  101.  
Title: Re: Kd Tree
Post by: gile on April 14, 2013, 04:46:21 PM
Thanks Tony, I'll give a try.
As the tree is balanced, there would never be more than half number of points plus one leaf nodes.
Title: Re: Kd Tree
Post by: Jerry J on April 16, 2013, 04:07:34 AM
Thanks for stretching my peanut of a brain.
Title: Re: Kd Tree
Post by: gile on April 19, 2013, 11:32:50 AM
I made some changes in the upper code.
I replaced the using of fields (global) by method parameters so that the public method can run in parallel execution.
I replaced the DistanceTo/GetDistanceTo methods by square distances.
Title: Re: Kd Tree
Post by: WILL HATCH on April 19, 2013, 12:16:44 PM
Great work gile!
Title: Re: Kd Tree
Post by: gile on April 26, 2013, 02:47:20 PM
Here's the F# version.
The main interest is that it allows parallel execution for building the tree while targeting the Framwork 2.0 so that it can be used with AutoCAD version form 2007 to 2011.

The F# specific constructs and types are only used internally.
The public methods (the same as in the C# version) are writen to be used with other .NET languages (C# or VB) they return Point3d arrays rather than Point3dCollections. A 'Pair' type (class) is used to replace the missing Tuple class in .NET Frameworks prior to 4.0.

Attached the FsPoint3dTree.dll which can be referenced in C# or VB projects (it may be needed to install the F# runtime (http://msdn.microsoft.com/en-us/library/ee829875%28v=vs.100%29.aspx)).

Note: I can't understand why, but the building of the tree runs about ten times slower in A 2012 and later than in prior versions, so for using with A2012 or later, use the C# version instead.

Code - F#: [Select]
  1. namespace Gile.Point3dTree
  2.  
  3. open System
  4. open Autodesk.AutoCAD.Geometry
  5.  
  6. module private Array =
  7.     let median (compare: 'a -> 'a -> int) (items: 'a[])=
  8.        if items = null || items.Length = 0 then
  9.            failwith "items"
  10.        let l = items.Length
  11.        let k = l / 2
  12.  
  13.        let rec loop f t =
  14.            if f < t then swap f t items.[(f+t) / 2] f t
  15.            else items.[.. k-1], items.[k], items.[k+1 ..]
  16.        and swap a b c f t =
  17.            if a < b then
  18.                if compare items.[a] c > -1 then
  19.                    let tmp = items.[b]
  20.                    items.[b] <- items.[a]
  21.                    items.[a] <- tmp
  22.                    swap a (b-1) c f t
  23.                else
  24.                    swap (a+1) b c f t
  25.            else
  26.                let n = if compare items.[a] c > 0 then a - 1 else a
  27.                if k <= n
  28.                then loop f n
  29.                else loop (n+1) t
  30.  
  31.        loop 0 (l-1)
  32.  
  33. type private TreeNode =
  34.    | Empty
  35.    | Node of int * Point3d option * Point3d * TreeNode * TreeNode  
  36.  
  37. /// <summary>
  38. /// Defines a tuple (double) to be used with versions of .NET Framework prior to 4.0
  39. /// </summary>
  40. /// <typeparam name="T1">Type of the first item.</typeparam>
  41. /// <typeparam name="T2">Type of the second item.</typeparam>
  42. /// <param name="item1">First item.</param>
  43. /// <param name="item2">Second item.</param>
  44. type Pair<'T1, 'T2>(item1, item2) =
  45.  
  46.    /// <summary>
  47.    /// Gets the first item of the pair.
  48.    /// </summary>
  49.    member this.Item1 with get() = item1
  50.  
  51.    /// <summary>
  52.    /// Gets the second item of the pair.
  53.    /// </summary>
  54.    member this.Item2 with get() = item2
  55.  
  56.    /// <summary>
  57.    /// Creates a new instance of Pair.
  58.    /// </summary>
  59.    /// <param name="item1">First item of the pair.</param>
  60.    /// <param name="item2">Second item of the pair.</param>
  61.    /// <returns>a new Pair containing the items.</returns>
  62.    static member Create(item1, item2) = Pair(item1, item2)
  63.  
  64.  
  65. /// <summary>
  66. /// Creates an new instance of Point3dTree.
  67. /// </summary>
  68. /// <param name="points">The Point3d collection to fill the tree.</param>
  69. /// <param name="ignoreZ">A value indicating if the Z coordinate of points is ignored
  70. /// (as if all points were projected to the XY plane).</param>
  71. type Point3dTree (points: Point3d seq, ignoreZ: bool) =
  72.    do if points = null then raise (System.ArgumentNullException("points"))
  73.  
  74.    let dimension = if ignoreZ then 2 else 3
  75.    let sqrDist (p1: Point3d) (p2: Point3d) =
  76.        if ignoreZ
  77.        then (p1.X - p2.X) * (p1.X - p2.X) + (p1.Y - p2.Y) * (p1.Y - p2.Y)
  78.        else (p1.X - p2.X) * (p1.X - p2.X) + (p1.Y - p2.Y) * (p1.Y - p2.Y) + (p1.Z - p2.Z) * (p1.Z - p2.Z)
  79.  
  80.    let rec shift n d =
  81.        if n >>> d > 1 then shift n (d+1) else d
  82.    let pDepth = shift System.Environment.ProcessorCount 0
  83.  
  84.    let create pts =
  85.        let rec loop depth parent = function
  86.            | [||] -> Empty
  87.            | pts ->
  88.                let d = depth % dimension
  89.                let left, median, right =
  90.                    pts |> Array.median(fun (p1: Point3d) p2 -> compare p1.[d] p2.[d])
  91.                let children =
  92.                    if depth < pDepth then
  93.                        [ async { return loop (depth + 1) (Some(median)) left };
  94.                          async { return loop (depth + 1) (Some(median)) right } ]
  95.                        |> Async.Parallel
  96.                        |> Async.RunSynchronously
  97.                    else
  98.                        [| loop (depth + 1) (Some(median)) left;
  99.                           loop (depth + 1) (Some(median)) right |]
  100.                Node(depth, parent, median, children.[0], children.[1])
  101.        loop 0 None pts
  102.  
  103.    let root = points |> Seq.distinct |> Seq.toArray |> create
  104.  
  105.    let rec findNeighbour location node (current, bestDist) =
  106.        match node with
  107.        | Empty -> (current, bestDist)
  108.        | Node(depth, _, point, left, right) ->
  109.            let dist = sqrDist point location
  110.            let d = depth % dimension
  111.            let bestPair =
  112.                if dist < bestDist
  113.                then point, dist
  114.                else current, bestDist
  115.            if bestDist < (location.[d] - point.[d]) * (location.[d] - point.[d]) then
  116.                findNeighbour location (if location.[d] < point.[d] then left else right) bestPair
  117.            else
  118.                findNeighbour location left bestPair
  119.                |> findNeighbour location right
  120.  
  121.    let rec getNeighbours center radius node acc =
  122.        match node with
  123.        | Empty -> acc
  124.        | Node(depth, _, point, left, right) ->
  125.            let acc = if sqrDist center point <= radius then point :: acc else acc
  126.            let d= depth % dimension;
  127.            let coordCen, coordPt = center.[d], point.[d]
  128.            if (coordCen - coordPt) * (coordCen - coordPt) > radius then
  129.                getNeighbours center radius (if coordCen < coordPt then left else right) acc
  130.            else
  131.                getNeighbours center radius left acc
  132.                |> getNeighbours center radius right
  133.  
  134.    let rec getKNeighbours center number node (pairs: (float * Point3d) list) =
  135.        match node with
  136.        | Empty -> pairs
  137.        | Node(depth, _, point, left, right) ->
  138.            let dist = sqrDist center point
  139.            let pairs =
  140.                match pairs.Length with
  141.                | 0 -> [ (dist, point) ]
  142.                | l when l < number ->
  143.                    if (dist > fst pairs.Head)
  144.                    then (dist, point) :: pairs
  145.                    else pairs.Head :: (dist, point) :: pairs.Tail
  146.                | _ ->
  147.                    if dist < fst pairs.Head
  148.                    then ((dist, point) :: pairs.Tail) |> List.sortBy(fun p -> -fst p)
  149.                    else pairs
  150.            let d = depth % dimension
  151.            let coordCen, coordCur = center.[d], point.[d]
  152.            if (coordCen - coordCur) * (coordCen - coordCur) > fst pairs.Head then
  153.                getKNeighbours center number (if coordCen < coordCur then left else right) pairs
  154.            else
  155.                getKNeighbours center number left pairs
  156.                |> getKNeighbours center number right
  157.  
  158.    let rec findRange (lowerLeft: Point3d) (upperRight: Point3d) node (acc: Point3d list) =
  159.        match node with
  160.        | Empty -> acc
  161.        | Node(depth, _, point, left, right) ->
  162.            let acc =
  163.                if ignoreZ then
  164.                    if point.X >= lowerLeft.X && point.X <= upperRight.X &&
  165.                       point.Y >= lowerLeft.Y && point.Y <= upperRight.Y
  166.                    then point :: acc else acc
  167.                else
  168.                    if point.X >= lowerLeft.X && point.X <= upperRight.X &&
  169.                       point.Y >= lowerLeft.Y && point.Y <= upperRight.Y &&
  170.                       point.Z >= lowerLeft.Z && point.Z <= upperRight.Z
  171.                    then point :: acc else acc
  172.            let d = depth % dimension
  173.            if upperRight.[d] < point.[d]
  174.            then findRange lowerLeft upperRight left acc
  175.            elif lowerLeft.[d] > point.[d]
  176.            then findRange lowerLeft upperRight right acc
  177.            else findRange lowerLeft upperRight left acc
  178.                 |> findRange lowerLeft upperRight right
  179.                
  180.    let rec getConnexions node radius (nodes, pairs) =
  181.        match node with
  182.        | Empty -> (nodes, pairs)
  183.        | Node(depth, parent, point, left, right) ->
  184.            let pairs =
  185.                nodes
  186.                |> List.fold(fun acc (n: TreeNode) ->
  187.                    match n with
  188.                    | Empty -> acc
  189.                    | Node(dep, par, _, _, _) ->
  190.                        let pt = par.Value
  191.                        let d = (dep + 1) % dimension
  192.