Author Topic: Kd Tree  (Read 20501 times)

0 Members and 1 Guest are viewing this topic.

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Re: Kd Tree
« Reply #15 on: April 03, 2013, 06:05:14 PM »
thanks again Tony, I get some improvement using your QuickSelectMedian<T>() method.

Here were I am:

Benchmarck (Release mode for 100K points)
Code: [Select]
Connect using Parallel = False
Get points:      236
Build tree:      203
Connect points:  265
Draw lines:      722
Total:          1426

Connect using Parallel = True
Get points:      244
Build tree:      105
Connect points:  282
Draw lines:      721
Total:          1352

The Point3dTree class.
I added a NearestNeighbour() method to find the nearest point of a given point.
Code - C#: [Select]
  1. using System;
  2. using System.Collections.Concurrent;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Autodesk.AutoCAD.Geometry;
  6.  
  7. namespace PointKdTree
  8. {
  9.     public class KdTreeNode<T>
  10.     {
  11.         public KdTreeNode(T value) { this.Value = value; }
  12.  
  13.         public T Value { get; internal set; }
  14.         public KdTreeNode<T> LeftChild { get; internal set; }
  15.         public KdTreeNode<T> RightChild { get; internal set; }
  16.         public int Depth { get; internal set; }
  17.         public bool Processed { get; set; }
  18.     }
  19.  
  20.     public class Point3dTree
  21.     {
  22.         private int dimension;
  23.         internal static bool useParallel = false;
  24.         public KdTreeNode<Point3d> Root { get; private set; }
  25.  
  26.         public Point3dTree(IEnumerable<Point3d> points) : this(points, false) { }
  27.  
  28.         public Point3dTree(IEnumerable<Point3d> points, bool ignoreZ)
  29.         {
  30.             if (points == null)
  31.                 throw new ArgumentNullException("points");
  32.             this.dimension = ignoreZ ? 2 : 3;
  33.             Point3d[] pts = points.Distinct().ToArray();
  34.             this.Root = Construct(pts, 0);
  35.         }
  36.  
  37.         public List<Point3d> NearestNeighbours(KdTreeNode<Point3d> node, double radius)
  38.         {
  39.             if (node == null)
  40.                 throw new ArgumentNullException("node");
  41.             List<Point3d> result = new List<Point3d>();
  42.             NearestNeighbours(node.Value, radius, this.Root.LeftChild, result);
  43.             NearestNeighbours(node.Value, radius, this.Root.RightChild, result);
  44.             return result;
  45.         }
  46.  
  47.         private void NearestNeighbours(Point3d center, double radius, KdTreeNode<Point3d> node, List<Point3d> result)
  48.         {
  49.             if (node == null) return;
  50.             Point3d pt = node.Value;
  51.             if (!node.Processed && center.DistanceTo(pt) <= radius)
  52.             {
  53.                 result.Add(pt);
  54.             }
  55.             int d = node.Depth % this.dimension;
  56.             double coordCen = center[d];
  57.             double coordPt = pt[d];
  58.             if (Math.Abs(coordCen - coordPt) > radius)
  59.             {
  60.                 if (coordCen < coordPt)
  61.                 {
  62.                     NearestNeighbours(center, radius, node.LeftChild, result);
  63.                 }
  64.                 else
  65.                 {
  66.                     NearestNeighbours(center, radius, node.RightChild, result);
  67.                 }
  68.             }
  69.             else
  70.             {
  71.                 NearestNeighbours(center, radius, node.LeftChild, result);
  72.                 NearestNeighbours(center, radius, node.RightChild, result);
  73.             }
  74.         }
  75.  
  76.         public Point3d NearestNeighbour(Point3d location)
  77.         {
  78.             return NearestNeighbour(location, this.Root, this.Root.Value, double.MaxValue);
  79.         }
  80.  
  81.         private Point3d NearestNeighbour(Point3d location, KdTreeNode<Point3d> node, Point3d currentBest, double bestDistance)
  82.         {
  83.             if (node == null)
  84.                 return currentBest;
  85.             int dim = node.Depth % this.dimension;
  86.             Point3d nodeLocation = node.Value;
  87.             double distance = location.DistanceTo(nodeLocation);
  88.             if (distance >= 0.0 && distance < bestDistance)
  89.             {
  90.                 currentBest = nodeLocation;
  91.                 bestDistance = distance;
  92.             }
  93.             bool isLeftNearer = location[dim] < nodeLocation[dim];
  94.             KdTreeNode<Point3d> nearestChild = isLeftNearer ? node.LeftChild : node.RightChild;
  95.             if (nearestChild != null)
  96.             {
  97.                 currentBest = NearestNeighbour(location, nearestChild, currentBest, bestDistance);
  98.                 bestDistance = currentBest.DistanceTo(location);
  99.             }
  100.             if (bestDistance > Math.Abs(location[dim] - nodeLocation[dim]))
  101.             {
  102.                 KdTreeNode<Point3d> farestChild = isLeftNearer ? node.RightChild : node.LeftChild;
  103.                 if (farestChild != null)
  104.                 {
  105.                     currentBest = NearestNeighbour(location, farestChild, currentBest, bestDistance);
  106.                     bestDistance = currentBest.DistanceTo(location);
  107.                 }
  108.             }
  109.             return currentBest;
  110.         }
  111.  
  112.         private KdTreeNode<Point3d> Construct(Point3d[] points, int depth)
  113.         {
  114.             int length = points.Length;
  115.             if (length == 0) return null;
  116.             int d = depth % this.dimension;
  117.             Point3d median = QuickSelectMedian(points, (p1, p2) => p1[d].CompareTo(p2[d]));
  118.             KdTreeNode<Point3d> node = new KdTreeNode<Point3d>(median);
  119.             node.Depth = depth;
  120.             int mid = length / 2;
  121.             int rlen = length - mid - 1;
  122.             Point3d[] left = new Point3d[mid];
  123.             Point3d[] right = new Point3d[rlen];
  124.             Array.Copy(points, 0, left, 0, mid);
  125.             Array.Copy(points, mid + 1, right, 0, rlen);
  126.             if (useParallel && depth < 4)
  127.             {
  128.                 System.Threading.Tasks.Parallel.Invoke(
  129.                    () => node.LeftChild = Construct(left, depth + 1),
  130.                    () => node.RightChild = Construct(right, depth + 1)
  131.                 );
  132.             }
  133.             else
  134.             {
  135.                 node.LeftChild = Construct(left, depth + 1);
  136.                 node.RightChild = Construct(right, depth + 1);
  137.             }
  138.             return node;
  139.         }
  140.  
  141.         // From Tony Tanzillo
  142.         // http://www.theswamp.org/index.php?topic=44312.msg495808#msg495808
  143.         private T QuickSelectMedian<T>(T[] items, Comparison<T> compare)
  144.         {
  145.             int l = items.Length;
  146.             int k = l / 2;
  147.             if (items == null || items.Length == 0)
  148.                 throw new ArgumentException("array");
  149.             int from = 0;
  150.             int to = l - 1;
  151.             while (from < to)
  152.             {
  153.                 int r = from;
  154.                 int w = to;
  155.                 T current = items[(r + w) / 2];
  156.                 while (r < w)
  157.                 {
  158.                     if (compare(items[r], current) > -1)
  159.                     {
  160.                         var tmp = items[w];
  161.                         items[w] = items[r];
  162.                         items[r] = tmp;
  163.                         w--;
  164.                     }
  165.                     else
  166.                     {
  167.                         r++;
  168.                     }
  169.                 }
  170.                 if (compare(items[r], current) > 0)
  171.                 {
  172.                     r--;
  173.                 }
  174.                 if (k <= r)
  175.                 {
  176.                     to = r;
  177.                 }
  178.                 else
  179.                 {
  180.                     from = r + 1;
  181.                 }
  182.             }
  183.             return items[k];
  184.         }
  185.     }
  186. }
  187.  

