Principal Component Analysis (PCA) implemented with PyTorch

What is PCA ?

PCA is an algorithm capable of finding patterns in data, it is used to reduce the dimension of the data.

If X is a matrix of size (m, n). We want to find an encoding fonction f such as f(X) = C where C is a matrix of size (m,l) with l < m and a decoding fonction g that can approximately reconstruct X  such as g(C) ≈ X

C is a representation of X in a lower dimension, we want to find f so that the loss of information in minimal.

PCA implementation steps

This article requires to know what is SVD and eigen decomposition if you want to understand each step. However if you don’t you can still read it to use the implementation !

Data preprocessing

We suppose that X is a Numpy array containing the data. k is the number of components we want after the transformation.

k = 3
X = torch.from_numpy(iris.data)

We need to standardize the data :

X_mean = torch.mean(X,0)
X = X - X_mean.expand_as(X)

Perform Singular Value Decomposition

With torch.SVD() we obtain the singular value decomposition: V the eigenvectors of X and S the eigenvalues in decreasing order. So U[:,:k] corresponds to the k largest eigenvalues.

U,S,V = torch.svd(torch.t(X))
C = torch.mm(X,U[:,:k])

Visualization

We will use our  PCA function

def PCA(data, k=2):
# preprocess the data
X = torch.from_numpy(data)
X_mean = torch.mean(X,0)
X = X - X_mean.expand_as(X)

# svd
U,S,V = torch.svd(torch.t(X))

Now we will visualize the PCA on the IRIS dataset from scikit learn

X = iris.data
y = iris.target
X_PCA = my_PCA(X)

plt.figure()

for i, target_name in enumerate(iris.target_names):
plt.scatter(X_PCA[y == i, 0], X_PCA[y == i, 1], label=target_name)

plt.legend()
plt.title('PCA of IRIS dataset')
plt.show() The PCA allowed us to visualize the iris dataset on a two dimensions visualization and to find combinations of attributes to identify each type of iris.

1 thought on “Principal Component Analysis (PCA) implemented with PyTorch”

1. Mahfuj says:

Can torch.svd() function calculates gradients in pytorch ?