Author Topic: Kd Tree  (Read 20503 times)

0 Members and 1 Guest are viewing this topic.

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Kd Tree
« on: April 01, 2013, 05:08:56 PM »
Hi,

I'm trying to implement a .NET Kd Tree for AutoCAD Point3d with a NearestNeighbours method.
I tried my class with this (old but successfull) challenge.
Using my Point3dTree is not very faster than the quite naive algorithm I posted in reply #31 with 10000 points (
notwithstanding 3 times faster with 100000 points).
Is this due to a bad implementation of the Kd Tree or to the fact this kind of data structure is not so good to solve this type of challenge ?

Code - C#: [Select]
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Autodesk.AutoCAD.Geometry;
  5.  
  6. namespace PointKdTree
  7. {
  8.     public class KdTreeNode<T>
  9.     {
  10.         public T Value { get; internal set; }
  11.         public KdTreeNode<T> LeftChild { get; internal set; }
  12.         public KdTreeNode<T> RightChild { get; internal set; }
  13.         public int Depth { get; internal set; }
  14.         public bool Processed { get; set; }
  15.  
  16.         public KdTreeNode(T value) { this.Value = value; }
  17.     }
  18.  
  19.     public class Point3dTree
  20.     {
  21.         private int dimension;
  22.         public KdTreeNode<Point3d> Root { get; private set; }
  23.  
  24.         public Point3dTree(IEnumerable<Point3d> points) : this(points, false) { }
  25.  
  26.         public Point3dTree(IEnumerable<Point3d> points, bool ignoreZ)
  27.         {
  28.             if (points == null)
  29.                 throw new ArgumentNullException("points");
  30.             this.dimension = ignoreZ ? 2 : 3;
  31.             Point3d[] pts = points.Distinct().ToArray();
  32.             this.Root = Construct(pts, 0);
  33.         }
  34.  
  35.         public List<Point3d> NearestNeighbours(KdTreeNode<Point3d> node, double radius)
  36.         {
  37.             if (node == null)
  38.                 throw new ArgumentNullException("node");
  39.             List<Point3d> result = new List<Point3d>();
  40.             GetNeighbours(node.Value, radius, this.Root.LeftChild, ref result);
  41.             GetNeighbours(node.Value, radius, this.Root.RightChild, ref result);
  42.             return result;
  43.         }
  44.  
  45.         private void GetNeighbours(Point3d center, double radius, KdTreeNode<Point3d> node, ref List<Point3d> result)
  46.         {
  47.             if (node == null) return;
  48.             Point3d pt = node.Value;
  49.             if (!node.Processed && center.DistanceTo(pt) <= radius)
  50.             {
  51.                 result.Add(pt);
  52.             }
  53.             int d = node.Depth % this.dimension;
  54.             double coordCen = center[d];
  55.             double coordPt = pt[d];
  56.             if (Math.Abs(coordCen - coordPt) > radius)
  57.             {
  58.                 if (coordCen < coordPt)
  59.                 {
  60.                     GetNeighbours(center, radius, node.LeftChild, ref result);
  61.                 }
  62.                 else
  63.                 {
  64.                     GetNeighbours(center, radius, node.RightChild, ref result);
  65.                 }
  66.             }
  67.             else
  68.             {
  69.                 GetNeighbours(center, radius, node.LeftChild, ref result);
  70.                 GetNeighbours(center, radius, node.RightChild, ref result);
  71.             }
  72.         }
  73.  
  74.         private KdTreeNode<Point3d> Construct(Point3d[] points, int depth)
  75.         {
  76.             int length = points.Length;
  77.             if (length == 0) return null;
  78.             int d = depth % this.dimension;
  79.             Array.Sort(points, (p1, p2) => p1[d].CompareTo(p2[d]));
  80.             int med = length / 2;
  81.             KdTreeNode<Point3d> node = new KdTreeNode<Point3d>(points[med]);
  82.             node.Depth = depth;
  83.             node.LeftChild = Construct(points.Take(med).ToArray(), depth + 1);
  84.             node.RightChild = Construct(points.Skip(med + 1).ToArray(), depth + 1);
  85.             return node;
  86.         }
  87.     }
  88. }

I also improve the speed for drawing lines by using the ObjectId.Open() method instead of using a transaction.
Code - C#: [Select]
  1. using System.Collections.Generic;
  2. using System.Linq;
  3. using Autodesk.AutoCAD.ApplicationServices;
  4. using Autodesk.AutoCAD.DatabaseServices;
  5. using Autodesk.AutoCAD.EditorInput;
  6. using Autodesk.AutoCAD.Geometry;
  7. using Autodesk.AutoCAD.Runtime;
  8. using AcAp = Autodesk.AutoCAD.ApplicationServices.Application;
  9.  
  10. [assembly: CommandClass(typeof(PointKdTree.CommandMethods))]
  11.  
  12. namespace PointKdTree
  13. {
  14.     public class CommandMethods
  15.     {
  16.         struct Point3dPair
  17.         {
  18.             public readonly Point3d Start;
  19.             public readonly Point3d End;
  20.  
  21.             public Point3dPair(Point3d start, Point3d end)
  22.             {
  23.                 this.Start = start;
  24.                 this.End = end;
  25.             }
  26.         }
  27.  
  28.         [CommandMethod("Connect")]
  29.         public void Connect()
  30.         {
  31.             System.Diagnostics.Stopwatch sw = new System.Diagnostics.Stopwatch();
  32.             sw.Start();
  33.             Document doc = AcAp.DocumentManager.MdiActiveDocument;
  34.             Database db = doc.Database;
  35.             Editor ed = doc.Editor;
  36.             RXClass rxc = RXClass.GetClass(typeof(DBPoint));
  37.             using (BlockTableRecord btr = (BlockTableRecord)db.CurrentSpaceId.Open(OpenMode.ForWrite))
  38.             {
  39.                 var pts = btr.Cast<ObjectId>()
  40.                     .Where(id => id.ObjectClass == rxc)
  41.                     .Select(id =>
  42.                     {
  43.                         using (DBPoint pt = (DBPoint)id.Open(OpenMode.ForRead))
  44.                         { return pt.Position; }
  45.                     });
  46.                 Point3dTree tree = new Point3dTree(pts, true);
  47.                 List<Point3dPair> pairs = new List<Point3dPair>();
  48.                 ConnectPoints(tree, tree.Root, 70.0, ref pairs);
  49.                 foreach (Point3dPair pair in pairs)
  50.                 {
  51.                     using (Line line = new Line(pair.Start, pair.End))
  52.                     {
  53.                         btr.AppendEntity(line);
  54.                     }
  55.                 }
  56.             }
  57.             sw.Stop();
  58.             ed.WriteMessage("\nElapsed milliseconds: {0}", sw.ElapsedMilliseconds);
  59.         }
  60.  
  61.         private void ConnectPoints(Point3dTree tree, KdTreeNode<Point3d> node, double dist, ref List<Point3dPair> pointPairs)
  62.         {
  63.             if (node == null) return;
  64.             node.Processed = true;
  65.             Point3d center = node.Value;
  66.             foreach (Point3d pt in tree.NearestNeighbours(node, dist))
  67.             {
  68.                 pointPairs.Add(new Point3dPair(center, pt));
  69.             }
  70.             ConnectPoints(tree, node.LeftChild, dist, ref pointPairs);
  71.             ConnectPoints(tree, node.RightChild, dist, ref pointPairs);
  72.         }
  73.     }
  74. }