The command used for the tests
Code - C#: [Select]
  1. // (C) Copyright 2012 by Gilles Chanteau
  2. //
  3. using System.Collections.Generic;
  4. using System.Collections.Concurrent;
  5. using System.Linq;
  6. using Autodesk.AutoCAD.ApplicationServices;
  7. using Autodesk.AutoCAD.DatabaseServices;
  8. using Autodesk.AutoCAD.EditorInput;
  9. using Autodesk.AutoCAD.Geometry;
  10. using Autodesk.AutoCAD.Runtime;
  11. using AcAp = Autodesk.AutoCAD.ApplicationServices.Application;
  12. using System;
  13.  
  14. [assembly: CommandClass(typeof(PointKdTree.CommandMethods))]
  15.  
  16. namespace PointKdTree
  17. {
  18.     public class CommandMethods
  19.     {
  20.         struct Point3dPair
  21.         {
  22.             public readonly Point3d Start;
  23.             public readonly Point3d End;
  24.  
  25.             public Point3dPair(Point3d start, Point3d end)
  26.             {
  27.                 this.Start = start;
  28.                 this.End = end;
  29.             }
  30.         }
  31.  
  32.         // From Tony Tanzillo
  33.         [CommandMethod("CONNECT_PARALLEL")]
  34.         public static void ConnectParallelSwitch()
  35.         {
  36.             Point3dTree.useParallel ^= true;
  37.             Application.DocumentManager.MdiActiveDocument.Editor.WriteMessage(
  38.                "\nConnect using Parallel = {0}", Point3dTree.useParallel);
  39.         }
  40.  
  41.         [CommandMethod("CONNECT")]
  42.         public void Connect()
  43.         {
  44.             System.Diagnostics.Stopwatch sw = new System.Diagnostics.Stopwatch();
  45.             sw.Start();
  46.             Document doc = AcAp.DocumentManager.MdiActiveDocument;
  47.             Database db = doc.Database;
  48.             Editor ed = doc.Editor;
  49.             RXClass rxc = RXClass.GetClass(typeof(DBPoint));
  50.             using (BlockTableRecord btr = (BlockTableRecord)db.CurrentSpaceId.Open(OpenMode.ForWrite))
  51.             {
  52.                 long t0 = sw.ElapsedMilliseconds;
  53.                 ObjectId[] ids = btr.Cast<ObjectId>().Where(id => id.ObjectClass == rxc).ToArray();
  54.                 int len = ids.Length;
  55.                 Point3d[] pts = new Point3d[len];
  56.                 for (int i = 0; i < len; i++)
  57.                 {
  58.                     using (DBPoint pt = (DBPoint)ids[i].Open(OpenMode.ForRead))
  59.                     {
  60.                         pts[i] = pt.Position;
  61.                     }
  62.                 }
  63.                 long t1 = sw.ElapsedMilliseconds;
  64.                 Point3dTree tree = new Point3dTree(pts, true);
  65.                 long t2 = sw.ElapsedMilliseconds;
  66.                 List<Point3dPair> pairs = new List<Point3dPair>();
  67.                 ConnectPoints(tree, tree.Root, 70.0, pairs);
  68.                 long t3 = sw.ElapsedMilliseconds;
  69.                 foreach (Point3dPair pair in pairs)
  70.                 {
  71.                     using (Line line = new Line(pair.Start, pair.End))
  72.                     {
  73.                         btr.AppendEntity(line);
  74.                     }
  75.                 }
  76.                 sw.Stop();
  77.                 long t4 = sw.ElapsedMilliseconds;
  78.                 ed.WriteMessage("\nConnect using Parallel = {0}", Point3dTree.useParallel);
  79.                 ed.WriteMessage(
  80.                     "\nGet points:{0,9}\nBuild tree:{1,9}\nConnect points:{2,5}\nDraw lines:{3,9}\nTotal:{4,14}",
  81.                     t1 - t0, t2 - t1, t3 - t2, t4 - t3, t4);
  82.             }
  83.         }
  84.  
  85.         private void ConnectPoints(Point3dTree tree, KdTreeNode<Point3d> node, double dist, List<Point3dPair> pointPairs)
  86.         {
  87.             if (node == null) return;
  88.             node.Processed = true;
  89.             Point3d center = node.Value;
  90.             foreach (Point3d pt in tree.NearestNeighbours(node, dist))
  91.             {
  92.                 pointPairs.Add(new Point3dPair(center, pt));
  93.             }
  94.             ConnectPoints(tree, node.LeftChild, dist, pointPairs);
  95.             ConnectPoints(tree, node.RightChild, dist, pointPairs);
  96.         }
  97.     }
  98. }
  99.  
Speaking English as a French Frog

TheMaster

  • Guest
Re: Kd Tree
« Reply #16 on: April 03, 2013, 07:42:37 PM »
Hi Gile - Great job.

Unless you're running on a box with 8 physical CPU cores, change the test to run in parallel only if depth < 3 for a quad-core system. You could also check the number of physical CPU cores at runtime and adjust that (for example, on a 2-core system, use Parallel.Invoke() only if depth < 2).

TheMaster

  • Guest
Re: Kd Tree
« Reply #17 on: April 07, 2013, 11:25:11 AM »
Hi Gile - Great job.

Unless you're running on a box with 8 physical CPU cores, change the test to run in parallel only if depth < 3 for a quad-core system. You could also check the number of physical CPU cores at runtime and adjust that (for example, on a 2-core system, use Parallel.Invoke() only if depth < 2).

As it turns out, getting the number of physical cores isn't trivial.

Code - Text: [Select]
  1.  
  2. /// <summary>
  3. /// Queries the system for the number of physical
  4. /// CPU cores and logical processors. Requires a
  5. /// reference to System.Management.dll
  6. ///
  7. /// Credit: http://www.stev.org/post/2011/10/27/C-Number-Of-Cores-and-Processors.aspx
  8. ///
  9. /// </summary>
  10.  
  11. public static class ThreadSystemInfo
  12. {
  13.     static ThreadSystemInfo()
  14.     {
  15.         try
  16.         {
  17.             ObjectQuery wql = new ObjectQuery( "SELECT * FROM Win32_Processor" );
  18.             using( var searcher = new ManagementObjectSearcher( wql ) )
  19.             using( var results = searcher.Get() )
  20.             {
  21.                 if( results != null )
  22.                 {
  23.                     var items = results.Cast<ManagementObject>();
  24.                     if( items.Any() )
  25.                     {
  26.                         var item = items.First();
  27.                         cores = int.Parse( item["NumberOfCores"].ToString() );
  28.                         logicalProcessors = int.Parse( item["NumberOfLogicalProcessors"].ToString() );
  29.                     }
  30.                 }
  31.             }
  32.         }
  33.         catch // above can fail on some systems, so kludge it.
  34.         {
  35.             logicalProcessors = System.Environment.ProcessorCount;
  36.             if( logicalProcessors > 1 )
  37.                 cores = logicalProcessors / 2;
  38.         }
  39.     }
  40.  
  41.     private static int cores = 1;
  42.     public static int PhysicalCoreCount
  43.     {
  44.         get
  45.         {
  46.             return cores;
  47.         }
  48.     }
  49.  
  50.     private static int logicalProcessors = 1;
  51.     public static int LogicalProcessorCount
  52.     {
  53.         get
  54.         {
  55.             return logicalProcessors;
  56.         }
  57.     }
  58. }
  59.  
  60.  
« Last Edit: April 07, 2013, 11:45:21 AM by TT »

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Re: Kd Tree
« Reply #18 on: April 08, 2013, 04:19:01 AM »
Quote
Hi Gile - Great job.
"With a Little Help from My Friends", thanks.

For now I care less which is specific to "connect points challenge" and I focus more on the implementation of a class with methods more useful.
Speaking English as a French Frog

TheMaster

  • Guest
Re: Kd Tree
« Reply #19 on: April 08, 2013, 07:15:20 AM »
Quote
Hi Gile - Great job.
"With a Little Help from My Friends", thanks.

For now I care less which is specific to "connect points challenge" and I focus more on the implementation of a class with methods more useful.

