#
# pca_dataset.py
#

import cv2
import numpy as np
import argparse

from sklearn import decomposition
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import time
 
from sklearn import datasets
 

#one of three options only should be neq ""
#--dataset = mnist, wines, iris (built-in datasets)
 




def ParseArguments():
    parser = argparse.ArgumentParser(description="Project ")
    parser.add_argument('--dataset', default="wine", required=False, help='mnist, wines, iris  (default: %(default)s)')

    args = parser.parse_args()
    
    return args.dataset


# main progam
dataset = ParseArguments()
 

if(dataset!=""):
    print("Using built-in dataset: ", dataset)

 
 


 
if(dataset=="mnist" or dataset=="MNIST"):
     
    ile=1000
    our_data = datasets.load_digits()
    points = our_data.data[:ile]
    data_classes = our_data.target[:ile]
    classes  = our_data.target_names
    classes_names= classes



if(dataset=="wine"):     
    our_data = datasets.load_wine()
    points = our_data.data
    data_classes = our_data.target
    classes  = our_data.target_names
    classes_names= classes


if(dataset=="iris"):     
    our_data = datasets.load_iris()
    points = our_data.data
    data_classes = our_data.target
    classes  = our_data.target_names
    classes_names= classes

     

print("Calculating PCA...", end="", flush=True)
start_time = time.time()        
pca = decomposition.PCA(n_components=3) 
pca.fit(points)
points_pca_reduced = pca.transform(points)
print("\t\t took %s seconds " % round((time.time() - start_time),5))
    

fig_pca = plt.figure(1)
ax_pca = fig_pca.add_subplot(111, projection='3d')

ax_pca.set_title(dataset + ": PCA")

 
for wt in range(0,data_classes.max()+1):
    points_pca=points_pca_reduced[data_classes == wt];    
    ax_pca.scatter(points_pca[:,0], points_pca[:,1], points_pca[:,2], label=classes_names[wt])
     



ax_pca.legend()
 
    
     

plt.show();
 
