Repository for Petra's work at ampli Jan-Feb 2019

clustering.py 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from util import getQuery, pickleQuery
  2. import numpy as np
  3. import pandas as p
  4. import matplotlib.pyplot as plt
  5. import seaborn as sns
  6. from scipy.spatial.distance import squareform
  7. from scipy.cluster.hierarchy import dendrogram, linkage, cophenet, fcluster
  8. # query = """
  9. # SELECT *, read_date + CONCAT(period / 2, ':', period %% 2 * 30, ':00')::time AS read_time
  10. # FROM public.coup_tall_april WHERE icp_id LIKE (%s) AND read_date = to_date(%s, 'dd/mm/yyyy')
  11. # ORDER BY icp_id, read_time;
  12. # """
  13. #
  14. # qparams = ['%%1117', '20/04/2017']
  15. # query = """
  16. # SELECT read_date, period, AVG(kwh_tot) AS average
  17. # FROM public.coup_tall_april
  18. # GROUP BY read_date, period
  19. # ORDER BY read_date, period;
  20. # """
  21. #
  22. # qparams = []
  23. #
  24. # df = getQuery(query, qparams)
  25. #
  26. # print(df.info())
  27. #
  28. # sns.set()
  29. #
  30. # #sns.lineplot(x = 'read_time', y = 'kwh_tot', hue = 'icp_id', data = df)
  31. # sns.lineplot(x = 'period', y = 'average', hue = 'read_date', data = df)
  32. #
  33. # plt.show()
  34. numclusts = 9
  35. df = p.read_pickle('../data/2016-17-sample.pkl')
  36. dforig = df
  37. # print(df)
  38. print(df.info())
  39. print(df.icp_id.nunique())
  40. print(df.read_time.nunique())
  41. # print(df.groupby('icp_id').read_time.nunique().nunique())
  42. df = df.pivot(index = 'read_time', columns = 'icp_id', values = 'kwh_tot')
  43. print(df.info())
  44. df = df[df.columns[df.max() != df.min()]]
  45. print(df.info())
  46. cmat = df.corr()
  47. print(cmat.info())
  48. lmat = squareform(1 - cmat)
  49. lobj = linkage(lmat, method = 'ward')
  50. print(lobj)
  51. print(cophenet(lobj, lmat))
  52. clabs = [x + 1 for x in range(numclusts)]
  53. cpal = dict(zip(clabs, sns.color_palette("colorblind", numclusts).as_hex()))
  54. clusts = fcluster(lobj, numclusts, criterion='maxclust')
  55. print(clusts)
  56. print(cmat.index.values)
  57. clustdf = p.DataFrame({'icp_id' : cmat.index.values, 'cluster' : clusts})
  58. print(clustdf)
  59. mdf = p.merge(clustdf, dforig, on = 'icp_id', how = 'left')
  60. print(mdf)
  61. print(mdf.info())
  62. qlow = lambda x: x.quantile(0.250)
  63. qhigh = lambda x: x.quantile(0.750)
  64. print(mdf.cluster.describe())
  65. mdagg = mdf.groupby(['read_time', 'cluster']).agg({
  66. 'kwh_tot': ['median', 'mean', ('CI_low', qlow), ('CI_high', qhigh)]
  67. }, q = 0.025)
  68. mdagg.columns = ['_'.join(x) for x in mdagg.columns.ravel()]
  69. mdagg = mdagg.reset_index()
  70. print(mdagg)
  71. print(mdagg.info())
  72. print(mdagg.describe())
  73. # mdf.to_csv('~/windows/Documents/clusters-ward.csv')
  74. print("Saving")
  75. mdf.to_pickle('../data/9-clusters-1617.pkl')
  76. mdagg.to_pickle('../data/9-clusters-1617.agg.pkl')
  77. print("saved")
  78. # Algorithm via
  79. # <https://stackoverflow.com/questions/38153829/custom-cluster-colors-of-scipy-dendrogram-in-python-link-color-func>
  80. ldict = {icp_id:cpal[cluster] for icp_id, cluster in zip(clustdf.icp_id, clustdf.cluster)}
  81. link_cols = {}
  82. for i, i12 in enumerate(lobj[:,:2].astype(int)):
  83. c1, c2 = (link_cols[x] if x > len(lobj) else ldict[clustdf.icp_id[x]]
  84. for x in i12)
  85. link_cols[i+1+len(lobj)] = c1 if c1 == c2 else '#000000'
  86. plt.figure(figsize = (25, 10))
  87. plt.title('ICP Clustering Dendrogram')
  88. plt.xlabel('ICP ID/(Number of ICPs)')
  89. plt.ylabel('distance')
  90. dendrogram(
  91. lobj,
  92. labels = cmat.index.values,
  93. leaf_rotation=90,
  94. leaf_font_size=8,
  95. # show_leaf_counts = True,
  96. # truncate_mode = 'lastp',
  97. # p = 50,
  98. # show_contracted = True,
  99. link_color_func = lambda x: link_cols[x],
  100. color_threshold = None
  101. )
  102. plt.show()
  103. # sns.set()
  104. #
  105. # f, axes = plt.subplots(3,3)
  106. #
  107. # for i, c in enumerate(clabs):
  108. # fds = mdagg[mdagg.cluster == c]
  109. # sns.lineplot(x = 'read_time', y = 'kwh_tot_mean', color = cpal[c], ax = axes[i//3][i%3], data = fds)
  110. # axes[i//3][i%3].fill_between(fds.read_time.dt.to_pydatetime(), fds.kwh_tot_CI_low, fds.kwh_tot_CI_high, alpha = 0.1, color = cpal[c])
  111. # plt.show()