In case anyone wants to know more about the concepts of parallelism and its use in recursive solutions to various programming problems, like that which was used in this thread, this is highly-recommended reading:
 
   http://blogs.msdn.com/b/pfxteam/archive/2008/01/31/7357135.aspx



Jeff H

  • Needs a day job
  • Posts: 6144
Re: Kd Tree
« Reply #20 on: April 09, 2013, 05:55:51 PM »
Hi Gile,
I am coming down with something and hopefully will not get too sick, but if you like I could send you more info tomorrow or the next day when feeling better if you wanted analyze it.
 
 

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Re: Kd Tree
« Reply #21 on: April 12, 2013, 12:09:18 PM »
Thanks Jeff, I hope you feel better.

I added some public methods to the Point3dTree class.

This class may be used to improve performances in case of many queries in a quite large amount of points.
It targets the Framwork 4.0 to use some parallelization features.

According to the value of the 'ignoreZ' constructor argument, a 2d tree (ignoreZ = true) or a 3d tree (ignoreZ = false, default).
Use ignoreZ = true if all points in the input collection lie on a plane parallel to XY or if the points have to be considered as projected on the XY plane.

Public methods:
NearestNeighbour(Point3d) Gets the nearest neighbour.
NearestNeighbours(Point3d, int) Gets the n nearest neighbours.
NearestNeighbours(Point3d, double) Gets the nearest neighbours within the distance.
BoxedRange(Point3d, Point3d) Gets the points in a range.
ConnectAll(double) Gets all the pairs of points which distance is less or equal than the specified distance.
The last one was was almost used for performance tests (and to reply to this challenge).