Speaking English as a French Frog

TheMaster

  • Guest
Re: Kd Tree
« Reply #1 on: April 01, 2013, 05:21:03 PM »
Hi Gile. You could try using some execution profiling tools to identify bottlenecks, and I can't tell you much about the overall performance characteristics without some analysis, but just skimming your code, this sticks out:

Code - C#: [Select]
  1.  
  2.     node.LeftChild = Construct(points.Take(med).ToArray(), depth + 1);
  3.     node.RightChild = Construct(points.Skip(med + 1).ToArray(), depth + 1);
  4.  
  5.  

In this case, I would probably not use Linq, because Take/Skip(...).ToArray() can't produce a result in a single allocation, and will have to do it incrementally (starting with 4 elements, and doubling it each time more capacity is needed), which depending on how many elements there are, can be a major performance hit.

I would use Array.Copy() or Array.Constrained.Copy() instead.

And one more possible optimization:

Code - C#: [Select]
  1.  
  2.     Array.Sort(points, (p1, p2) => p1[d].CompareTo(p2[d]));
  3.  
  4.  

The above (from your Construct() method), appears to be redundant, since Construct() is recursive, and each call is sorting its argument, which for all but the outer-most call, should not be necessary.

« Last Edit: April 01, 2013, 05:50:34 PM by TT »

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Re: Kd Tree
« Reply #2 on: April 01, 2013, 06:07:29 PM »
Thanks Tony,

Quote
I would use Array.Copy() or Array.Constrained.Copy() instead.
That makes sense.

Quote
The above (from your Construct() method), appears to be redundant, since Construct() is recursive, and each call is sorting its argument, which for all but the outer-most call, should not be necessary.
I do not agree. At each call 'd' is changing according to the depth in tree and the tree dimension (2d or 3d):
first call: d = 0, the points are sorted by X,
second  call: d = 1 => the points are sorted b Y
third call: d = 0 if dimension = 2 (ignore Z) => the points are sorted by X or d = 2 if dimension = 3 => the points are sorted by Z
Speaking English as a French Frog

TheMaster

  • Guest
Re: Kd Tree
« Reply #3 on: April 01, 2013, 08:03:21 PM »
Thanks Tony,

Quote
I would use Array.Copy() or Array.Constrained.Copy() instead.
That makes sense.

Quote
The above (from your Construct() method), appears to be redundant, since Construct() is recursive, and each call is sorting its argument, which for all but the outer-most call, should not be necessary.
I do not agree. At each call 'd' is changing according to the depth in tree and the tree dimension (2d or 3d):
first call: d = 0, the points are sorted by X,
second  call: d = 1 => the points are sorted b Y
third call: d = 0 if dimension = 2 (ignore Z) => the points are sorted by X or d = 2 if dimension = 3 => the points are sorted by Z

Hi Gile. Yes, I didn't look close enough, and was thinking 'b-tree', not 'kd-tree' :laugh:


pkohut

  • Bull Frog
  • Posts: 483
Re: Kd Tree
« Reply #4 on: April 02, 2013, 01:05:34 AM »
Hi Gile,

Profiling will probably reveal the bottleneck to be in the Construct method, and the call to sort being a big contributor.
Code: [Select]
Array.Sort(points, (p1, p2) => p1[d].CompareTo(p2[d]));Change it to a lighter weight partial sort, like C++'s std::nth_element (http://stackoverflow.com/questions/2540602/does-c-sharp-have-a-stdnth-element-equivalent)


Here is a code snippet from http://www.theswamp.org/index.php?topic=32874.msg401826#msg401826, which is similar to your Construct method.
Code: [Select]
    void OptimizeKdTree(typename std::vector<NODE_POINT<T, DIM>>::iterator & i1,
        typename std::vector<NODE_POINT<T, DIM>>::iterator & i2,
        const size_t dir)
    {
        if(i1 == i2)
            return;

        // Create instance of the compare function and set with the KDTree direction
        NodeCompare<NODE_POINT<T, DIM>, DIM> compare(dir % DIM);

        // Find the halfway mark between the 2 iterators
        std::vector<NODE_POINT<T, DIM>>::iterator iMid = i1 + (i2 - i1) / 2;

        /*
        * rearrange so elements left of iMid are less than iMid.pos
        * and elements right of iMid are greater than iMid.pos
        */
        std::nth_element(i1, iMid, i2, compare);

        // Insert the iMid into the KDTree
        Insert(iMid->Data(), 0);

        // Continue until no more elements to insert
        if(iMid != i1)
            OptimizeKdTree(i1, iMid, dir + 1);
        if(++iMid != i2)
            OptimizeKdTree(iMid, i2, dir + 1);
    }
