Tuesday, September 6, 2011

python to visualise vectors

I've been working on a simple example SVM problem and wanted to be able to visualise the vectors in 3D. Python with Matplotlib works nicely.

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt


fig = plt.figure()
ax = Axes3D(fig)


ax.plot3D([0,-1], [0,1], [0,1], zdir='z', label='x1 +1')
ax.plot3D([0,0], [0,1], [0,0], zdir='z', label='x2 -1')
ax.plot3D([0,1], [0,1], [0,1], zdir='z', label='x3 +1')
ax.plot3D([0,0], [0,0], [0,0], zdir='z', label='x4 +1')


ax.scatter([0,-1], [0,1], [0,1], zdir='z', label='x1')
ax.scatter([0,0], [0,1], [0,0], zdir='z', label='x2')
ax.scatter([0,1], [0,1], [0,1], zdir='z', label='x3')
ax.scatter([0,0], [0,0], [0,0], zdir='z', label='x4')


ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')


ax.legend()
plt.show()



To add an equation to the graph:

import numpy as np
x = np.linspace(-1, 1, 50)
y = 2*x**2
ax.plot3D(x, y, 0);

No comments:

Post a Comment