Code - C#: [Select]
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Autodesk.AutoCAD.Geometry;
  5.  
  6. namespace PointKdTree
  7. {
  8.     /// <summary>
  9.     /// Node of a Point3dTree
  10.     /// </summary>
  11.     public class TreeNode
  12.     {
  13.         /// <summary>
  14.         /// Creates a new instance of TreeNode
  15.         /// </summary>
  16.         /// <param name="value">The 3d point value of the node.</param>
  17.         public TreeNode(Point3d value) { this.Value = value; }
  18.  
  19.         /// <summary>
  20.         /// Gets the value (Point3d) of the node.
  21.         /// </summary>
  22.         public Point3d Value { get; internal set; }
  23.  
  24.         /// <summary>
  25.         /// Gets the parent node.
  26.         /// </summary>
  27.         public TreeNode Parent { get; internal set; }
  28.  
  29.         /// <summary>
  30.         /// Gets the left child node.
  31.         /// </summary>
  32.         public TreeNode LeftChild { get; internal set; }
  33.  
  34.         /// <summary>
  35.         /// Gets the right child node.
  36.         /// </summary>
  37.         public TreeNode RightChild { get; internal set; }
  38.  
  39.         /// <summary>
  40.         /// Gets the depth of the node in tree.
  41.         /// </summary>
  42.         public int Depth { get; internal set; }
  43.     }
  44.  
  45.     /// <summary>
  46.     /// Provides methods to organize 3d points in a Kd tree structure to speed up the search of neighbours.
  47.     /// A boolean constructor parameter (ignoreZ) indicates if the resulting Kd tree is a 3d tree or a 2d tree.
  48.     /// Use ignoreZ = true if all points in the input collection lie on a plane parallel to XY
  49.     /// or if the points have to be considered as projected on the XY plane.
  50.     /// </summary>
  51.     public class Point3dTree
  52.     {
  53.         #region Private fields
  54.  
  55.         private int dimension;
  56.         private int parallelDepth;
  57.         private bool ignoreZ;
  58.         private Point2d center2d;
  59.         private Point3d center;
  60.         private Point3d current;
  61.         private Func<Point3d, double> distanceTo;
  62.  
  63.         #endregion
  64.  
  65.         #region Constructor
  66.  
  67.         /// <summary>
  68.         /// Creates an new instance of Point3dTree.
  69.         /// </summary>
  70.         /// <param name="points">The Point3d collection to fill the tree.</param>
  71.         /// <param name="ignoreZ">A value indicating if the Z coordinate of points is ignored
  72.         /// (as if all points were projected to the XY plane).</param>
  73.         public Point3dTree(IEnumerable<Point3d> points, bool ignoreZ = false)
  74.         {
  75.             if (points == null)
  76.                 throw new ArgumentNullException("points");
  77.             this.ignoreZ = ignoreZ;
  78.             this.dimension = ignoreZ ? 2 : 3;
  79.             int numProc = System.Environment.ProcessorCount;
  80.             this.parallelDepth = -1;
  81.             while (numProc >> ++this.parallelDepth > 1) ;
  82.             Point3d[] pts = points.Distinct().ToArray();
  83.             this.Root = Construct(pts, 0, null);
  84.         }
  85.  
  86.         #endregion
  87.  
  88.         #region Public properties
  89.  
  90.         /// <summary>
  91.         /// Gets the root node of the tree.
  92.         /// </summary>
  93.         public TreeNode Root { get; private set; }
  94.  
  95.         #endregion
  96.  
  97.         #region Public methods
  98.  
  99.         /// <summary>
  100.         /// Gets the nearest neighbour.
  101.         /// </summary>
  102.         /// <param name="point">The point from which search the nearest neighbour.</param>
  103.         /// <returns>The nearest point in the collection from the specified one.</returns>
  104.         public Point3d NearestNeighbour(Point3d point)
  105.         {
  106.             this.center = point;
  107.             this.current = this.Root.Value;
  108.             if (ignoreZ)
  109.             {
  110.                 center2d = point.Flatten();
  111.                 this.distanceTo = Distance2dTo;
  112.             }
  113.             else
  114.             {
  115.                 this.distanceTo = Distance3dTo;
  116.             }
  117.             return GetNeighbour(this.Root, this.Root.Value, double.MaxValue);
  118.         }
  119.  
  120.         /// <summary>
  121.         /// Gets the neighbours within the specified distance.
  122.         /// </summary>
  123.         /// <param name="point">The point from which search the nearest neighbours.</param>
  124.         /// <param name="radius">The distance in which collect the neighbours.</param>
  125.         /// <returns>The points which distance from the specified point is less or equal to the specified distance.</returns>
  126.         public Point3dCollection NearestNeighbours(Point3d point, double radius)
  127.         {
  128.             Point3dCollection points = new Point3dCollection();
  129.             this.center = point;
  130.             if (this.ignoreZ)
  131.             {
  132.                 this.center2d = point.Flatten();
  133.                 this.distanceTo = Distance2dTo;
  134.             }
  135.             else
  136.             {
  137.                 this.distanceTo = Distance3dTo;
  138.             }
  139.             GetNeighboursAtDistance(radius, this.Root.LeftChild, points);
  140.             GetNeighboursAtDistance(radius, this.Root.RightChild, points);
  141.             return points;
  142.         }
  143.  
  144.         /// <summary>
  145.         /// Gets the n nearest neighbours.
  146.         /// </summary>
  147.         /// <param name="point">The point from which search the nearest neighbours.</param>
  148.         /// <param name="number">The number of points to collect.</param>
  149.         /// <returns>The n nearest neighbours of the specified point.</returns>
  150.         public Point3dCollection NearestNeighbours(Point3d point, int number)
  151.         {
  152.             List<Tuple<double, Point3d>> pairs = new List<Tuple<double, Point3d>>(number);
  153.             this.center = point;
  154.             if (ignoreZ)
  155.             {
  156.                 this.center2d = point.Flatten();
  157.                 this.distanceTo = Distance2dTo;
  158.             }
  159.             else
  160.             {
  161.                 this.distanceTo = Distance3dTo;
  162.             }
  163.             GetNNeighbours(number, double.MaxValue, this.Root, pairs);
  164.             Point3dCollection points = new Point3dCollection();
  165.             for (int i = 0; i < pairs.Count; i++)
  166.             {
  167.                 points.Add(pairs[i].Item2);
  168.             }
  169.             return points;
  170.         }
  171.  
  172.         /// <summary>
  173.         /// Gets the points in a range.
  174.         /// </summary>
  175.         /// <param name="pt1">The first corner of range.</param>
  176.         /// <param name="pt2">The opposite corner of the range.</param>
  177.         /// <returns>All points within the box.</returns>
  178.         public Point3dCollection BoxedRange(Point3d pt1, Point3d pt2)
  179.         {
  180.             Point3d lowerLeft = new Point3d(
  181.                 Math.Min(pt1.X, pt2.X), Math.Min(pt1.Y, pt2.Y), Math.Min(pt1.Z, pt2.Z));
  182.             Point3d upperRight = new Point3d(
  183.                 Math.Max(pt1.X, pt2.X), Math.Max(pt1.Y, pt2.Y), Math.Max(pt1.Z, pt2.Z));
  184.             Point3dCollection points = new Point3dCollection();
  185.             FindRange(lowerLeft, upperRight, this.Root, points);
  186.             return points;
  187.         }
  188.  
  189.         /// <summary>
  190.         /// Gets all the pairs of points which distance is less or equal than the specified distance.
  191.         /// </summary>
  192.         /// <param name="radius">The maximum distance between two points. </param>
  193.         /// <returns>The pairs of points which distance is less or equal than the specified distance.</returns>
  194.         public List<Tuple<Point3d, Point3d>> ConnectAll(double radius)
  195.         {
  196.             List<Tuple<Point3d, Point3d>> connexions = new List<Tuple<Point3d, Point3d>>();
  197.             Stack<TreeNode> nodes = new Stack<TreeNode>();
  198.             GetConnexions(this.Root, radius, connexions, nodes);
  199.             return connexions;
  200.         }
  201.  
  202.         #endregion
  203.  
  204.         #region Private methods
  205.  
  206.         private TreeNode Construct(Point3d[] points, int depth, TreeNode parent)
  207.         {
  208.             int length = points.Length;
  209.             if (length == 0) return null;
  210.             int d = depth % this.dimension;
  211.             Point3d median = points.QuickSelectMedian((p1, p2) => p1[d].CompareTo(p2[d]));
  212.             TreeNode node = new TreeNode(median);
  213.             node.Depth = depth;
  214.             node.Parent = parent;
  215.             int mid = length / 2;
  216.             int rlen = length - mid - 1;
  217.             Point3d[] left = new Point3d[mid];
  218.             Point3d[] right = new Point3d[rlen];
  219.             Array.Copy(points, 0, left, 0, mid);
  220.             Array.Copy(points, mid + 1, right, 0, rlen);
  221.             if (depth < this.parallelDepth)
  222.             {
  223.                 System.Threading.Tasks.Parallel.Invoke(
  224.                    () => node.LeftChild = Construct(left, depth + 1, node),
  225.                    () => node.RightChild = Construct(right, depth + 1, node)
  226.                 );
  227.             }
  228.             else
  229.             {
  230.                 node.LeftChild = Construct(left, depth + 1, node);
  231.                 node.RightChild = Construct(right, depth + 1, node);
  232.             }
  233.             return node;
  234.         }
  235.  
  236.         private Point3d GetNeighbour(TreeNode node, Point3d currentBest, double bestDist)
  237.         {
  238.             if (node == null)
  239.                 return currentBest;
  240.             this.current = node.Value;
  241.             int d = node.Depth % this.dimension;
  242.             double coordCen = center[d];
  243.             double coordCur = current[d];
  244.             double dist = this.distanceTo(this.current);
  245.             if (dist >= 0.0 && dist < bestDist)
  246.             {
  247.                 currentBest = this.current;
  248.                 bestDist = dist;
  249.             }
  250.             if (bestDist >= Math.Abs(coordCen - coordCur))
  251.             {
  252.                 currentBest = GetNeighbour(
  253.                     coordCen < coordCur ? node.LeftChild : node.RightChild, currentBest, bestDist);
  254.                 bestDist = this.distanceTo(currentBest);
  255.             }
  256.             else
  257.             {
  258.                 currentBest = GetNeighbour(node.LeftChild, currentBest, bestDist);
  259.                 bestDist = this.distanceTo(currentBest);
  260.                 currentBest = GetNeighbour(node.RightChild, currentBest, bestDist);
  261.                 bestDist = this.distanceTo(currentBest);
  262.             }
  263.             return currentBest;
  264.         }
  265.  
  266.         private void GetNeighboursAtDistance(double radius, TreeNode node, Point3dCollection points)
  267.         {
  268.             if (node == null) return;
  269.             this.current = node.Value;
  270.             double dist = this.distanceTo(this.current);
  271.             if (dist <= radius)
  272.             {
  273.                 points.Add(this.current);
  274.             }
  275.             int d = node.Depth % this.dimension;
  276.             double coordCen = this.center[d];
  277.             double coordCur = this.current[d];
  278.             if (Math.Abs(coordCen - coordCur) > radius)
  279.             {
  280.                 if (coordCen < coordCur)
  281.                 {
  282.                     GetNeighboursAtDistance(radius, node.LeftChild, points);
  283.                 }
  284.                 else
  285.                 {
  286.                     GetNeighboursAtDistance(radius, node.RightChild, points);
  287.                 }
  288.             }
  289.             else
  290.             {
  291.                 GetNeighboursAtDistance(radius, node.LeftChild, points);
  292.                 GetNeighboursAtDistance(radius, node.RightChild, points);
  293.             }
  294.         }
  295.  
  296.         private void GetNNeighbours(int number, double worstDist, TreeNode node, List<Tuple<double, Point3d>> pairs)
  297.         {
  298.             if (node == null) return;
  299.             this.current = node.Value;
  300.             double dist = this.distanceTo(this.current);
  301.             int cnt = pairs.Count;
  302.             if (cnt < number)
  303.             {
  304.                 pairs.Add(new Tuple<double, Point3d>(dist, this.current));
  305.                 pairs.Sort((p1, p2) => p1.Item1.CompareTo(p2.Item1));
  306.                 worstDist = pairs[cnt].Item1;
  307.             }
  308.             else if (dist < worstDist)
  309.             {
  310.                 pairs.RemoveAt(number - 1);
  311.                 pairs.Add(new Tuple<double, Point3d>(dist, this.current));
  312.                 pairs.Sort((p1, p2) => p1.Item1.CompareTo(p2.Item1));
  313.                 worstDist = pairs[number - 1].Item1;
  314.             }
  315.             int d = node.Depth % this.dimension;
  316.             double coordCen = center[d];
  317.             double coordCur = current[d];
  318.             if (Math.Abs(coordCen - coordCur) > worstDist)
  319.             {
  320.                 if (coordCen < coordCur)
  321.                 {
  322.                     GetNNeighbours(number, worstDist, node.LeftChild, pairs);
  323.                 }
  324.                 else
  325.                 {
  326.                     GetNNeighbours(number, worstDist, node.RightChild, pairs);
  327.                 }
  328.             }
  329.             else
  330.             {
  331.                 GetNNeighbours(number, worstDist, node.LeftChild, pairs);
  332.                 GetNNeighbours(number, pairs[pairs.Count - 1].Item1, node.RightChild, pairs);
  333.             }
  334.         }
  335.  
  336.         private void FindRange(Point3d lowerLeft, Point3d upperRight, TreeNode node, Point3dCollection points)
  337.         {
  338.             if (node == null)
  339.                 return;
  340.             this.current = node.Value;
  341.             if (ignoreZ)
  342.             {
  343.                 if (this.current.X >= lowerLeft.X && this.current.X <= upperRight.X &&
  344.                     this.current.Y >= lowerLeft.Y && this.current.Y <= upperRight.Y)
  345.                     points.Add(this.current);
  346.             }
  347.             else
  348.             {
  349.                 if (this.current.X >= lowerLeft.X && this.current.X <= upperRight.X &&
  350.                     this.current.Y >= lowerLeft.Y && this.current.Y <= upperRight.Y &&
  351.                     this.current.Z >= lowerLeft.Z && this.current.Z <= upperRight.Z)
  352.                     points.Add(this.current);
  353.             }
  354.             int d = node.Depth % this.dimension;
  355.             if (upperRight[d] < this.current[d])
  356.                 FindRange(lowerLeft, upperRight, node.LeftChild, points);
  357.             else if (lowerLeft[d] > this.current[d])
  358.                 FindRange(lowerLeft, upperRight, node.RightChild, points);
  359.             else
  360.             {
  361.                 FindRange(lowerLeft, upperRight, node.LeftChild, points);
  362.                 FindRange(lowerLeft, upperRight, node.RightChild, points);
  363.             }
  364.         }
  365.  
  366.         private void GetConnexions(TreeNode node, double radius, List<Tuple<Point3d, Point3d>> connexions, Stack<TreeNode> nodes)
  367.         {
  368.             if (node == null) return;
  369.             Point3dCollection points = new Point3dCollection();
  370.             this.center = node.Value;
  371.             if (ignoreZ)
  372.             {
  373.                 this.center2d = this.center.Flatten();
  374.                 this.distanceTo = Distance2dTo;
  375.             }
  376.             else
  377.             {
  378.                 this.distanceTo = Distance3dTo;
  379.             }
  380.             foreach (TreeNode tn in nodes)
  381.             {
  382.                 TreeNode parent = tn.Parent;
  383.                 int d = parent.Depth % this.dimension;
  384.                 if (Math.Abs(this.center[d] - parent.Value[d]) <= radius)
  385.                 {
  386.                     GetNeighboursAtDistance(radius, tn, points);
  387.                 }
  388.             }
  389.             GetNeighboursAtDistance(radius, node.LeftChild, points);
  390.             GetNeighboursAtDistance(radius, node.RightChild, points);
  391.             for (int i = 0; i < points.Count; i++)
  392.             {
  393.                 connexions.Add(new Tuple<Point3d, Point3d>(this.center, points[i]));
  394.             }
  395.  
  396.             if (node.RightChild != null)
  397.             {
  398.                 nodes.Push(node.RightChild);
  399.             }
  400.             else if (node.LeftChild == null && nodes.Count > 0)
  401.             {
  402.                 nodes.Pop();
  403.             }
  404.             GetConnexions(node.LeftChild, radius, connexions, nodes);
  405.             GetConnexions(node.RightChild, radius, connexions, nodes);
  406.         }
  407.  
  408.         private double Distance2dTo(Point3d pt)
  409.         {
  410.             return this.center2d.GetDistanceTo(pt.Flatten());
  411.         }
  412.  
  413.         private double Distance3dTo(Point3d pt)
  414.         {
  415.             return this.center.DistanceTo(pt);
  416.         }
  417.  
  418.         #endregion
  419.     }
  420.  
  421.     static class Extensions
  422.     {
  423.         public static Point2d Flatten(this Point3d pt)
  424.         {
  425.             return new Point2d(pt.X, pt.Y);
  426.         }
  427.  
  428.         // Credit: Tony Tanzillo
  429.         // http://www.theswamp.org/index.php?topic=44312.msg495808#msg495808
  430.         public static T QuickSelectMedian<T>(this T[] items, Comparison<T> compare)
  431.         {
  432.             int l = items.Length;
  433.             int k = l / 2;
  434.             if (items == null || l == 0)
  435.                 throw new ArgumentException("array");
  436.             int from = 0;
  437.             int to = l - 1;
  438.             while (from < to)
  439.             {
  440.                 int r = from;
  441.                 int w = to;
  442.                 T current = items[(r + w) / 2];
  443.                 while (r < w)
  444.                 {
  445.                     if (compare(items[r], current) > -1)
  446.                     {
  447.                         var tmp = items[w];
  448.                         items[w] = items[r];
  449.                         items[r] = tmp;
  450.                         w--;
  451.                     }
  452.                     else
  453.                     {
  454.                         r++;
  455.                     }
  456.                 }
  457.                 if (compare(items[r], current) > 0)
  458.                 {
  459.                     r--;
  460.                 }
  461.                 if (k <= r)
  462.                 {
  463.                     to = r;
  464.                 }
  465.                 else
  466.                 {
  467.                     from = r + 1;
  468.                 }
  469.             }
  470.             return items[k];
  471.         }
  472.     }
  473. }
  474.  
