Repository for Petra's work at ampli Jan-Feb 2019

clustering.py 3.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from util import getQuery, pickleQuery
  2. import numpy as np
  3. import pandas as p
  4. import matplotlib.pyplot as plt
  5. import seaborn as sns
  6. from scipy.spatial.distance import squareform
  7. from scipy.cluster.hierarchy import dendrogram, linkage, cophenet, fcluster
  8. # query = """
  9. # SELECT *, read_date + CONCAT(period / 2, ':', period %% 2 * 30, ':00')::time AS read_time
  10. # FROM public.coup_tall_april WHERE icp_id LIKE (%s) AND read_date = to_date(%s, 'dd/mm/yyyy')
  11. # ORDER BY icp_id, read_time;
  12. # """
  13. #
  14. # qparams = ['%%1117', '20/04/2017']
  15. # query = """
  16. # SELECT read_date, period, AVG(kwh_tot) AS average
  17. # FROM public.coup_tall_april
  18. # GROUP BY read_date, period
  19. # ORDER BY read_date, period;
  20. # """
  21. #
  22. # qparams = []
  23. #
  24. # df = getQuery(query, qparams)
  25. #
  26. # print(df.info())
  27. #
  28. # sns.set()
  29. #
  30. # #sns.lineplot(x = 'read_time', y = 'kwh_tot', hue = 'icp_id', data = df)
  31. # sns.lineplot(x = 'period', y = 'average', hue = 'read_date', data = df)
  32. #
  33. # plt.show()
  34. numclusts = 9
  35. df = p.read_pickle('../data/2017-sample.pkl')
  36. dforig = df
  37. print(df.info())
  38. print(df.icp_id.nunique())
  39. print(df.read_time.nunique())
  40. # print(df.groupby('icp_id').read_time.nunique().nunique())
  41. df = df.pivot(index = 'read_time', columns = 'icp_id', values = 'kwh_tot')
  42. df = df[df.columns[df.max() != df.min()]]
  43. print(df.info())
  44. cmat = df.corr()
  45. print(cmat.info())
  46. lmat = squareform(1 - cmat)
  47. lobj = linkage(lmat, method = 'ward')
  48. print(lobj)
  49. print(cophenet(lobj, lmat))
  50. clabs = [x + 1 for x in range(numclusts)]
  51. cpal = dict(zip(clabs, sns.color_palette("colorblind", numclusts).as_hex()))
  52. clusts = fcluster(lobj, numclusts, criterion='maxclust')
  53. print(clusts)
  54. print(cmat.index.values)
  55. clustdf = p.DataFrame({'icp_id' : cmat.index.values, 'cluster' : clusts})
  56. print(clustdf)
  57. mdf = p.merge(clustdf, dforig, on = 'icp_id', how = 'left')
  58. print(mdf)
  59. print(mdf.info())
  60. qlow = lambda x: x.quantile(0.250)
  61. qhigh = lambda x: x.quantile(0.750)
  62. print(mdf.cluster.describe())
  63. mdagg = mdf.groupby(['read_time', 'cluster']).agg({
  64. 'kwh_tot': ['median', 'mean', ('CI_low', qlow), ('CI_high', qhigh)]
  65. }, q = 0.025)
  66. mdagg.columns = ['_'.join(x) for x in mdagg.columns.ravel()]
  67. mdagg = mdagg.reset_index()
  68. print(mdagg)
  69. print(mdagg.info())
  70. print(mdagg.describe())
  71. # mdf.to_csv('~/windows/Documents/clusters-ward.csv')
  72. # Algorithm via
  73. # <https://stackoverflow.com/questions/38153829/custom-cluster-colors-of-scipy-dendrogram-in-python-link-color-func>
  74. ldict = {icp_id:cpal[cluster] for icp_id, cluster in zip(clustdf.icp_id, clustdf.cluster)}
  75. link_cols = {}
  76. for i, i12 in enumerate(lobj[:,:2].astype(int)):
  77. c1, c2 = (link_cols[x] if x > len(lobj) else ldict[clustdf.icp_id[x]]
  78. for x in i12)
  79. link_cols[i+1+len(lobj)] = c1 if c1 == c2 else '#000000'
  80. plt.figure(figsize = (25, 10))
  81. plt.title('ICP Clustering Dendrogram')
  82. plt.xlabel('ICP ID/(Number of ICPs)')
  83. plt.ylabel('distance')
  84. dendrogram(
  85. lobj,
  86. labels = cmat.index.values,
  87. leaf_rotation=90,
  88. leaf_font_size=8,
  89. # show_leaf_counts = True,
  90. # truncate_mode = 'lastp',
  91. # p = 50,
  92. # show_contracted = True,
  93. link_color_func = lambda x: link_cols[x],
  94. color_threshold = None
  95. )
  96. plt.show()
  97. sns.set()
  98. ax = sns.lineplot(x = 'read_time', y = 'kwh_tot_mean', hue = 'cluster', data = mdagg, palette = cpal)
  99. for c in clabs:
  100. fds = mdagg[mdagg.cluster == c]
  101. ax.fill_between(fds.read_time.dt.to_pydatetime(), fds.kwh_tot_CI_low, fds.kwh_tot_CI_high, alpha = 0.1, color = cpal[c])
  102. plt.show()