Point on an Ellipsoid of Minimum Distance to Another Point in Space
Posted by cheshirekow in Math on September 4, 2009
This is a response to a question posted on the irrlicht forums at here.
PDF file:
Closest point on an ellipsoid to a point in space (to come)
If one wants to know the closest point on an ellipsoid to another point in space, it can be found using constrained nonlinear optimization. It was stated in the forum thread that there is no closed-form solution to this problem, though I find that rather surprising. I will update this post if I can confirm (or confirm an opposition) to that fact.
In any case, given a point in space, where
we want to find a point that minimizes the euclidean distance to
subject to the contraint that it is on an ellipsoid with the
following functional description
This corresponds to the minimization of the following cost function (the “cost” is the distance to the target point)
subject to the following constraint (which just says that it must be on the ellipsoid)
Now there are many ways to solve this problem. I will outline two of them here known as steepest descent (or gradient search) and the leGrange method.
Steepest Descent
First comes steepest decent, because it is the easiest to understand graphically. The goal here is that, starting at some point on the ellipse, we will look at how the value of the cost function changes depending on the direction we move, and choose to move in the direction that most improves the cost. To begin with, let us redefine the problem in terms of two independent variables, and let the third be dependent. In the following formulation, I will choose to be the dependent variable, but you should note that this is not always the best choice. A good choice of dependent variable is to look at the coordinates of the target point and pick the one that is furthest from the origin (to avoid numerically ill-conditioned situations).
Using as the dependent variable we can reformulate the constraint as the following
Given this, we exam the effect that changing each of our independent variables has on our cost function, by calculating the partial derivatives. First though, we note that the minimum of a square root expression will occur at the same point as the minimum of the argument, so we will throw away the square root part and use
for the derivatives instead.
And given the equation above for we can determine the following
which simplifies to
and similarly
Note that the sign of these derivatives is opposite the sign of . Therefore we see the following
where is the signum function and returns if the argument is less than zero ().
Now we solve for the minimum point using the following algorithm
- start at some and
- calculate
- pick whichever of the two is closer by picking the one with the same sign as
- calculate the cost
- if of the current iteration is within some tolerance of at the previous iteration, stop here
- calcluate and
- calculate and from , , and
- update and using and
- return to (1)
Here is a matlab code that demonstrates this algorithm. Note that in this code, for the initial guess I use the intersection of the line segment between the origin of the ellipse and with the surface of the ellipse.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | clc; clear; % define the ellipsoid parameters a = 3; b = 7; c = 9; % createa a 100 x 100 point mesh for displaying the ellipsoid x = linspace( -a, a, 50 ); y = linspace( -b, b, 50 ); % calculate the z coordinates [X,Y] = meshgrid(x,y); Zp = c*sqrt(1 - X.*X/(a*a) - Y.*Y/(b*b) ); Zm = -Zp; % create the point we're going to calculate the closest too px = 3; py = 4; pz = 8; % create an initial guess by finding the intersection of the ray from the % centroid to point with the ellipsoid surface k = sqrt( 1 / ( (px^2)/(a^2) + (py^2)/(b^2) + (pz^2)/(c^2) ) ); qx = k*px; qy = k*py; qz = k*pz; % this is the scale factor for our movement along the ellipse, this can % start off quite large because it will be reduced as needed s = 1; % this is the main loop, you'll want to put some kind of tolerance based % terminal condition, i.e. when the cost function (distance) is only % bouncing back and forth between a tolerance or something more tailored to % your application, I will stop when the distance is +/- 1% i = 1; while 1 % calculate the z for our given x and y qz_plus = c * sqrt( 1 - qx*qx/(a*a) - qy*qy/(b*b) ); qz_minus = - qz_plus; % we want the one that get's us closest to the target so if( abs(pz - qz_plus) < abs(pz - qz_minus) ) qz = qz_plus; else qz = qz_minus; end % calculate the current value of the cost function J = sqrt( (px - qx)^2 + (py - qy)^2 + (pz - qz)^2 ); % store the current values for the plots qxplot(i) = qx; qyplot(i) = qy; qzplot(i) = qz; Jplot(i) = J; % check to see if we overshot the goal or jumped off the ellipsoid if( i > 1 ) % if we jumped off the ellipsoid or overshot the minimal cost if( imag(qz) ~= 0 || J > Jplot(i-1) ); % then go back to the previous position and use a finer % step size qx = qxplot(i-1); qy = qyplot(i-1); qz = qzplot(i-1); J = Jplot(i-1); s = s/10; i = i-1; % if we did just jump over the actual minimum, let's check to % see how confident we are at this point, if we're with in % 1% there's no need to continue if( i > 3 ) if( abs( (Jplot(i-1) - Jplot(i-3)) / Jplot(i-3) ) < 0.001 ) break; end end end end % calculate the gradient of the cost with respect to our control % variables; since we divide by qz in the second term we first need % to check whether or not qz is too close to zero; if it is, we know % that the second term should be zero if( qz < 1e-10 ) dJdx = -2*(px - qx); dJdy = -2*(py - qy); else dJdx = -2*(px - qx) + 2*sign(qz)*(pz - qz)*( c/(a*a) * qx/qz ); dJdy = -2*(py - qy) + 2*sign(qz)*(pz - qz)*( c/(b*b) * qy/qz ); end % calculate the update \vector that we will move along dx = -J/dJdx; dy = -J/dJdy; % calculate the magnitude of that update \vector magnitude = sqrt( dx*dx + dy*dy ); % normalize our update \vector so that we don't shoot off at an % uncontrolable rate dx = s * (1/magnitude) *dx; dy = s * (1/magnitude) *dy; % update the current position qx = qx + dx; qy = qy + dy; % increment the index i = i+1; end; index = 1:i+1; % generate a little report message = sprintf( ... ['The closest point was found at [%f, %f, %f]', ... 'with a distance of %f and a confidence of +/- %f %%'], ... qx, qy, qz, J, 100*abs((Jplot(i+1)-Jplot(i-1)) / Jplot(i-1)) ); disp(message); % now we'll verify that the one we found is the closest Jmap_plus = sqrt( (px - X).*(px - X) + (py - Y).*(py - Y) + (pz - Zp).*(pz - Zp ) ); Jmap_minus = sqrt( (px - X).*(px - X) + (py - Y).*(py - Y) + (pz - Zp).*(pz - Zm ) ); [Jmin_plus_row, imin_plus_row] = min( Jmap_plus ); [Jmin_plus_col, imin_plus_col] = min( Jmin_plus_row ); imin = imin_plus_row(imin_plus_col); jmin = imin_plus_col; xmin = X(imin, jmin); ymin = Y(imin, jmin); zmin = Zp(imin, jmin); Jmin = Jmin_plus_col; Jmin2 = Jmap_plus(imin,jmin); message = sprintf( ... ['The closest point by searching the mesh was at [%f, %f, %f]' ... 'and has a distance of %f (%f)'], xmin, ymin, zmin, Jmin, Jmin2 ); disp(message); % and draw some pretty pictures close(figure(1)); figure(1); grid on; hold on; mesh(X,Y,real(Zp), 'FaceAlpha', 0.5); mesh(X,Y,real(Zm), 'FaceAlpha', 0.5); plot3(px,py,pz,'ko'); plot3(qxplot,qyplot,qzplot,'ko'); plot3(qxplot,qyplot,qzplot,'k-'); plot3([px qx], [py qy], [pz qz], 'b-'); xlabel('X'); ylabel('Y'); zlabel('Z'); title('steepest decent method'); hold off; close(figure(2)); figure(2); plot(index, Jplot, 'g-'); xlabel( 'iteration' ); ylabel( 'distance (cost)' ); close(figure(3)); figure(3); mesh(X,Y,real(Jmap_plus), 'FaceAlpha', 0.5); xlabel('X'); ylabel('Y'); zlabel('Z'); title('distance to target'); |
The output of this script is
The closest point was found at [1.321810, 3.187960, 6.962411]with a distance of 2.133617 and a confidence of +/- 0.000000 %
The closest point by searching the mesh was at [1.285714, 3.285714, 6.948103]and has a distance of 2.134354 (2.134354)
You can see the work done by the gradient search in the first figure, which is shown below:
I know this doesn’tlook like the closest point, but, as you can see, the script actually searches the mesh for the lowest value point and comes up with essentially the same point. To demonstrate, here is a 3d plot of the distance of each point on the upper half of the ellipse to the target point, with the minimum point labeled. This is the output of the script to figure 3. The reason it doesn’t look right is because of how matlab squashes everything to fit the window.
And finally here is a plot of the distance calculated at each iteration of the loop. As you can see, it doesn’t take many to find a value that is pretty damn good.
Steepest decent works pretty well right? Well, not always. The initial guess used in the code above will save your ass, but steepest descent can suffer from problems of local minima. Look what happens when I try an initial guess on the other side of the ellipsoid.
The algorithm get’s caught in a local minimum, and converges to a very precise, very incorrect solution.
LeGrange Method
This method is a little more elegant, but still ends up with a search algorithm. However, in this case we’re only doing a one dimensional search for root-finding of a polynomial, which is a quite well documented problem.
In this method, we note that, since the constraint must be satisfied everywhere, if we know the minimal cost is achieved at a point , then we know that
and we also know that if the augmented cost function is defined as
which expands to
then and is also at since the latter term must be zero if the constraint is to be satisfied. We can then minimize the cost function as we usually do, by taking the derivative with respect to the independent variables, setting them to zero, and solving the system for the unknowns.
First though, we note that the minimum of a square root expression will occur at the same point as the minimum of the argument, so we will throw away the square root part and solve instead for the minimum of
We take the partial derivative of the augmented cost with respect to
the four unknowns and see that
.
.
.
Setting these all to zero we can solve for the following
.
.
Substituting back into the fourth of the differentials we get that
which simplifies to
Note that this is a sixth order polynomial in . Use your favorite zero-finding algorithm (same as a root finding algorithm) to solve for the zeros of this polynomial, (of which there should only be two as there can only be one closest point, and one furthest point on an ellipsoid unless the target point is along one of the axes), and you will find the minimum and the maximum of your cost function. Once the solutions of are known, plug them into the equations above to solve for , , and . Then plug those into the cost function, and pick the smaller of the two.
For the example problem used in the steepest descent code above, the polynomial looks like this:
Here is a detail view:
As you can see, there are two solutions. Below is a matlab code that generates the above graphs, and solves for the minimum point using a binary search. Note that the search assumes that one of the values of \lambda that makes the polynomial zero is to the left of the origin, and the other is to the right. That is true for this example buy may not be in general (I didn’t take the time to work out why that may be the case ) so you may have to modify the search to look for both roots on both sides of the origin.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | clc; clear; % define the ellipsoid parameters a = 3; b = 7; c = 9; % createa a 100 x 100 point mesh for displaying the ellipsoid x = linspace( -a, a, 50 ); y = linspace( -b, b, 50 ); % calculate the z coordinates [X,Y] = meshgrid(x,y); Zp = c*sqrt(1 - X.*X/(a*a) - Y.*Y/(b*b) ); Zm = -Zp; % create the point we're going to calculate the closest too px = 3; py = 4; pz = 8; \lambda = -160:0.01:20; a2 = a*a; b2 = b*b; c2 = c*c; px2 = px*px; py2 = py*py; pz2 = pz*pz; lpa = \lambda + a*a; lpb = \lambda + b*b; lpc = \lambda + c*c; P = a2*px2*lpb.*lpb.*lpc.*lpc + ... b2*py2*lpa.*lpa.*lpc.*lpc + ... c2*pz2*lpa.*lpa.*lpb.*lpb - ... lpa.*lpa.*lpb.*lpb.*lpc.*lpc; x = a2*px./lpa; y = b2*py./lpb; z = c2*pz./lpc; J = sqrt( (px-x).*(px-x) + (py-y).*(py-y) + (pz-z).*(pz-z) ); close(figure(1)); figure(1); plot( \lambda, P, 'b-', 'linewidth', 2 ); xlabel('$\lambda$', 'Interpreter', 'LaTex'); ylabel('$\frac{\partial L}{ \partial \lambda }$', 'Interpreter', 'LaTex'); title('value of the polynomial'); close(figure(2)); figure(2); plot( \lambda, P, 'b-', 'linewidth', 2 ); axis([-160,20,0, 1e10]); xlabel('$\lambda$', 'Interpreter', 'LaTex'); ylabel('$\frac{\partial L}{ \partial \lambda }$', 'Interpreter', 'LaTex'); title('detail of value of the polynomial'); % start search by looking at the value of the polynomial at zero \lambda = 0; lpa = \lambda + a*a; lpb = \lambda + b*b; lpc = \lambda + c*c; P = a2*px2*lpb.*lpb.*lpc.*lpc + ... b2*py2*lpa.*lpa.*lpc.*lpc + ... c2*pz2*lpa.*lpa.*lpb.*lpb - ... lpa.*lpa.*lpb.*lpb.*lpc.*lpc; % we need to store the start sign so that we know what sign we're looking % for in order to determine a zero crossing startSign = sign(P); % start our search in one direction i = 1; d\lambda = 1; lhs\lambda = 0; % start the forward search, we're looking for the first value we can % find that has a sign opposite that at the zero point while d\lambda < realmax/10 rhs\lambda = d\lambda; lpa = rhs\lambda + a*a; lpb = rhs\lambda + b*b; lpc = rhs\lambda + c*c; P = a2*px2*lpb.*lpb.*lpc.*lpc + ... b2*py2*lpa.*lpa.*lpc.*lpc + ... c2*pz2*lpa.*lpa.*lpb.*lpb - ... lpa.*lpa.*lpb.*lpb.*lpc.*lpc; % if we've discovered a sign change we can stop searching if( sign(P) ~= startSign ) lhs\lambda = d\lambda/10; break; % if not, we need to grow the increment else d\lambda = d\lambda * 10; end end % now we can start a bisection search using the lhs and rhs that we've % just determined, which are exactly one order of magnitude apart error = (rhs\lambda - lhs\lambda)/lhs\lambda; while( abs(error) > 1e-2) \lambda = (lhs\lambda + rhs\lambda)/2; lpa = \lambda + a*a; lpb = \lambda + b*b; lpc = \lambda + c*c; P = a2*px2*lpb.*lpb.*lpc.*lpc + ... b2*py2*lpa.*lpa.*lpc.*lpc + ... c2*pz2*lpa.*lpa.*lpb.*lpb - ... lpa.*lpa.*lpb.*lpb.*lpc.*lpc; if( sign(P) ~= startSign ) rhs\lambda = \lambda; else lhs\lambda = \lambda; end error = (lhs\lambda - rhs\lambda)/lhs\lambda; end % store the found value \lambda1 = (lhs\lambda + rhs\lambda)/2; % now we search in the other direction; d\lambda = 1; rhs\lambda = 0; % start the forward search, we're looking for the first value we can % find that has a sign opposite that at the zero point while d\lambda < realmax/10 lhs\lambda = -d\lambda; lpa = lhs\lambda + a*a; lpb = lhs\lambda + b*b; lpc = lhs\lambda + c*c; P = a2*px2*lpb.*lpb.*lpc.*lpc + ... b2*py2*lpa.*lpa.*lpc.*lpc + ... c2*pz2*lpa.*lpa.*lpb.*lpb - ... lpa.*lpa.*lpb.*lpb.*lpc.*lpc; % if we've discovered a sign change we can stop searching if( sign(P) ~= startSign ) rhs\lambda = -d\lambda/10; break; % if not, we need to grow the increment else d\lambda = d\lambda * 10; end end % now we can start a bisection search using the lhs and rhs that we've % just determined, which are exactly one order of magnitude apart error = (rhs\lambda - lhs\lambda)/lhs\lambda; while( abs(error) > 1e-2) \lambda = (lhs\lambda + rhs\lambda)/2; lpa = \lambda + a*a; lpb = \lambda + b*b; lpc = \lambda + c*c; P = a2*px2*lpb.*lpb.*lpc.*lpc + ... b2*py2*lpa.*lpa.*lpc.*lpc + ... c2*pz2*lpa.*lpa.*lpb.*lpb - ... lpa.*lpa.*lpb.*lpb.*lpc.*lpc; if( sign(P) ~= startSign ) lhs\lambda = \lambda; else rhs\lambda = \lambda; end error = (lhs\lambda - rhs\lambda)/lhs\lambda; end % store the found value \lambda2 = (lhs\lambda + rhs\lambda)/2; x1 = a2*px/(\lambda1+a2); y1 = b2*py/(\lambda1+b2); z1 = c2*pz/(\lambda1+c2); x2 = a2*px/(\lambda2+a2); y2 = b2*py/(\lambda2+b2); z2 = c2*pz/(\lambda2+c2); J1 = sqrt( (px-x1)*(px-x1) + (py-y1)*(py-y1) + (py-y1)*(py-y1) ); J2 = sqrt( (px-x2)*(px-x2) + (py-y2)*(py-y2) + (py-y2)*(py-y2) ); % print a little report message = sprintf( ['found zero crossings at \lambda = %f and %f', ... 'which corresponds to the points [%f, %f, %f]', ... 'and [%f, %f, %f] with costs of %f and %f'], ... \lambda1, \lambda2, x1, y1, z1, x2, y2, z2, J1, J2 ); disp(message); if( J1 < J2 ) message = sprintf( 'the point of minimum distance is [%f, %f, %f]', ... x1, y1, z1 ); else message = sprintf( 'the point of minimum distance is [%f, %f, %f]', ... x2, y2, z2 ); end disp(message); |
The output of this script is
found zero crossings at \lambda = 11.801758 and -155.810547which corresponds to the points [1.297967, 3.223591, 6.982626]and [-0.183910, -1.835025, -8.661880] with costs of 2.025472 and 8.844903
the point of minimum distance is [1.297967, 3.223591, 6.982626]
Which matches with the previous method. Sweet.
As a brief note, if the target point is inside the ellipse and along one of it’s axes, there may be an infinite number of solutions (all points on a ring around the ellipse located at the coordinate of the target point will be the same distance away). If this is a possibility in your application, make sure you watch out for it and decide what the “right” solution is.
I hope this helps. Sorry I didn’t write the code in c++ but it took long enough to generate the results without doing that. Hopefully you can figure out from the matlab code exactly what you need to. Matlab script isn’t too different from C++.
-Cheshirekow
Instanciate Objects of Unknown Type from Their Parent Interface
Posted by cheshirekow in C++ Discoveries and Notes, Programming on August 14, 2009
This is based on my previous two posts on Static Interfaces in C++ and Keep Track of and Enumerate All Sub-classes of a Particular Interface. The idea is that I want my code to be extensible in the feature without requiring any re-writing of the current code base. The code base operates on generic objects via their interfaces, so as long as newly-coded classes properly extend those interfaces, the program should know how to handle them. The problem is, how can we write the program in such a manner that a user interface can enumerate available options for implementations of a particular interface, and how can we instantiate those objects?
In Keep Track of and Enumerate All Sub-classes of a Particular Interface I showed how to maintain a registry of classes deriving from a given interface, which handles the first problem, but there is a limitation in that all of these classes must provide a factory method that takes no parameters (void input). I decided that, for my project, this was not acceptable and I needed a way to define the creation parameters as part of the factory methods, whereas the creation parameters may be different for particular interfaces.
In Keep Track of and Enumerate All Sub-classes of a Particular Interface I showed how we can enforce the requirement of a static method in derived classes with a particular signature using a template interface.
In this post I will combine the two so that we can create a registry of classes that inherit from a particular interface, and provide a static factory method for creating objects of that interface, using a particular creation method signature unique to that interface. The registry will pair class names with function pointers that match the specific signature of the interface the class is being registered for.
Disclaimer: I do not claim this is the “best” way to handle this issue. This is just what I came up with. It happens to be pretty involved and overly indirect, which means it’s probably bad design. It is, however, an extremely interesting exercise in generic programming.
Prequil: the code will require these later so there they are:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | /** * file RegistryTest.cpp * date: Aug 14, 2009 * brief: * * detail: */ #include <set> #include <map> #include <string> #include <iostream> using namespace std; |
Ok, so lets begin. First let’s define a couple of interfaces that we’re interested in.
16 17 18 | class InterfaceA{}; class InterfaceB{}; class InterfaceC{}; |
Now we create a template class whose sole purpose is to create a per-interface typedef
of the function signature that is necessary for instantiating and object of that class. Is it really possible that all sub-objects can be instantiated with the same parameters? If that’s the case, shouldn’t they all be combined into a single class that just contains that information as private members? Probably, but in my case these parameters are more like a “bare minimum” for instantiation, and then many more parameters are set by the user. It makes sense to me, I promise. If it doesn’t to you, you don’t have to use this.
19 20 21 22 23 24 | template< typename InterfaceType > class Factory { public: typedef InterfaceType*(*Creator)(void); }; |
Creator
is now a typedef
that aliases a function pointer that takes no parameters. Wait, isn’t that what we had before? Yes, but now we make a couple of template specializations to define the different signatures for our specific interfaces. These specializations would normally be in the file that contained the interface declaration.
25 26 27 28 29 30 31 32 33 34 35 36 37 38 | /// specializations can define other creators, this one requires an int template<> class Factory<InterfaceB> { public: typedef InterfaceB*(*Creator)(int); }; /// specializations can define other creators, this one requires an int, a /// bool, and a char template<> class Factory<InterfaceC> { public: typedef InterfaceC*(*Creator)(int,bool,char); }; |
Cool. Now we create a static interface that enforces it’s derivative classes to contain a static method called createNew
which can be used to instantiate a new object of that interface. We can use the typedef we just created to make the function signature generic for this template (or specific to individual instantiations of it).
39 40 41 42 43 44 45 46 47 48 | template<typename InterfaceType, typename ClassType> class IStaticFactory { public: IStaticFactory() { typename Factory<InterfaceType>::Creator check = ClassType::createNew; check = check; } }; |
Still following? Good. Now we define the registry class template, which maps the class name of a derived class to a function pointer with an interface-specific signature that serves as a static factory for objects of the derived class, returning a pointer to that object of the type of the interface. See my previous post for details on this class.
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | template <typename InterfaceType> class Registry { private: std::map< std::string, typename Factory<InterfaceType>::Creator > m_creatorMap; Registry(){} public: static Registry& getInstance(); bool registerClass( const std::string& name, typename Factory<InterfaceType>::Creator creator ); std::set<std::string> getClassNames(); typename Factory<InterfaceType>::Creator Registry<InterfaceType>::getCreator( std::string className ); }; // A convient macro to compact the registration of a class #define RegisterWithInterface( CLASS, INTERFACE ) namespace { bool dummy_ ## CLASS = Registry<INTERFACE>::getInstance().registerClass( #CLASS, CLASS::createNew ); } template <typename InterfaceType > Registry<InterfaceType>& Registry<InterfaceType>::getInstance() { static Registry<InterfaceType> registry; return registry; } template <typename InterfaceType > bool Registry<InterfaceType>::registerClass( const std::string& name, typename Factory<InterfaceType>::Creator creator ) { m_creatorMap[name] = creator; return true; } template <typename InterfaceType > std::set<std::string> Registry<InterfaceType>::getClassNames() { std::set<std::string> keys; typename std::map< std::string, InterfaceType* (*)(void) >::iterator pair; for( pair = m_creatorMap.begin(); pair != m_creatorMap.end(); pair++) keys.insert( pair->first ); return keys; } template <typename InterfaceType > typename Factory<InterfaceType>::Creator Registry<InterfaceType>::getCreator( std::string className ) { return m_creatorMap[className]; } |
The difference between this and the Registry in my previous post, is that this time the registry uses the generic Factory<InterfaceType>::Creator
typedef to define the function pointer. This way, that pointer is forced to have the specific signature. Sweet!
Now lets write some derived classes of those interfaces.
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | class DerivedA : public InterfaceA, public IStaticFactory<InterfaceA, DerivedA> { public: static InterfaceA* createNew(){ return (InterfaceA*)1; } }; RegisterWithInterface(DerivedA, InterfaceA); class DerivedB : public InterfaceB, public IStaticFactory<InterfaceB, DerivedB> { public: static InterfaceB* createNew(int a){ return (InterfaceB*)2; } }; RegisterWithInterface(DerivedB, InterfaceB); class DerivedC : public InterfaceC, public IStaticFactory<InterfaceC, DerivedC> { public: static InterfaceC* createNew(int a, bool b, char c){ return (InterfaceC*)3; } }; RegisterWithInterface(DerivedC, InterfaceC); |
These classes are basically dummies, but inheriting from IStaticFactory...
the compiler will enforce that they contain the static method createNew
with the proper signature. Notice that InterfaceA
uses the default template so the static factory in DerivedA
takes no parameters, while InterfaceB
and InterfaceC
have specializations so the static factories in DerivedB
and DerivedC
have their respective parameters. Since this is just an example, the methods don’t actually create new objects they just return pointers, but in reality this is where we would use new DerivedA(...)
and so on.
Well that’s it. Pretty cool huh? The compiler will enforce all this stuff for us so we can actually say to ourselves when we write new implementations months from now “If it compiles, it will be compatible.”
Lastly, here’s a little test case to run
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 | int main() { DerivedA a; DerivedB b; DerivedC c; InterfaceA* pA; InterfaceB* pB; InterfaceC* pC; Factory<InterfaceA>::Creator makesObjectOfA = Registry<InterfaceA>::getInstance().getCreator("DerivedA"); pA = (*makesObjectOfA)(); Factory<InterfaceB>::Creator makesObjectOfB = Registry<InterfaceB>::getInstance().getCreator("DerivedB"); pB = (*makesObjectOfB)(1); Factory<InterfaceC>::Creator makesObjectOfC = Registry<InterfaceC>::getInstance().getCreator("DerivedC"); pC = (*makesObjectOfC)(1,false,'a'); cout << "pA: " << pA << "n"; cout << "pB: " << pB << "n"; cout << "pC: " << pC << "n"; return 0; } |
Static Interfaces in C++
Posted by cheshirekow in C++ Discoveries and Notes, Programming on August 13, 2009
I remember looking around a few weeks ago for how to make a “static interface” in c++. Basically, I wanted a way to use the compiler to enforce that a class had certain static functions. Almost all of the internet resources I found basically said “Why would you ever want to do that; you don’t really want to do that; you probably have bad design” and so on… continuously begging the question. Of course, they were right: the design was bad and that wasn’t really what I wanted to do. Well, never the less, I still managed to think of a way to create a sort of static interface using a template class.
The strategy is to define a template class that uses the static methods of the template parameter class. That way, as long as the template is instantiated, the compiler will complain unless we have provided those static functions. We can ensure that the template is instantiated and enforce the inheritance idea by making the derived class extend from the template class we wrote to enforce those static methods.
Here is an example. We can create the static interface by declaring a class template that uses the functions we want to enforce as part of the interface.
1 2 3 4 5 6 7 8 9 10 | template < typename T > class StaticInterface { public: StaticInterface() { int(*fooCheck)(int) = T::foo; bool(*barCheck)(bool) = T::bar; } }; |
By assigning T::foo
and T::bar
to function pointers, we are saying, implicitly, that whatever class is provided as a parameter to this template must have a static method called foo
and a static method called bar
and, furthermore, that those static methods must have
the same signature as the function pointers we stuff them into.
By putting this code inside the constructor of the class, we know that this method of the template will be instantiated, even if we don’t explicitly use it later in the code, as long as we derive from this class somewhere. So then, the last question is, where can we derive from it?
Well, in the class that we want to inherit the interface of course!
11 12 13 14 15 16 | class DerivedClass : public StaticInterface<DerivedClass> { public: static int foo(int param){ return 10; } static bool bar(bool param){ return 20; } }; |
The DerivedClass
constructor implicitly calls the StaticInterface
constructor, which assigns the function pointers fooCheck
and barCheck
to the address of the functions DerivedClass::foo
and DerivedClass::bar
. As a result, if we forget the bar
function in the DerivedClass
the compiler will choke with an error. g++ says the following:
src/poc/test/StaticInterfaceTest.cpp: In constructor `StaticInterface
src/poc/test/StaticInterfaceTest.cpp:41: instantiated from here
src/poc/test/StaticInterfaceTest.cpp:20: error: `bar' is not a member of `DerivedClass'
Pretty cool huh?
As a final note, please consider this “an interesting observation” and not necessarily a “great design choice”. As I said, I decided against actually trying to utilize this idea in my project, and I urge you think carefully before about yours before trying to use it yourself.