« Last Edit: April 13, 2013, 05:41:57 PM by gile »
Speaking English as a French Frog

TheMaster

  • Guest
Re: Kd Tree
« Reply #22 on: April 12, 2013, 03:09:57 PM »
The last one was was almost used for performance tests (and to reply to this challenge).

I read that thread, but has anyone compared your implementation to the C/C++ based ones from that thread?

I also noticed in one post, that someone is using (getvar "DATE") to measure times - which is not very accurate (for sub-second times).

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Re: Kd Tree
« Reply #23 on: April 12, 2013, 04:12:46 PM »
I read that thread, but has anyone compared your implementation to the C/C++ based ones from that thread?
I don't think so, this thread is quite old and all the C/C++ part was/is other my head. I remember there was impressive LISP results from Evgeniy.
Speaking English as a French Frog

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Re: Kd Tree
« Reply #24 on: April 13, 2013, 05:32:02 PM »
I changed the way to avoid duplicated lines in the ConnectAll. Rather than a boolean 'Processed' property in the TreeNode, I now use a Stack<TreeNode> containing the  right children of the parents of the current node where to search (or not if far enough) for neighbours. This improves a little the 'connect points" part speed.
The tree is now double linked.

Some results:
Code - Text: [Select]
  1. 100K points (128 026 lines)
  2. Get points:      220
  3. Build tree:      114
  4. Connect points:  199
  5. Draw lines:      764
  6. Total:          1297
  7.  
  8. 1 Million points (1 280 741 lines)
  9. Get points:     2118
  10. Build tree:     1388
  11. Connect points: 2107
  12. Draw lines:    12046
  13. Total:         17659

The testing command:
Code - C#: [Select]
  1.         [CommandMethod("CONNECT")]
  2.         public void Connect()
  3.         {
  4.             Document doc = AcAp.DocumentManager.MdiActiveDocument;
  5.             Database db = doc.Database;
  6.             Editor ed = doc.Editor;
  7.             try
  8.             {
  9.                 System.Diagnostics.Stopwatch sw = new System.Diagnostics.Stopwatch();
  10.                 sw.Start();
  11.                 RXClass rxc = RXClass.GetClass(typeof(DBPoint));
  12.                 using (BlockTableRecord btr = (BlockTableRecord)db.CurrentSpaceId.Open(OpenMode.ForWrite))
  13.                 {
  14.                     long t0 = sw.ElapsedMilliseconds;
  15.                     ObjectId[] ids = btr.Cast<ObjectId>().Where(id => id.ObjectClass == rxc).ToArray();
  16.                     int len = ids.Length;
  17.                     Point3d[] pts = new Point3d[len];
  18.                     for (int i = 0; i < len; i++)
  19.                     {
  20.                         using (DBPoint pt = (DBPoint)ids[i].Open(OpenMode.ForRead))
  21.                         {
  22.                             pts[i] = pt.Position;
  23.                         }
  24.                     }
  25.                     long t1 = sw.ElapsedMilliseconds;
  26.                     Point3dTree tree = new Point3dTree(pts, true); // <- true = 2d tree, false = 3d tree
  27.                     long t2 = sw.ElapsedMilliseconds;
  28.                     List<Tuple<Point3d, Point3d>> pairs = tree.ConnectAll(70.0);
  29.                     long t3 = sw.ElapsedMilliseconds;
  30.                     foreach (var pair in pairs)
  31.                     {
  32.                         using (Line line = new Line(pair.Item1, pair.Item2))
  33.                         {
  34.                             btr.AppendEntity(line);
  35.                         }
  36.                     }
  37.                     db.TransactionManager.QueueForGraphicsFlush();
  38.                     sw.Stop();
  39.                     long t4 = sw.ElapsedMilliseconds;
  40.                     ed.WriteMessage(
  41.                         "\nGet points:{0,9}\nBuild tree:{1,9}\nConnect points:{2,5}\nDraw lines:{3,9}\nTotal:{4,14}",
  42.                         t1 - t0, t2 - t1, t3 - t2, t4 - t3, t4);
  43.                 }
  44.             }
  45.             catch (System.Exception ex)
  46.             {
  47.                 ed.WriteMessage("\n{0}\n{1}", ex.Message, ex.StackTrace);
  48.             }
  49.         }
