Repository for Petra's work at ampli Jan-Feb 2019

clustering.py 3.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from util import getQuery, pickleQuery
  2. import numpy as np
  3. import pandas as p
  4. import matplotlib
  5. matplotlib.use('agg')
  6. import matplotlib.pyplot as plt
  7. import seaborn as sns
  8. from scipy.spatial.distance import squareform
  9. from scipy.cluster.hierarchy import dendrogram, linkage, cophenet, fcluster
  10. from tqdm import tqdm
  11. def tqcorr(df):
  12. cols = df.columns
  13. ncols = len(cols)
  14. cdf = p.DataFrame(index = cols, columns = cols)
  15. for c in tqdm(cols):
  16. cdf.loc[c, c] = 0
  17. comb = combinations(cols, 2)
  18. for c1, c2 in tqdm(comb):
  19. cdf.loc[c1, c2] = 1 - df[c1].corr(df[c2])
  20. return cdf
  21. tqdm.pandas()
  22. Sourcedata = '../data/2017-all-wide.pkl'
  23. lableddata = '../data/9-clusters.pkl'
  24. aggdata = '../data/9-clusters.agg.pkl'
  25. clustertable = '../data/9-clusters-sample-table.pkl'
  26. numclusts = 9
  27. df = p.read_pickle(Sourcedata)
  28. # dforig = df
  29. # print(df)
  30. print(df.info())
  31. tqcorr(df)
  32. # print(df.icp_id.nunique())
  33. # print(df.read_time.nunique())
  34. # print(df.groupby('icp_id').read_time.nunique().nunique())
  35. # df = df.pivot(index = 'read_time', columns = 'icp_id', values = 'kwh_tot')
  36. # print(df.info())
  37. df = df[df.columns[df.max() != df.min()]]
  38. print(df.info())
  39. cmat = tqcorr(df)
  40. print(cmat)
  41. print(cmat.info())
  42. # lmat = squareform(1 - cmat)
  43. # lobj = linkage(lmat, method = 'ward')
  44. # print(lobj)
  45. # print(cophenet(lobj, lmat))
  46. # clabs = [x + 1 for x in range(numclusts)]
  47. # cpal = dict(zip(clabs, sns.color_palette("colorblind", numclusts).as_hex()))
  48. # clusts = fcluster(lobj, numclusts, criterion='maxclust')
  49. # print(clusts)
  50. # print(cmat.index.values)
  51. # clustdf = p.DataFrame({'icp_id' : cmat.index.values, 'cluster' : clusts})
  52. # print(clustdf)
  53. # clustdf.to_pickle(clustertable)
  54. # mdf = p.merge(clustdf, dforig, on = 'icp_id', how = 'left')
  55. # print(mdf)
  56. # print(mdf.info())
  57. # qlow = lambda x: x.quantile(0.250)
  58. # qhigh = lambda x: x.quantile(0.750)
  59. # print(mdf.cluster.describe())
  60. # mdagg = mdf.groupby(['read_time', 'cluster']).agg({
  61. # 'kwh_tot': ['median', 'mean', ('CI_low', qlow), ('CI_high', qhigh)]
  62. # }, q = 0.025)
  63. # mdagg.columns = ['_'.join(x) for x in mdagg.columns.ravel()]
  64. # mdagg = mdagg.reset_index()
  65. # print(mdagg)
  66. # print(mdagg.info())
  67. # print(mdagg.describe())
  68. # # mdf.to_csv('~/windows/Documents/clusters-ward.csv')
  69. # print("Saving")
  70. # mdf.to_pickle(lableddata)
  71. # mdagg.to_pickle(aggdata)
  72. # print("saved")
  73. # # Algorithm via
  74. # # <https://stackoverflow.com/questions/38153829/custom-cluster-colors-of-scipy-dendrogram-in-python-link-color-func>
  75. # ldict = {icp_id:cpal[cluster] for icp_id, cluster in zip(clustdf.icp_id, clustdf.cluster)}
  76. # link_cols = {}
  77. # for i, i12 in enumerate(lobj[:,:2].astype(int)):
  78. # c1, c2 = (link_cols[x] if x > len(lobj) else ldict[clustdf.icp_id[x]]
  79. # for x in i12)
  80. # link_cols[i+1+len(lobj)] = c1 if c1 == c2 else '#000000'
  81. # plt.figure(figsize = (25, 10))
  82. # plt.title('ICP Clustering Dendrogram')
  83. # plt.xlabel('ICP ID/(Number of ICPs)')
  84. # plt.ylabel('distance')
  85. # dendrogram(
  86. # lobj,
  87. # labels = cmat.index.values,
  88. # leaf_rotation=90,
  89. # leaf_font_size=8,
  90. # # show_leaf_counts = True,
  91. # # truncate_mode = 'lastp',
  92. # # p = 50,
  93. # # show_contracted = True,
  94. # link_color_func = lambda x: link_cols[x],
  95. # color_threshold = None
  96. # )
  97. # # plt.show()
  98. # plt.savefig("../img/sample-9-dendro.png")
  99. # sns.set()
  100. # f, axes = plt.subplots(3,3)
  101. # for i, c in enumerate(clabs):
  102. # fds = mdagg[mdagg.cluster == c]
  103. # sns.lineplot(x = 'read_time', y = 'kwh_tot_mean', color = cpal[c], ax = axes[i//3][i%3], data = fds)
  104. # axes[i//3][i%3].fill_between(fds.read_time.dt.to_pydatetime(), fds.kwh_tot_CI_low, fds.kwh_tot_CI_high, alpha = 0.1, color = cpal[c])
  105. # # plt.show()
  106. # plt.savefig("../img/sample-9-panedtrends.png")