Issue
I’m using matplotlib. I have a following linear regression model with a surface plane and training data set.
I need to draw orthogonal distance from each data point to the surface plane that would look similar to this:
Here is the code snippet that I have:
nx, ny = (100, 100)
x1 = np.linspace(-3, 10.0, nx)
x2 = np.linspace(0, 15.0, ny)
x_plane, y_plane = np.meshgrid(x1, x2)
XY = np.stack((x_plane.ravel(), y_plane.ravel()),axis =1)
z_plane = np.array([normal_equation(x,y) for x,y in XY]).reshape(x_plane.shape)
fig = plt.figure(figsize=(10, 8))
ax = fig.gca(projection = '3d')
ax.scatter(X2, X1, Y, color='r')
ax.plot_surface(x_plane, y_plane, z_plane, color='b', alpha=0.4)
ax.set_xlabel('x1')
ax.set_ylabel('x2')
ax.set_zlabel('y')
ax.set_zlim(-10, 5)
Any help would be very appreciated.
Solution
Some simple mathematical facts we can use to solve this problem:
- The cross product of two vectors on the plane is a vector perpendicular to the plane.
- The dot product of two vectors measures the distance that each vector travels along the same direction as the other vector.
First, we can find a vector perpendicular to the plane using the following code:
perpendicular = np.cross(
(0, 1, normal_equation(0, 1) - normal_equation(0, 0)),
(1, 0, normal_equation(1, 0) - normal_equation(0, 0))
)
normal = perpendicular / np.linalg.norm(perpendicular)
(Note: we assumed here that the plane is not vertical which it shouldn’t be in linear regression)
Second, we need to trace back each point along this normal vector back to the plane.
plane_point = np.array([0, 0, normal_equation(0, 0)])
dot_prods = [
np.dot(np.array(u) - plane_point, normal)
for u in zip(X2, X1, Y)
]
closest_points = [
np.array([X2[i], X1[i], Y[i]]) - normal * dot_prods[i]
for i in range(len(Y))
]
Finally, we can draw connections between each of these points.
for i in range(len(Y)):
ax.plot(
[closest_points[i][0], X2[i]],
[closest_points[i][1], X1[i]],
[closest_points[i][2], Y[i]],
'k-'
)
I hope this helps!
Answered By – Matthew Miller
Answer Checked By – Willingham (AngularFixing Volunteer)