« Last Edit: April 13, 2013, 05:36:49 PM by gile »
Speaking English as a French Frog

TheMaster

  • Guest
Re: Kd Tree
« Reply #25 on: April 14, 2013, 01:25:23 PM »
I read that thread, but has anyone compared your implementation to the C/C++ based ones from that thread?
I don't think so, this thread is quite old and all the C/C++ part was/is other my head. I remember there was impressive LISP results from Evgeniy.

I would expect that once a dataset is indexed, locating points within a given distance should not be too difficult in any language, since the whole point to indexing the data is to minimize the amount of work needed to find nearest coordinates.

LISP is well-suited to tree-based, recursive algorithms, so it wouldn't surprise me that can find points in an index dataset quickly. But, I would be surprised if he was able to construct the tree as fast as it can be done in .NET or native code. I very much doubt that, if for no other reason, because of the inherent overhead of dealing with immutable lists.
« Last Edit: April 14, 2013, 01:34:20 PM by TT »

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Re: Kd Tree
« Reply #26 on: April 14, 2013, 01:52:17 PM »
Hi Tony,

It seems to me he didn't use a tree structure but a 'divide and conquer' algorithm as shown by this picture.
By my side, I tried to implement a more reusable (and may be usefull) structure.

I found another way to avoid duplicated connexions with the ConnectAll() method. Rather than storing nodes in a stack, I can get the 'right parents nodes' of the current node using a new property in the TreeNode which indicates if the the node is a left child node or not. This seems to improve a little ConnectAll() speed.

Here's the new implementation.

New version:
- Replaced the using global fields by function parameters so that the public methods can be used in parallel execution.
- Using of square distances

