Source code for prestools.graph

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# Created by Roberto Preste
import itertools
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.cluster.hierarchy as sch
from typing import Union, List


[docs]def flatten_image(img: np.ndarray, scale: bool = False) -> np.ndarray: """Convert an image array to a single-dimension vector. Args: img: input image array of shape (l, h, d = 3) scale: scale resulting vector dividing its values by 255 (default: False) Returns: v: reshaped vector of shape (l * h * d, 1) """ v = img.reshape(img.shape[0] * img.shape[1] * img.shape[2], 1) if scale: return v / 255 return v
[docs]def plot_heatmap_dendrogram(df: pd.DataFrame, cmap: str = "RdBu_r", title: str = "Cluster Heatmap", save: Union[bool, str] = False, method: str = "ward"): """Plot a heatmap with hierarchical clustering of a dataframe. Create (and optionally save) a heatmap with hierarchical clustering created using Seaborn, starting from a given dataframe of correlations. See Also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html Args: df: input dataframe of correlations cmap: colormap to use (default: 'RdBu_r') title: title for resulting plot (default: 'Cluster Heatmap') save: if False, the plot will not be saved, just shown; otherwise it is possible to specify the path/filename where the file will be saved (default: False) method: method to use to cluster the data (default: 'ward') """ if df.shape == (0, 0) or df.shape == (1, 1): return False cm = sns.clustermap(df, method=method, figsize=(20, 16), vmin=-1, vmax=1, annot=True, cmap=cmap) plt.suptitle(title, fontsize=22) if save: cm.savefig(save) plt.show() return
[docs]def plot_dendrogram(df: Union[pd.DataFrame, np.ndarray], cut_off: Union[bool, float] = False, title: str = "Dendrogram", save: Union[bool, str] = False, method: str = "ward"): """Plot a dendrogram plot from a dataframe. Create (and optionally save) a dendrogram plot starting from a given dataframe of correlations. It is also possible to add a cut-off line given a distance to use for separating clusters. See Also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html Args: df: input dataframe of correlations cut_off: if not False, a vertical line will be added to better identify clusters (default: False) title: title for resulting plot (default: 'Dendrogram') save: if False, the plot will not be saved, just shown; otherwise it is possible to specify the path/filename where the file will be saved (default: False) method: method to use to cluster the data (default: 'ward') """ if df.shape == (0, 0) or df.shape == (1, 1): return False Z = sch.linkage(df, method=method) # dists = ssd.pdist(df) # c, coph_dists = sch.cophenet(Z, dists) plt.figure(figsize=(20, 16)) sch.dendrogram(Z, leaf_font_size=16, labels=df.columns, orientation="left") if cut_off: plt.axvline(x=cut_off, linewidth=4.0, linestyle="--") plt.title(title, fontsize=22) plt.xlabel("distance", fontsize=14) plt.ylabel("feature", fontsize=14) plt.yticks(fontsize=14) plt.xticks(fontsize=14) if save: plt.savefig(save) plt.show() return
[docs]def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], title: str = "Confusion Matrix", cmap: str = "Reds", normalize: bool = False, save: Union[bool, str] = False): """Create a plot from a confusion matrix array. Args: cm: input confusion matrix array class_names: class names to use title: title for resulting plot (default: 'Confusion Matrix') cmap: colormap to use (default: 'RdBu_r') normalize: use classes ratios instead of raw numbers (default: False) save: if False, the plot will not be saved, just shown; otherwise it is possible to specify the path/filename where the file will be saved (default: False) See Also: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html """ if np.sum(cm) == 0: return False accuracy = np.trace(cm) / float(np.sum(cm)) misclass = 1 - accuracy plt.figure(figsize=(20, 16)) plt.imshow(cm, interpolation="nearest", cmap=cmap) plt.title("{}\n".format(title), fontsize=22) cbar = plt.colorbar() cbar.set_label('# of samples\n', rotation=270, size=14, labelpad=18) tick_marks = np.arange(len(class_names)) plt.xticks(tick_marks, class_names, fontsize=14) plt.yticks(tick_marks, class_names, fontsize=14) if normalize: cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] thresh = cm.max() / 1.5 if normalize else cm.max() / 2 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): if normalize: plt.text(j, i, "{:0.4f}".format(cm[i, j]), horizontalalignment="center", fontsize=16, weight="bold", color="white" if cm[i, j] > thresh else "black") else: plt.text(j, i, "{:,}".format(cm[i, j]), horizontalalignment="center", fontsize=16, weight="bold", color="white" if cm[i, j] > thresh else "black") plt.ylabel("True label", fontsize=14) plt.xlabel("Predicted label\n\nAccuracy={:0.4f}; Misclass={:0.4f}".format(accuracy, misclass), fontsize=14) if save: plt.savefig(save) plt.show() return
[docs]def reduce_xaxis_ticks(ax: plt.Axes, step: int): """Show every ith x axis tick. Args: ax: axis to be adjusted step: factor to reduce the number of x axis ticks by Examples: >>> fig, ax = plt.subplots() >>> reduce_xaxis_ticks(ax, 5) """ plt.setp(ax.xaxis.get_ticklabels(), visible=False) for label in ax.xaxis.get_ticklabels()[step-1::step]: label.set_visible(True) return
[docs]def reduce_yaxis_ticks(ax: plt.Axes, step: int): """Show every ith y axis tick. Args: ax: axis to be adjusted step: factor to reduce the number of y axis ticks by Examples: >>> fig, ax = plt.subplots() >>> reduce_yaxis_ticks(ax, 5) """ plt.setp(ax.yaxis.get_ticklabels(), visible=False) for label in ax.yaxis.get_ticklabels()[step - 1::step]: label.set_visible(True) return