New tread (not retired) - public repo at https://github.com/pkohut

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Re: Kd Tree
« Reply #5 on: April 02, 2013, 01:32:40 AM »
Thanks Paul.

That makes sense, the issue now is: am I able to build a quick select routine faster than Array.Sort() ?
Speaking English as a French Frog

TheMaster

  • Guest
Re: Kd Tree
« Reply #6 on: April 02, 2013, 06:35:00 AM »
Thanks Paul.

That makes sense, the issue now is: am I able to build a quick select routine faster than Array.Sort() ?

Purely out of curiosity, I tried to parallelize your Construct() method, but only saw a minor improvement of around 20-30%.

The search can also be parallelized, but it requires major surgery (it needs to return results rather than add them to a single list, or you have to use a collection from System.Collections.Concurrent).

Also, you don't need to use 'ref' parameters to pass and use reference types, like List<T>. Use 'ref' only when you're going to actually change the value of the variable in the caller, rather than operate on an instance of a passed reference type. See my note in the code.

Code - C#: [Select]
  1.  
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Autodesk.AutoCAD.ApplicationServices;
  6. using Autodesk.AutoCAD.DatabaseServices;
  7. using Autodesk.AutoCAD.EditorInput;
  8. using Autodesk.AutoCAD.Geometry;
  9. using Autodesk.AutoCAD.Runtime;
  10. using AcAp = Autodesk.AutoCAD.ApplicationServices.Application;
  11. using System.Diagnostics;
  12.  
  13. namespace PointKdTree
  14. {
  15.    public class KdTreeNode<T>
  16.    {
  17.       public T Value
  18.       {
  19.          get;
  20.          internal set;
  21.       }
  22.       public KdTreeNode<T> LeftChild
  23.       {
  24.          get;
  25.          internal set;
  26.       }
  27.       public KdTreeNode<T> RightChild
  28.       {
  29.          get;
  30.          internal set;
  31.       }
  32.       public int Depth
  33.       {
  34.          get;
  35.          internal set;
  36.       }
  37.       public bool Processed
  38.       {
  39.          get;
  40.          set;
  41.       }
  42.  
  43.       public KdTreeNode( T value )
  44.       {
  45.          this.Value = value;
  46.       }
  47.    }
  48.  
  49.    public class Point3dNode : KdTreeNode<Point3d>
  50.    {
  51.       public Point3dNode( Point3d value ) : base(value)
  52.       {
  53.       }
  54.  
  55.       public double this[int dim]
  56.       {
  57.          get
  58.          {
  59.             return this.Value[this.Depth % dim];
  60.          }
  61.       }
  62.    }
  63.  
  64.    public class Point3dTree
  65.    {
  66.       private int dimension;
  67.       public KdTreeNode<Point3d> Root
  68.       {
  69.          get;
  70.          private set;
  71.       }
  72.  
  73.       public Point3dTree( IEnumerable<Point3d> points ) : this( points, false )
  74.       {
  75.       }
  76.  
  77.       public Point3dTree( IEnumerable<Point3d> points, bool ignoreZ )
  78.       {
  79.          if( points == null )
  80.             throw new ArgumentNullException( "points" );
  81.          this.dimension = ignoreZ ? 2 : 3;
  82.          Point3d[] pts = points.Distinct().ToArray();
  83.          this.Root = Construct( pts, 0 );
  84.       }
  85.  
  86.       public List<Point3d> NearestNeighbours( KdTreeNode<Point3d> node, double radius )
  87.       {
  88.          if( node == null )
  89.             throw new ArgumentNullException( "node" );
  90.          List<Point3d> result = new List<Point3d>();
  91.          GetNeighbours( node.Value, radius, this.Root.LeftChild, result );
  92.          GetNeighbours( node.Value, radius, this.Root.RightChild, result );
  93.          return result;
  94.       }
  95.  
  96.       private void GetNeighbours( Point3d center, double radius, KdTreeNode<Point3d> node, List<Point3d> result )
  97.       {
  98.          if( node == null ) return;
  99.          Point3d pt = node.Value;
  100.          if( !node.Processed && center.DistanceTo( pt ) <= radius )
  101.          {
  102.             result.Add( pt );
  103.          }
  104.          int d = node.Depth % this.dimension;
  105.          double coordCen = center[d];
  106.          double coordPt = pt[d];
  107.          if( Math.Abs( coordCen - coordPt ) > radius )
  108.          {
  109.             if( coordCen < coordPt )
  110.             {
  111.                GetNeighbours( center, radius, node.LeftChild, result );
  112.             }
  113.             else
  114.             {
  115.                GetNeighbours( center, radius, node.RightChild, result );
  116.             }
  117.          }
  118.          else
  119.          {
  120.             GetNeighbours( center, radius, node.LeftChild, result );
  121.             GetNeighbours( center, radius, node.RightChild, result );
  122.          }
  123.       }
  124.  
  125.       internal static bool useParallel = false;
  126.  
  127.       private KdTreeNode<Point3d> Construct( Point3d[] points, int depth )
  128.       {
  129.          int length = points.Length;
  130.          if( length == 0 ) return null;
  131.          int d = depth % this.dimension;
  132.          Array.Sort( points, ( p1, p2 ) => p1[d].CompareTo( p2[d] ) );
  133.          int med = length / 2;
  134.          KdTreeNode<Point3d> node = new KdTreeNode<Point3d>( points[med] );
  135.          node.Depth = depth;
  136.          var left = points.Take( med ).ToArray();
  137.          var right = points.Skip( med + 1 ).ToArray();\
  138.          
  139.          /////////////////////////////////////////////////////////////
  140.          /// Do construction in parallel up to 8 levels (it
  141.          /// makes little sense to parallelize any more than
  142.          /// that, unless you got dozens of CPUs):
  143.          
  144.          if( useParallel && depth < 9 )
  145.          {
  146.             System.Threading.Tasks.Parallel.Invoke(
  147.                () => node.LeftChild = Construct( left, depth + 1 ),
  148.                () => node.RightChild = Construct( right, depth + 1 )
  149.             );
  150.          }
  151.          else
  152.          {
  153.             node.LeftChild = Construct( left, depth + 1 );
  154.             node.RightChild = Construct( right, depth + 1 );
  155.          }
  156.          return node;
  157.       }
  158.  
  159.    }
  160.  
  161.    public class CommandMethods
  162.    {
  163.       struct Point3dPair
  164.       {
  165.          public readonly Point3d Start;
  166.          public readonly Point3d End;
  167.  
  168.          public Point3dPair( Point3d start, Point3d end )
  169.          {
  170.             this.Start = start;
  171.             this.End = end;
  172.          }
  173.       }
  174.  
  175.       /// Toggle parallel execution on/off:
  176.      
  177.       [CommandMethod( "CONNECT_PARALLEL" )]
  178.       public static void ConnectParallelSwitch()
  179.       {
  180.          Point3dTree.useParallel ^= true;
  181.          Application.DocumentManager.MdiActiveDocument.Editor.WriteMessage(
  182.             "\nConnect using Parallel = {0}", Point3dTree.useParallel );
  183.       }
  184.  
  185.       [CommandMethod( "Connect" )]
  186.       public static void Connect()
  187.       {
  188.          Stopwatch sw = new System.Diagnostics.Stopwatch();
  189.          Stopwatch sw2 = null;
  190.          sw.Start();
  191.          Document doc = AcAp.DocumentManager.MdiActiveDocument;
  192.          Database db = doc.Database;
  193.          Editor ed = doc.Editor;
  194.          if( Point3dTree.useParallel )
  195.             ed.WriteMessage( "\n(Using Parallel Execution)\n" );
  196.          else
  197.             ed.WriteMessage( "\n(Using Sequential Execution)\n" );
  198.          RXClass rxc = RXClass.GetClass( typeof( DBPoint ) );
  199.          using( BlockTableRecord btr = (BlockTableRecord) db.CurrentSpaceId.Open( OpenMode.ForWrite ) )
  200.          {
  201.             var pts = btr.Cast<ObjectId>()
  202.                 .Where( id => id.ObjectClass == rxc )
  203.                 .Select( id =>
  204.                 {
  205.                    using( DBPoint pt = (DBPoint) id.Open( OpenMode.ForRead ) )
  206.                    {
  207.                       return pt.Position;
  208.                    }
  209.                 } );
  210.  
  211.             sw2 = Stopwatch.StartNew();
  212.             Point3dTree tree = new Point3dTree( pts, true );
  213.             List<Point3dPair> pairs = new List<Point3dPair>();
  214.             ConnectPoints( tree, tree.Root, 70.0, pairs );
  215.             sw2.Stop();
  216.             foreach( Point3dPair pair in pairs )
  217.             {
  218.                using( Line line = new Line( pair.Start, pair.End ) )
  219.                {
  220.                   btr.AppendEntity( line );
  221.                }
  222.             }
  223.          }
  224.          sw.Stop();
  225.          ed.WriteMessage( "\nElapsed milliseconds: {0}", sw.ElapsedMilliseconds );
  226.          ed.WriteMessage( "\nNearestNeighbor time: {0}", sw2.ElapsedMilliseconds );
  227.       }
  228.  
  229.  
  230.       /// Don't use 'ref' parameters for reference types, unless
  231.       /// your intention is to modify the value of the variable
  232.       /// in the caller. If you are only accessing the reference
  233.       /// type (e.g., adding items to a List<T>), you don't need
  234.       /// 'ref', which has some performance overhead.
  235.      
  236.       private static void ConnectPoints( Point3dTree tree, KdTreeNode<Point3d> node, double dist, List<Point3dPair> pointPairs )
  237.       {
  238.          if( node == null ) return;
  239.          node.Processed = true;
  240.          Point3d center = node.Value;
  241.          foreach( Point3d pt in tree.NearestNeighbours( node, dist ) )
  242.          {
  243.             pointPairs.Add( new Point3dPair( center, pt ) );
  244.          }
  245.          ConnectPoints( tree, node.LeftChild, dist, pointPairs );
  246.          ConnectPoints( tree, node.RightChild, dist, pointPairs );
  247.       }
  248.    }
  249. }
  250.  
  251.  