Code - C#: [Select]
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Autodesk.AutoCAD.Geometry;
  5.  
  6. namespace PointKdTree
  7. {
  8.     /// <summary>
  9.     /// Node of a Point3dTree
  10.     /// </summary>
  11.     public class TreeNode
  12.     {
  13.         /// <summary>
  14.         /// Creates a new instance of TreeNode
  15.         /// </summary>
  16.         /// <param name="value">The 3d point value of the node.</param>
  17.         public TreeNode(Point3d value) { this.Value = value; }
  18.  
  19.         /// <summary>
  20.         /// Gets the value (Point3d) of the node.
  21.         /// </summary>
  22.         public Point3d Value { get; internal set; }
  23.  
  24.         /// <summary>
  25.         /// Gets the parent node.
  26.         /// </summary>
  27.         public TreeNode Parent { get; internal set; }
  28.  
  29.         /// <summary>
  30.         /// Gets the left child node.
  31.         /// </summary>
  32.         public TreeNode LeftChild { get; internal set; }
  33.  
  34.         /// <summary>
  35.         /// Gets the right child node.
  36.         /// </summary>
  37.         public TreeNode RightChild { get; internal set; }
  38.  
  39.         /// <summary>
  40.         /// Gets the depth of the node in tree.
  41.         /// </summary>
  42.         public int Depth { get; internal set; }
  43.  
  44.         /// <summary>
  45.         /// Gets a value indicating if the current node is a LeftChild node.
  46.         /// </summary>
  47.         public bool IsLeft { get; internal set; }
  48.     }
  49.  
  50.     /// <summary>
  51.     /// Provides methods to organize 3d points in a Kd tree structure to speed up the search of neighbours.
  52.     /// A boolean constructor parameter (ignoreZ) indicates if the resulting Kd tree is a 3d tree or a 2d tree.
  53.     /// Use ignoreZ = true if all points in the input collection lie on a plane parallel to XY
  54.     /// or if the points have to be considered as projected on the XY plane.
  55.     /// </summary>
  56.     public class Point3dTree
  57.     {
  58.         #region Private fields
  59.  
  60.         private int dimension;
  61.         private int parallelDepth;
  62.         private bool ignoreZ;
  63.         private Func<Point3d, Point3d, double> sqrDist;
  64.  
  65.         #endregion
  66.  
  67.         #region Constructor
  68.  
  69.         /// <summary>
  70.         /// Creates an new instance of Point3dTree.
  71.         /// </summary>
  72.         /// <param name="points">The Point3d collection to fill the tree.</param>
  73.         /// <param name="ignoreZ">A value indicating if the Z coordinate of points is ignored
  74.         /// (as if all points were projected to the XY plane).</param>
  75.         public Point3dTree(IEnumerable<Point3d> points, bool ignoreZ = false)
  76.         {
  77.             if (points == null)
  78.                 throw new ArgumentNullException("points");
  79.             this.ignoreZ = ignoreZ;
  80.             this.dimension = ignoreZ ? 2 : 3;
  81.             if (ignoreZ)
  82.                 this.sqrDist = SqrDistance2d;
  83.             else
  84.                 this.sqrDist = SqrDistance3d;
  85.             int numProc = System.Environment.ProcessorCount;
  86.             this.parallelDepth = -1;
  87.             while (numProc >> ++this.parallelDepth > 1) ;
  88.             Point3d[] pts = points.Distinct().ToArray();
  89.             this.Root = Create(pts, 0, null, false);
  90.         }
  91.  
  92.         #endregion
  93.  
  94.         #region Public properties
  95.  
  96.         /// <summary>
  97.         /// Gets the root node of the tree.
  98.         /// </summary>
  99.         public TreeNode Root { get; private set; }
  100.  
  101.         #endregion
  102.  
  103.         #region Public methods
  104.  
  105.         /// <summary>
  106.         /// Gets the nearest neighbour.
  107.         /// </summary>
  108.         /// <param name="point">The point from which search the nearest neighbour.</param>
  109.         /// <returns>The nearest point in the collection from the specified one.</returns>
  110.         public Point3d NearestNeighbour(Point3d point)
  111.         {
  112.             return GetNeighbour(point, this.Root, this.Root.Value, double.MaxValue);
  113.         }
  114.  
  115.         /// <summary>
  116.         /// Gets the neighbours within the specified distance.
  117.         /// </summary>
  118.         /// <param name="point">The point from which search the nearest neighbours.</param>
  119.         /// <param name="radius">The distance in which collect the neighbours.</param>
  120.         /// <returns>The points which distance from the specified point is less or equal to the specified distance.</returns>
  121.         public Point3dCollection NearestNeighbours(Point3d point, double radius)
  122.         {
  123.             Point3dCollection points = new Point3dCollection();
  124.             GetNeighboursAtDistance(point, radius * radius, this.Root, points);
  125.             return points;
  126.         }
  127.  
  128.         /// <summary>
  129.         /// Gets the given number of nearest neighbours.
  130.         /// </summary>
  131.         /// <param name="point">The point from which search the nearest neighbours.</param>
  132.         /// <param name="number">The number of points to collect.</param>
  133.         /// <returns>The n nearest neighbours of the specified point.</returns>
  134.         public Point3dCollection NearestNeighbours(Point3d point, int number)
  135.         {
  136.             List<Tuple<double, Point3d>> pairs = new List<Tuple<double, Point3d>>(number);
  137.             GetKNeighbours(point, number, this.Root, pairs);
  138.             Point3dCollection points = new Point3dCollection();
  139.             for (int i = 0; i < pairs.Count; i++)
  140.             {
  141.                 points.Add(pairs[i].Item2);
  142.             }
  143.             return points;
  144.         }
  145.  
  146.         /// <summary>
  147.         /// Gets the points in a range.
  148.         /// </summary>
  149.         /// <param name="pt1">The first corner of range.</param>
  150.         /// <param name="pt2">The opposite corner of the range.</param>
  151.         /// <returns>All points within the box.</returns>
  152.         public Point3dCollection BoxedRange(Point3d pt1, Point3d pt2)
  153.         {
  154.             Point3d lowerLeft = new Point3d(
  155.                 Math.Min(pt1.X, pt2.X), Math.Min(pt1.Y, pt2.Y), Math.Min(pt1.Z, pt2.Z));
  156.             Point3d upperRight = new Point3d(
  157.                 Math.Max(pt1.X, pt2.X), Math.Max(pt1.Y, pt2.Y), Math.Max(pt1.Z, pt2.Z));
  158.             Point3dCollection points = new Point3dCollection();
  159.             FindRange(lowerLeft, upperRight, this.Root, points);
  160.             return points;
  161.         }
  162.  
  163.         /// <summary>
  164.         /// Gets all the pairs of points which distance is less or equal than the specified distance.
  165.         /// </summary>
  166.         /// <param name="radius">The maximum distance between two points. </param>
  167.         /// <returns>The pairs of points which distance is less or equal than the specified distance.</returns>
  168.         public List<Tuple<Point3d, Point3d>> ConnectAll(double radius)
  169.         {
  170.             List<Tuple<Point3d, Point3d>> connexions = new List<Tuple<Point3d, Point3d>>();
  171.             GetConnexions(this.Root, radius * radius, connexions);
  172.             return connexions;
  173.         }
  174.  
  175.         #endregion
  176.  
  177.         #region Private methods
  178.  
  179.         private TreeNode Create(Point3d[] points, int depth, TreeNode parent, bool isLeft)
  180.         {
  181.             int length = points.Length;
  182.             if (length == 0) return null;
  183.             int d = depth % this.dimension;
  184.             Point3d median = points.QuickSelectMedian((p1, p2) => p1[d].CompareTo(p2[d]));
  185.             TreeNode node = new TreeNode(median);
  186.             node.Depth = depth;
  187.             node.Parent = parent;
  188.             node.IsLeft = isLeft;
  189.             int mid = length / 2;
  190.             int rlen = length - mid - 1;
  191.             Point3d[] left = new Point3d[mid];
  192.             Point3d[] right = new Point3d[rlen];
  193.             Array.Copy(points, 0, left, 0, mid);
  194.             Array.Copy(points, mid + 1, right, 0, rlen);
  195.             if (depth < this.parallelDepth)
  196.             {
  197.                 System.Threading.Tasks.Parallel.Invoke(
  198.                    () => node.LeftChild = Create(left, depth + 1, node, true),
  199.                    () => node.RightChild = Create(right, depth + 1, node, false)
  200.                 );
  201.             }
  202.             else
  203.             {
  204.                 node.LeftChild = Create(left, depth + 1, node, true);
  205.                 node.RightChild = Create(right, depth + 1, node, false);
  206.             }
  207.             return node;
  208.         }
  209.  
  210.         private Point3d GetNeighbour(Point3d center, TreeNode node, Point3d currentBest, double bestDist)
  211.         {
  212.             if (node == null)
  213.                 return currentBest;
  214.             Point3d current = node.Value;
  215.             int d = node.Depth % this.dimension;
  216.             double coordCen = center[d];
  217.             double coordCur = current[d];
  218.             double dist = this.sqrDist(center, current);
  219.             if (dist >= 0.0 && dist < bestDist)
  220.             {
  221.                 currentBest = current;
  222.                 bestDist = dist;
  223.             }
  224.             dist = coordCen - coordCur;
  225.             if (bestDist < dist * dist)
  226.             {
  227.                 currentBest = GetNeighbour(
  228.                     center, coordCen < coordCur ? node.LeftChild : node.RightChild, currentBest, bestDist);
  229.                 bestDist = this.sqrDist(center, currentBest);
  230.             }
  231.             else
  232.             {
  233.                 currentBest = GetNeighbour(center, node.LeftChild, currentBest, bestDist);
  234.                 bestDist = this.sqrDist(center, currentBest);
  235.                 currentBest = GetNeighbour(center, node.RightChild, currentBest, bestDist);
  236.                 bestDist = this.sqrDist(center, currentBest);
  237.             }
  238.             return currentBest;
  239.         }
  240.  
  241.         private void GetNeighboursAtDistance(Point3d center, double radius, TreeNode node, Point3dCollection points)
  242.         {
  243.             if (node == null) return;
  244.             Point3d current = node.Value;
  245.             double dist = this.sqrDist(center, current);
  246.             if (dist <= radius)
  247.             {
  248.                 points.Add(current);
  249.             }
  250.             int d = node.Depth % this.dimension;
  251.             double coordCen = center[d];
  252.             double coordCur = current[d];
  253.             dist = coordCen - coordCur;
  254.             if (dist * dist > radius)
  255.             {
  256.                 if (coordCen < coordCur)
  257.                 {
  258.                     GetNeighboursAtDistance(center, radius, node.LeftChild, points);
  259.                 }
  260.                 else
  261.                 {
  262.                     GetNeighboursAtDistance(center, radius, node.RightChild, points);
  263.                 }
  264.             }
  265.             else
  266.             {
  267.                 GetNeighboursAtDistance(center, radius, node.LeftChild, points);
  268.                 GetNeighboursAtDistance(center, radius, node.RightChild, points);
  269.             }
  270.         }
  271.  
  272.         private void GetKNeighbours(Point3d center, int number, TreeNode node, List<Tuple<double, Point3d>> pairs)
  273.         {
  274.             if (node == null) return;
  275.             Point3d current = node.Value;
  276.             double dist = this.sqrDist(center, current);
  277.             int cnt = pairs.Count;
  278.             if (cnt == 0)
  279.             {
  280.                 pairs.Add(new Tuple<double, Point3d>(dist, current));
  281.             }
  282.             else if (cnt < number)
  283.             {
  284.                 if (dist > pairs[0].Item1)
  285.                 {
  286.                     pairs.Insert(0, new Tuple<double, Point3d>(dist, current));
  287.                 }
  288.                 else
  289.                 {
  290.                     pairs.Add(new Tuple<double, Point3d>(dist, current));
  291.                 }
  292.             }
  293.             else if (dist < pairs[0].Item1)
  294.             {
  295.                 pairs[0] = new Tuple<double, Point3d>(dist, current);
  296.                 pairs.Sort((p1, p2) => p2.Item1.CompareTo(p1.Item1));
  297.             }
  298.             int d = node.Depth % this.dimension;
  299.             double coordCen = center[d];
  300.             double coordCur = current[d];
  301.             dist = coordCen - coordCur;
  302.             if (dist * dist > pairs[0].Item1)
  303.             {
  304.                 if (coordCen < coordCur)
  305.                 {
  306.                     GetKNeighbours(center, number, node.LeftChild, pairs);
  307.                 }
  308.                 else
  309.                 {
  310.                     GetKNeighbours(center, number, node.RightChild, pairs);
  311.                 }
  312.             }
  313.             else
  314.             {
  315.                 GetKNeighbours(center, number, node.LeftChild, pairs);
  316.                 GetKNeighbours(center, number, node.RightChild, pairs);
  317.             }
  318.         }
  319.  
  320.         private void FindRange(Point3d lowerLeft, Point3d upperRight, TreeNode node, Point3dCollection points)
  321.         {
  322.             if (node == null)
  323.                 return;
  324.             Point3d current = node.Value;
  325.             if (ignoreZ)
  326.             {
  327.                 if (current.X >= lowerLeft.X && current.X <= upperRight.X &&
  328.                     current.Y >= lowerLeft.Y && current.Y <= upperRight.Y)
  329.                     points.Add(current);
  330.             }
  331.             else
  332.             {
  333.                 if (current.X >= lowerLeft.X && current.X <= upperRight.X &&
  334.                     current.Y >= lowerLeft.Y && current.Y <= upperRight.Y &&
  335.                     current.Z >= lowerLeft.Z && current.Z <= upperRight.Z)
  336.                     points.Add(current);
  337.             }
  338.             int d = node.Depth % this.dimension;
  339.             if (upperRight[d] < current[d])
  340.                 FindRange(lowerLeft, upperRight, node.LeftChild, points);
  341.             else if (lowerLeft[d] > current[d])
  342.                 FindRange(lowerLeft, upperRight, node.RightChild, points);
  343.             else
  344.             {
  345.                 FindRange(lowerLeft, upperRight, node.LeftChild, points);
  346.                 FindRange(lowerLeft, upperRight, node.RightChild, points);
  347.             }
  348.         }
  349.  
  350.         private void GetConnexions(TreeNode node, double radius, List<Tuple<Point3d, Point3d>> connexions)
  351.         {
  352.             if (node == null) return;
  353.             Point3dCollection points = new Point3dCollection();
  354.             Point3d center = node.Value;
  355.             if (ignoreZ)
  356.             GetRightParentsNeighbours(center, node, radius, points);
  357.             GetNeighboursAtDistance(center, radius, node.LeftChild, points);
  358.             GetNeighboursAtDistance(center, radius, node.RightChild, points);
  359.             for (int i = 0; i < points.Count; i++)
  360.             {
  361.                 connexions.Add(new Tuple<Point3d, Point3d>(center, points[i]));
  362.             }
  363.             GetConnexions(node.LeftChild, radius, connexions);
  364.             GetConnexions(node.RightChild, radius, connexions);
  365.         }
  366.  
  367.         private void GetRightParentsNeighbours(Point3d center, TreeNode node, double radius, Point3dCollection points)
  368.         {
  369.             TreeNode parent = GetRightParent(node);
  370.             if (parent == null) return;
  371.             int d = parent.Depth % this.dimension;
  372.             double dist = center[d] - parent.Value[d];
  373.             if (dist * dist <= radius)
  374.             {
  375.                 GetNeighboursAtDistance(center, radius, parent.RightChild, points);
  376.             }
  377.             GetRightParentsNeighbours(center, parent, radius, points);
  378.         }
  379.  
  380.         private TreeNode GetRightParent(TreeNode node)
  381.         {
  382.             TreeNode parent = node.Parent;
  383.             if (parent == null) return null;
  384.             if (node.IsLeft) return parent;
  385.             return GetRightParent(parent);
  386.         }
  387.  
  388.         private double SqrDistance2d(Point3d p1, Point3d p2)
  389.         {
  390.             return (p1.X - p2.X) * (p1.X - p2.X) +
  391.                 (p1.Y - p2.Y) * (p1.Y - p2.Y);
  392.         }
  393.  
  394.         private double SqrDistance3d(Point3d p1, Point3d p2)
  395.         {
  396.             return (p1.X - p2.X) * (p1.X - p2.X) +
  397.                 (p1.Y - p2.Y) * (p1.Y - p2.Y) +
  398.                 (p1.Z - p2.Z) * (p1.Z - p2.Z);
  399.         }
  400.  
  401.         #endregion
  402.     }
  403.  
  404.     static class Extensions
  405.     {
  406.         // Credit: Tony Tanzillo
  407.         // http://www.theswamp.org/index.php?topic=44312.msg495808#msg495808
  408.         public static T QuickSelectMedian<T>(this T[] items, Comparison<T> compare)
  409.         {
  410.             int l = items.Length;
  411.             int k = l / 2;
  412.             if (items == null || l == 0)
  413.                 throw new ArgumentException("array");
  414.             int from = 0;
  415.             int to = l - 1;
  416.             while (from < to)
  417.             {
  418.                 int r = from;
  419.                 int w = to;
  420.                 T current = items[(r + w) / 2];
  421.                 while (r < w)
  422.                 {
  423.                     if (compare(items[r], current) > -1)
  424.                     {
  425.                         var tmp = items[w];
  426.                         items[w] = items[r];
  427.                         items[r] = tmp;
  428.                         w--;
  429.                     }
  430.                     else
  431.                     {
  432.                         r++;
  433.                     }
  434.                 }
  435.                 if (compare(items[r], current) > 0)
  436.                 {
  437.                     r--;
  438.                 }
  439.                 if (k <= r)
  440.                 {
  441.                     to = r;
  442.                 }
  443.                 else
  444.                 {
  445.                     from = r + 1;
  446.                 }
  447.             }
  448.             return items[k];
  449.         }
  450.     }
  451. }
  452.  
