#
# color2gray.py
#

import numpy as np
import sys
import argparse
import glob, os
import  scipy.misc   as misc

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn import decomposition
import time
#from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
#from metric_learn import NCA
#from sklearn.manifold import TSNE

 
 

def ParseArguments():
	parser = argparse.ArgumentParser(description="Project ")
	parser.add_argument('--image', default="", required=True, help='Color image  (default: %(default)s)')
	args = parser.parse_args()
	
	return args.image


def rgb2gray(rgb):
    r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
    L = 0.2989 * r + 0.5870 * g + 0.1140 * b
    return L
    
# main progam
image_file  =  ParseArguments() 
 
image = plt.imread(image_file)#	,mode="RGB")

image_misc_bw = rgb2gray(image)
 


image_height=image.shape[0]
image_width=image.shape[1]

if(image.shape[2]>3):
    image=image[:,:,0:3]
    


image2=image.reshape(-1,3)



print("Calculating PCA...", end="", flush=True)
start_time = time.time()		
pca = decomposition.PCA(n_components=3) 
pca.fit(image2)
points_pca_reduced = pca.transform(image2)
print("\t\t took %s seconds " % round((time.time() - start_time),5))

print("points_pca_reduced.shape = ", points_pca_reduced.shape)
 

image_bw_pc1 = points_pca_reduced[:,0].reshape(image_height,image_width)
image_bw_pc2 = points_pca_reduced[:,1].reshape(image_height,image_width)
image_bw_pc3 = points_pca_reduced[:,2].reshape(image_height,image_width)

image_rgb_pca=np.zeros(image.shape)
 
  

fig_orig = plt.figure(1)
ax_orig = fig_orig.add_subplot(111)
ax_orig.imshow(image  )


fig_misc_bw = plt.figure(2)
ax_misc_bw = fig_misc_bw.add_subplot(111)
ax_misc_bw.imshow(image_misc_bw,cmap='gray')



fig_pca = plt.figure(3)
ax_pca = fig_pca.add_subplot(111)
ax_pca.imshow(image_bw_pc1,cmap='gray')


fig_pca2 = plt.figure(4)
ax_pca2 = fig_pca2.add_subplot(111)
ax_pca2.imshow(image_bw_pc2,cmap='gray')


fig_pca3 = plt.figure(5)
ax_pca3 = fig_pca3.add_subplot(111)
ax_pca3.imshow(image_bw_pc3,cmap='gray')



 

ax_orig.set_title("Original image")
ax_pca.set_title("BW: PC 1")
ax_pca2.set_title("BW: PC 2")
ax_pca3.set_title("BW: PC 3")

# ~ ax_rgb_pca.set_title("(R,G,B) = (PC1, PC2, PC3)")

ax_misc_bw.set_title("Grayscale")



ax_orig.legend()
ax_pca.legend()
ax_pca2.legend()
ax_pca3.legend()
ax_misc_bw.legend()


plt.show()





quit()