Quick test with 50K points:

Code - Text: [Select]
  1.  
  2. Command: CONNECT
  3. (Using Sequential Execution)
  4. Elapsed milliseconds: 875
  5. NearestNeighbor time: 858
  6.  
  7. Command: CONNECT_PARALLEL
  8. Connect using Parallel = True
  9.  
  10. Command: CONNECT
  11. (Using Parallel Execution)
  12. Elapsed milliseconds: 604
  13. NearestNeighbor time: 602
  14.  
  15.  
« Last Edit: April 02, 2013, 06:38:06 AM by TT »

pkohut

  • Bull Frog
  • Posts: 483
Re: Kd Tree
« Reply #7 on: April 02, 2013, 10:00:17 AM »
Nice speed bump and coding Tony.

FWIW, Array.Sort is O(n log n), while nth_element is O(n).
New tread (not retired) - public repo at https://github.com/pkohut

TheMaster

  • Guest
Re: Kd Tree
« Reply #8 on: April 02, 2013, 06:29:50 PM »
Nice speed bump and coding Tony.

FWIW, Array.Sort is O(n log n), while nth_element is O(n).

Hi Paul, and thanks.

I posted what I thought was a notable improvement to Gile's original code, but as it turned out, it was Linq's deferred execution that was corrupting the timings, so I retracted it.

I tried using the QuickSelect algorithm in place of Array.Sort(), but that didn't seem to help much, and I think that's mainly because the majority of the execution time of Gile's code is consumed by getting the coordinate data from the DBPoints, drawing the Lines, and most-notably, by a call to Enumerable.Distinct() to remove all duplicates. Only a small fraction of it is consumed by building the tree, and I don't know if that's as important as searching it.