« Last Edit: April 19, 2013, 01:05:53 PM by gile »
Speaking English as a French Frog

TheMaster

  • Guest
Re: Kd Tree
« Reply #27 on: April 14, 2013, 02:59:15 PM »
Hi Tony,

It seems to me he didn't use a tree structure but a 'divide and conquer' algorithm as shown by this picture.
By my side, I tried to implement a more reusable (and may be usefull) structure.

I found another way to avoid duplicated connexions with the ConnectAll() method. Rather than storing nodes in a stack, I can get the 'right parents nodes' of the current node using a new property in the TreeNode which indicates if the the node is a left child node or not. This seems to improve a little ConnectAll() speed.


Hi Gile. Very good.

Here's one more suggestion: Depending on a number of factors like point distribution, there can be a very large number of leaf nodes in the tree (leaf nodes are nodes with no child nodes). You can save some memory by not having fields for child nodes on leaf nodes, by using two different classes for leaf and branch nodes, using something like this:

Code - C#: [Select]
  1.  
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6.  
  7. namespace Graph.Nodes
  8. {
  9.  
  10.    /// Use for leaf nodes with no children:
  11.    /// (contains no fields for child nodes)
  12.    
  13.    public class Node<T>
  14.    {
  15.       protected Node( T value )
  16.       {
  17.          this.Value = value;
  18.       }
  19.  
  20.       public T Value
  21.       {
  22.          get;
  23.          set;
  24.       }
  25.  
  26.       public virtual Node<T> Left
  27.       {
  28.          get
  29.          {
  30.             return null;
  31.          }
  32.          protected set
  33.          {
  34.          }
  35.       }
  36.  
  37.       public virtual Node<T> Right
  38.       {
  39.          get
  40.          {
  41.             return null;
  42.          }
  43.          protected set
  44.          {
  45.          }
  46.       }
  47.  
  48.       // Create leaf or branch node depending on
  49.       // if child nodes are provided:
  50.      
  51.       public static Node<T> Create( T value, Node<T> left = null, Node<T> right = null )
  52.       {
  53.          if( left == null && right == null )
  54.             return new Node<T>( value );
  55.          else
  56.             return new BranchNode<T>( value, left, right );
  57.       }
  58.    }
  59.    
  60.    /// Used for branch nodes with 1 or 2 children
  61.  
  62.    public class BranchNode<T> : Node<T>
  63.    {
  64.       Node<T> left = null;
  65.       Node<T> right = null;
  66.  
  67.       protected internal BranchNode( T value, Node<T> left, Node<T> right )
  68.          : base( value )
  69.       {
  70.          this.left = left;
  71.          this.right = right;
  72.       }
  73.  
  74.       public override Node<T> Left
  75.       {
  76.          get
  77.          {
  78.             return left;
  79.          }
  80.          protected set
  81.          {
  82.             left = value;
  83.          }
  84.       }
  85.  
  86.       public override Node<T> Right
  87.       {
  88.          get
  89.          {
  90.             return right;
  91.          }
  92.          protected set
  93.          {
  94.             right = value;
  95.          }
  96.       }
  97.    }
  98.  
  99. }
  100.  
  101.  

gile

  • Gator
  • Posts: 2507
  • Marseille, France
Re: Kd Tree
« Reply #28 on: April 14, 2013, 04:46:21 PM »
Thanks Tony, I'll give a try.
As the tree is balanced, there would never be more than half number of points plus one leaf nodes.
Speaking English as a French Frog

Jerry J

  • Newt
  • Posts: 48
Re: Kd Tree
« Reply #29 on: April 16, 2013, 04:07:34 AM »
Thanks for stretching my peanut of a brain.