Below is Gile's code with a few ideas on how he might improve overall performance implemented, which has various tweaks, such as avoiding the use of ToArray() when the quantity is known, avoiding use of Enumerable.Distinct() (which could be dealt with in the process of building the tree), and storing the entire dataset of points in array, and operating on int[] arrays of indices into the array, rather than directly on the Point3ds themselves.  So, the Construct() method was refactored to pass arrays of int[], that contain indices of elements in the Point3d[] array, which is never duplicated.

Code - C#: [Select]
  1.  
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Autodesk.AutoCAD.ApplicationServices;
  6. using Autodesk.AutoCAD.DatabaseServices;
  7. using Autodesk.AutoCAD.EditorInput;
  8. using Autodesk.AutoCAD.Geometry;
  9. using Autodesk.AutoCAD.Runtime;
  10. using AcAp = Autodesk.AutoCAD.ApplicationServices.Application;
  11. using System.Diagnostics;
  12.  
  13. namespace PointKdTree
  14. {
  15.    public class KdTreeNode<T>
  16.    {
  17.       public T Value
  18.       {
  19.          get;
  20.          internal set;
  21.       }
  22.       public KdTreeNode<T> LeftChild
  23.       {
  24.          get;
  25.          internal set;
  26.       }
  27.       public KdTreeNode<T> RightChild
  28.       {
  29.          get;
  30.          internal set;
  31.       }
  32.       public int Depth
  33.       {
  34.          get;
  35.          internal set;
  36.       }
  37.       public bool Processed
  38.       {
  39.          get;
  40.          set;
  41.       }
  42.  
  43.       public KdTreeNode( T value )
  44.       {
  45.          this.Value = value;
  46.       }
  47.    }
  48.  
  49.    public class Point3dNode : KdTreeNode<Point3d>
  50.    {
  51.       public Point3dNode( Point3d value ) : base(value)
  52.       {
  53.       }
  54.  
  55.       public double this[int dim]
  56.       {
  57.          get
  58.          {
  59.             return this.Value[this.Depth % dim];
  60.          }
  61.       }
  62.    }
  63.  
  64.    public class Point3dTree
  65.    {
  66.       private int dimension;
  67.  
  68.       public KdTreeNode<Point3d> Root
  69.       {
  70.          get;
  71.          private set;
  72.       }
  73.  
  74.       public Point3dTree( IEnumerable<Point3d> points )
  75.          : this( points, false )
  76.       {
  77.       }
  78.  
  79.       /// <summary>
  80.       /// Eliminated call to Enumerable.Distinct() which
  81.       /// was consuming most of the time required here:
  82.       /// </summary>
  83.       public Point3dTree( IEnumerable<Point3d> points, bool ignoreZ )
  84.       {
  85.          if( points == null )
  86.             throw new ArgumentNullException( "points" );
  87.          this.dimension = ignoreZ ? 2 : 3;
  88.          var data = points as Point3d[] ?? points.ToArray(); //
  89.          int[] indices = new int[data.Length];
  90.          for( int i = 0; i < data.Length; i++ )
  91.             indices[i] = i;
  92.          this.Root = Construct( data, indices, 0 );
  93.       }
  94.  
  95.       public List<Point3d> NearestNeighbours( KdTreeNode<Point3d> node, double radius )
  96.       {
  97.          if( node == null )
  98.             throw new ArgumentNullException( "node" );
  99.          List<Point3d> result = new List<Point3d>();
  100.          GetNeighbours( node.Value, radius, this.Root.LeftChild, result );
  101.          GetNeighbours( node.Value, radius, this.Root.RightChild, result );
  102.          return result;
  103.       }
  104.  
  105.       private void GetNeighbours( Point3d center, double radius, KdTreeNode<Point3d> node, List<Point3d> result )
  106.       {
  107.          if( node == null ) return;
  108.          Point3d pt = node.Value;
  109.          if( !node.Processed && center.DistanceTo( pt ) <= radius )
  110.          {
  111.             result.Add( pt );
  112.          }
  113.          int d = node.Depth % this.dimension;
  114.          double coordCen = center[d];
  115.          double coordPt = pt[d];
  116.          if( Math.Abs( coordCen - coordPt ) > radius )
  117.          {
  118.             if( coordCen < coordPt )
  119.             {
  120.                GetNeighbours( center, radius, node.LeftChild, result );
  121.             }
  122.             else
  123.             {
  124.                GetNeighbours( center, radius, node.RightChild, result );
  125.             }
  126.          }
  127.          else
  128.          {
  129.             GetNeighbours( center, radius, node.LeftChild, result );
  130.             GetNeighbours( center, radius, node.RightChild, result );
  131.          }
  132.       }
  133.  
  134.       internal static bool useParallel = false;
  135.  
  136.       /// <summary>
  137.       /// Refactored to pass entire dataset of Point3d in data,
  138.       /// and to manipulate int[] arrays of indices.
  139.       /// </summary>
  140.      
  141.       private KdTreeNode<Point3d> Construct( Point3d[] data, int[] points, int depth )
  142.       {
  143.          int length = points.Length;
  144.          if( length == 0 ) return null;
  145.          int d = depth % this.dimension;
  146.          int l1 = length / 2;
  147.          int l2 = length - l1 - 1;
  148.          var items = new KeyValuePair<double, int>[length];
  149.          for( int i = 0; i < length; i++ )
  150.             items[i] = new KeyValuePair<double, int>( data[points[i]][d], i );
  151.          int pivot = OrderedSelect( items, d );
  152.          var left = new int[l1];
  153.          for( int i = 0; i < l1; i++ )
  154.             left[i] = items[i].Value;
  155.          var right = new int[l2];
  156.          ++l1;
  157.          for( int i = 0; i < l2; i++ )
  158.             right[i] = items[i+l1].Value;
  159.          KdTreeNode<Point3d> node = new KdTreeNode<Point3d>( data[pivot] );
  160.          KdTreeNode<Point3d> leftchild = null;
  161.          KdTreeNode<Point3d> rightchild = null;
  162.          node.Depth = depth;
  163.          if( useParallel && depth < 4 ) // parallelize up to 4 levels deep
  164.          {
  165.             System.Threading.Tasks.Parallel.Invoke(
  166.                () => leftchild = Construct( data, left, depth + 1 ),
  167.                () => rightchild = Construct( data, right, depth + 1 ) );
  168.          }
  169.          else
  170.          {
  171.             leftchild = Construct( data, left, depth + 1 );
  172.             rightchild = Construct( data, right, depth + 1 );
  173.          }
  174.          node.LeftChild = leftchild;
  175.          node.RightChild = rightchild;
  176.          return node;
  177.       }
  178.  
  179.       // Quick select algorithm operating on KeyValuePair<double,int>
  180.       // where the key is the ordinate/dimension and the value is the
  181.       // index of the associated Point3d in the dataset:
  182.  
  183.       public static int OrderedSelect( KeyValuePair<double, int>[] points, int dim )
  184.       {
  185.          int k = points.Length / 2;
  186.          if( points == null || points.Length <= k )
  187.             throw new ArgumentException( "array" );
  188.  
  189.          int from = 0, to = points.Length - 1;
  190.          while( from < to )
  191.          {
  192.             int r = from, w = to;
  193.             double pivot = points[( r + w ) / 2].Key;
  194.             while( r < w )
  195.             {
  196.                if( points[r].Key >= pivot )
  197.                {
  198.                   KeyValuePair<double, int> tmp = points[w];
  199.                   points[w] = points[r];
  200.                   points[r] = tmp;
  201.                   w--;
  202.                }
  203.                else
  204.                {
  205.                   r++;
  206.                }
  207.             }
  208.             if( points[r].Key > pivot )
  209.             {
  210.                r--;
  211.             }
  212.             if( k <= r )
  213.             {
  214.                to = r;
  215.             }
  216.             else
  217.             {
  218.                from = r + 1;
  219.             }
  220.          }
  221.          return points[k].Value;
  222.       }
  223.  
  224.    }
  225.  
  226.    public class CommandMethods
  227.    {
  228.       struct Point3dPair
  229.       {
  230.          public readonly Point3d Start;
  231.          public readonly Point3d End;
  232.  
  233.          public Point3dPair( Point3d start, Point3d end )
  234.          {
  235.             this.Start = start;
  236.             this.End = end;
  237.          }
  238.       }
  239.  
  240.       [CommandMethod( "CONNECT_PARALLEL" )]
  241.       public static void ConnectParallelSwitch()
  242.       {
  243.          Point3dTree.useParallel ^= true;
  244.          Application.DocumentManager.MdiActiveDocument.Editor.WriteMessage(
  245.             "\nConnect using Parallel = {0}", Point3dTree.useParallel );
  246.       }
  247.  
  248.       [CommandMethod( "Connect" )]
  249.       public static void Connect()
  250.       {
  251.          Stopwatch sw = new System.Diagnostics.Stopwatch();
  252.          Stopwatch sw2 = null;
  253.          Stopwatch sw3 = null;
  254.          long distinctTime = 0L;
  255.          sw.Start();
  256.          Document doc = AcAp.DocumentManager.MdiActiveDocument;
  257.          Database db = doc.Database;
  258.          Editor ed = doc.Editor;
  259.          if( Point3dTree.useParallel )
  260.             ed.WriteMessage( "\n(Using Parallel Execution)\n" );
  261.          else
  262.             ed.WriteMessage( "\n(Using Sequential Execution)\n" );
  263.          RXClass rxc = RXClass.GetClass( typeof( DBPoint ) );
  264.          int cnt = 0;
  265.  
  266.          using( BlockTableRecord btr = (BlockTableRecord) db.CurrentSpaceId.Open( OpenMode.ForWrite ) )
  267.          {
  268.             var dbpoints = btr.Cast<ObjectId>().Where( id => id.ObjectClass == rxc );
  269.             ObjectId[] ids = dbpoints.ToArray();
  270.             int len = ids.Length;
  271.             Point3d[] array = new Point3d[len];
  272.             for( int i = 0; i < len; i++ )
  273.             {
  274.                using( DBPoint pt = (DBPoint) ids[i].Open( OpenMode.ForRead ) )
  275.                {
  276.                   array[i] = pt.Position;
  277.                }
  278.             }
  279.             sw3 = Stopwatch.StartNew();
  280.             Point3dTree tree = new Point3dTree( array, true );
  281.             sw3.Stop();
  282.             sw2 = Stopwatch.StartNew();
  283.             List<Point3dPair> pairs = new List<Point3dPair>( cnt / 2 );
  284.             ConnectPoints( tree, tree.Root, 70.0, pairs );
  285.             sw2.Stop();
  286.             foreach( Point3dPair pair in pairs )
  287.             {
  288.                using( Line line = new Line( pair.Start, pair.End ) )
  289.                {
  290.                   btr.AppendEntity( line );
  291.                }
  292.             }
  293.          }
  294.          sw.Stop();
  295.          ed.WriteMessage( "\nTotal time:                {0,5}", sw.ElapsedMilliseconds );
  296.          ed.WriteMessage( "\nConstruct time:            {0,5}", sw3.ElapsedMilliseconds );
  297.          ed.WriteMessage( "\nNearestNeighbor time:      {0,5}", sw2.ElapsedMilliseconds );
  298.       }
  299.  
  300.       private static void ConnectPoints( Point3dTree tree, KdTreeNode<Point3d> node, double dist, List<Point3dPair> pointPairs )
  301.       {
  302.          if( node == null ) return;
  303.          node.Processed = true;
  304.          Point3d center = node.Value;
  305.          foreach( Point3d pt in tree.NearestNeighbours( node, dist ) )
  306.          {
  307.             pointPairs.Add( new Point3dPair( center, pt ) );
  308.          }
  309.          ConnectPoints( tree, node.LeftChild, dist, pointPairs );
  310.          ConnectPoints( tree, node.RightChild, dist, pointPairs );
  311.       }
  312.    }
  313. }
  314.  
  315.  
« Last Edit: April 02, 2013, 10:17:43 PM by TT »

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Re: Kd Tree
« Reply #9 on: April 03, 2013, 07:00:58 AM »
Hi,

First of all, thanks to Tony and Paul for their interest and advice.

Parallelization
It was part of my goals. I tried to develop (in parallel) the same thing in F# in order to use these features but I still encountering some problems.
I tried to parallelize the search part using a ConcurrentBag, it seems to have much overhead so that it slows down significantly the process.

Unnecessary use of ref
Thank you for pointing that. The back and forth between C# and F#, imperative and functional programming, mutable and immutable data have probably helped to make me miss it.

Array.Copy vs Take / Skip
I changed this as Tony's advice.

Array.Sort vs QuickSelect
I tried to implement some other QuickSelect and nth-smallest algorithms. Each time Array.Sort () is (much) faster.
Tony, your OrderedSelect() function is not working here.

Distinct ()
I did not notice any significant speed difference with or without.

Here're somme results with my current but not last version (changes made for point 1, 2 and 3, keeping using Array.Sort() and Distinct()).

100K points
Code: [Select]
Connect using Parallel = False
Get points:      388
Build tree:      718
Connect points:  659
Draw lines:      869
Total:          2634

Connect using Parallel = True
Get points:      374
Build tree:      287
Connect points:  669
Draw lines:      883
Total:          2213
« Last Edit: April 03, 2013, 07:31:53 AM by gile »
Speaking English as a French Frog

TheMaster

  • Guest
Re: Kd Tree
« Reply #10 on: April 03, 2013, 10:09:32 AM »

Tony, your OrderedSelect() function is not working here.


Hi Gile, yes, there's a bug in the code I posted.

Code - C#: [Select]
  1.  
  2.      for( int i = 0; i < length; i++ )
  3.           items[i] = new KeyValuePair<double, int>( data[points[i]][d], i );
  4.  
  5. // Should be:
  6.  
  7.      for( int i = 0; i < length; i++ )
  8.      {  
  9.           int index = points[i];
  10.           items[i] = new KeyValuePair<double, int>( data[index][d], index );
  11.      }
  12.  
  13.  

I believe, with that it should work, but I haven't seen much of a difference in the result.

FWIW, the C++ kdtree code I've used in the past uses a bubble-sort, which is probably also not optimal.

[/quote]

TheMaster

  • Guest
Re: Kd Tree
« Reply #11 on: April 03, 2013, 10:22:47 AM »
Here're somme results with my current but not last version (changes made for point 1, 2 and 3, keeping using Array.Sort() and Distinct()).

100K points
Code: [Select]
Connect using Parallel = False
Get points:      388
Build tree:      718
Connect points:  659
Draw lines:      869
Total:          2634

Connect using Parallel = True
Get points:      374
Build tree:      287
Connect points:  669
Draw lines:      883
Total:          2213

Hi Gile.  In your original code, it wasn't possible to break out the time required to build the tree, because the Linq code that opens each DBPoint and gets its position wasn't executing until the constructor of your tree class called Distinct().ToArray() on the IEnumerable<Point3d> argument.

You might notice that I rewrote that part, so that I could measure the tree construction time, without also including the time required to iterate the BlockTableRecord, open  each DBPoint and get its Position.

So, I'm not entirely sure what to make of your numbers because I don't know if you're using the same code from the above post, or code based on your original post, but because of the Linq deferred execution, the times will be very different.

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Re: Kd Tree
« Reply #12 on: April 03, 2013, 11:02:02 AM »
Quote
So, I'm not entirely sure what to make of your numbers because I don't know if you're using the same code from the above post, or code based on your original post, but because of the Linq deferred execution, the times will be very different.
I changed the code the same way you did: build a Point3d array (Get points) and then corstruvt the tree (Build tree).
Speaking English as a French Frog

TheMaster

  • Guest
Re: Kd Tree
« Reply #13 on: April 03, 2013, 12:39:09 PM »

Array.Sort vs QuickSelect
I tried to implement some other QuickSelect and nth-smallest algorithms. Each time Array.Sort () is (much) faster. Tony, your OrderedSelect() function is not working here.

Gile, after applying the bug fix I showed in my other reply, it seems to be working, and I wanted to be sure, so I wrote this quick test, and it also confirms what Paul has suggested (much faster than Array.Sort), but of course, since the time to build the tree is a very small fraction of your overall time, it isn't significant to that, but I would still use it rather than Array.Sort anyways.

Code - C#: [Select]
  1.  
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using Autodesk.AutoCAD.Runtime;
  7. using Autodesk.AutoCAD.EditorInput;
  8. using System.Diagnostics;
  9. using Autodesk.AutoCAD.ApplicationServices;
  10.  
  11. namespace System
  12. {
  13.  
  14.    public static class QuickSelectExtensions
  15.    {
  16.       /// <summary>
  17.       /// Quick Select algorithm generic implementation
  18.       ///
  19.       /// Rearranges the given array up to the k'th element
  20.       /// (the median element) so that all values to the left
  21.       /// are < the median element. Returns the median element.
  22.       ///
  23.       /// </summary>
  24.      
  25.       public static T QuickSelectMedian<T>( this T[] items, Comparison<T> compare )
  26.       {
  27.          if( items == null )
  28.             throw new ArgumentException( "items" );
  29.          int k = items.Length / 2;
  30.          int from = 0;
  31.          int to = items.Length - 1;
  32.          while( from < to )
  33.          {
  34.             int r = from;
  35.             int w = to;
  36.             T current = items[( r + w ) / 2];
  37.             while( r < w )
  38.             {
  39.                if( compare( items[r], current ) > -1 )
  40.                {
  41.                   var tmp = items[w];
  42.                   items[w] = items[r];
  43.                   items[r] = tmp;
  44.                   w--;
  45.                }
  46.                else
  47.                {
  48.                   r++;
  49.                }
  50.             }
  51.             if( compare( items[r], current ) > 0 )
  52.             {
  53.                r--;
  54.             }
  55.             if( k <= r )
  56.             {
  57.                to = r;
  58.             }
  59.             else
  60.             {
  61.                from = r + 1;
  62.             }
  63.          }
  64.          return items[k];
  65.       }
  66.    }
  67.  
  68.    public static class QSelectVsQuickSortCommands
  69.    {
  70.       /// <summary>
  71.       ///
  72.       /// Compare QuickSort (Array.Sort), with QuickSelectMedian()
  73.       /// on a random array of double[]:
  74.       ///
  75.       /// </summary>
  76.      
  77.       [CommandMethod( "QSELVSQSORT" )]
  78.       public static void QuickSelectVersesQuickSort()
  79.       {
  80.          Editor ed = Application.DocumentManager.MdiActiveDocument.Editor;
  81.  
  82. #if( DEBUG )
  83.          ed.WriteMessage( "\nTesting should NOT be done in DEBUG builds 8)");
  84.          return;
  85. #endif
  86.          int count = 1000000;
  87.          Random rnd = new Random();
  88.          double[] array = new double[count];
  89.          for( int i = 0; i < count; i++ )
  90.             array[i] = rnd.NextDouble();
  91.          double[] array2 = new double[count];
  92.          Array.Copy( array, array2, count );
  93.          Stopwatch sw = Stopwatch.StartNew();
  94.          Array.Sort( array, ( a, b ) => a.CompareTo( b ) );
  95.          sw.Stop();
  96.          Stopwatch sw2 = Stopwatch.StartNew();
  97.          double med = array2.QuickSelectMedian(
  98.             Comparer<double>.Default.Compare );
  99.          sw2.Stop();
  100.          ed.WriteMessage( "\nArray.Sort():        {0,6}",
  101.             sw.ElapsedMilliseconds );
  102.          ed.WriteMessage( "\nQuickSelectMedian(): {0,6}",
  103.             sw2.ElapsedMilliseconds );
  104.       }
  105.  
  106.       /// <summary>
  107.       ///  Verify if QuickSelectMedian() is working:
  108.       ///  
  109.       ///  After returning, all values in the array
  110.       ///  whose indices are below the median value
  111.       ///  must be less than the median value. The
  112.       ///  median value is the value at [length/2]
  113.       ///  
  114.       /// </summary>
  115.       [CommandMethod( "QSELTEST" )]
  116.       public static void Qselect()
  117.       {
  118.          int count = 25;
  119.          Random rnd = new Random();
  120.          double[] array = new double[count];
  121.          for( int i = 0; i < count; i++ )
  122.             array[i] = 100 * rnd.NextDouble();
  123.          WriteLine( "\narray Before QuickSelectMedian():\n" );
  124.          Dump( array );
  125.          double med = array.QuickSelectMedian(
  126.             Comparer<double>.Default.Compare );
  127.          WriteLine( "\nQuickSelectMedian(): [{0}] = {1}\n",
  128.             array.Length / 2, med );
  129.          WriteLine( "\narray After QuickSelectMedian():\n" );
  130.          Dump( array );
  131.       }
  132.  
  133.       public static void WriteLine( string msg, params object[] args )
  134.       {
  135.          Application.DocumentManager.MdiActiveDocument.Editor.WriteMessage(
  136.             "\n" + msg, args );
  137.       }
  138.  
  139.       static void Dump( double[] array )
  140.       {
  141.          int cnt = array.Length;
  142.          for( int i = 0; i < cnt; i++ )
  143.          {
  144.             WriteLine( "array[{0}] = {1:f4}", i, array[i] );
  145.          }
  146.       }
  147.    }
  148. }
  149.  
  150.  


Code - Text: [Select]
  1.  
  2. Command: QSELTEST
  3.  
  4. array Before QuickSelectMedian():
  5.  
  6. array[0] = 99.9556
  7. array[1] = 87.1161
  8. array[2] = 65.3429
  9. array[3] = 95.0263
  10. array[4] = 5.5207
  11. array[5] = 38.1086
  12. array[6] = 82.9302
  13. array[7] = 83.4299
  14. array[8] = 34.1567
  15. array[9] = 46.9312
  16. array[10] = 24.4187
  17. array[11] = 55.2761
  18. array[12] = 9.0754
  19. array[13] = 92.8103
  20. array[14] = 35.5625
  21. array[15] = 88.6875
  22. array[16] = 7.8821
  23. array[17] = 10.9855
  24. array[18] = 69.7092
  25. array[19] = 17.6414
  26. array[20] = 82.9813
  27. array[21] = 12.4893
  28. array[22] = 2.8624
  29. array[23] = 21.3571
  30. array[24] = 74.9873
  31.  
  32. QuickSelectMedian(): [12] = 46.9311974229902
  33.  
  34. array After QuickSelectMedian():
  35.  
  36. array[0] = 2.8624
  37. array[1] = 7.8821
  38. array[2] = 5.5207
  39. array[3] = 21.3571
  40. array[4] = 12.4893
  41. array[5] = 17.6414
  42. array[6] = 10.9855
  43. array[7] = 34.1567
  44. array[8] = 9.0754
  45. array[9] = 24.4187
  46. array[10] = 35.5625
  47. array[11] = 38.1086
  48. array[12] = 46.9312
  49. array[13] = 55.2761
  50. array[14] = 65.3429
  51. array[15] = 69.7092
  52. array[16] = 82.9302
  53. array[17] = 82.9813
  54. array[18] = 74.9873
  55. array[19] = 87.1161
  56. array[20] = 83.4299
  57. array[21] = 88.6875
  58. array[22] = 99.9556
  59. array[23] = 95.0263
  60. array[24] = 92.8103
  61.  
  62.  
  63.  
  64.  
  65. Command: QSELVSQSORT
  66.  
  67. Array.Sort():           379
  68. QuickSelectMedian():     47
  69.  
  70. Command:
  71.  
  72.  
« Last Edit: April 03, 2013, 09:50:05 PM by TT »

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Re: Kd Tree
« Reply #14 on: April 03, 2013, 02:16:00 PM »
Thanks again, Tony.
I'll give a look and try to understand why my QuickSelect was so slow.
Speaking English as a French Frog