Source code for cachai._core.chord

# Basic imports
import numpy as np
import pandas as pd
import seaborn as sns
import cachai.utilities as chu
import cachai.gadgets as chg
# Matplotlib imports
from   matplotlib import pyplot as plt
from   matplotlib.patches import Arc, Circle, PathPatch
from   matplotlib.path import Path
import matplotlib.colors as mtpl_colors
from   matplotlib.text import Text

[docs] class ChordDiagram(): def __init__(self, corr_matrix, **kwargs): """ Initialize a ChordDiagram instance. """ # Correlation matrix error handling self.corr_matrix = corr_matrix self._validate_corr_matrix() # Initialize additional parameters self.__dict__.update(kwargs) if isinstance(self.corr_matrix, pd.DataFrame): if self.names is None: self.names = self.corr_matrix.columns.tolist() self.corr_matrix = self.corr_matrix.to_numpy() if self.names is None: self.names = [f'N{i+1}' for i in range(len(self.corr_matrix))] if self.colors is None: self.colors = sns.hls_palette(len(self.corr_matrix)) self.nodes = dict() self.order = [i for i in range(len(self.corr_matrix))] self.global_indexes = [] if self.font is None: self.font = {'size':self.fontsize} # Initialize collection lists self.node_patches = [] self.node_labels = [] self.node_labels_params = [] self.chord_patches = [[] for i in range(len(self.corr_matrix))] self.chord_blends = [[] for i in range(len(self.corr_matrix))] self.bezier_curves = [[] for i in range(len(self.corr_matrix))] self.__ports_refs = [] self.__highlighted_ports = [] # Generate the diagram self.__generate_diagram() # Util methods def _validate_corr_matrix(self): """ Validate that a correlation matrix meets the required specifications: - Input is a numpy.ndarray or pandas.DataFrame - Matrix is 2-dimensional - Matrix is not empty - All values are int or float - Matrix is symmetric """ temp_corr_matrix = self.corr_matrix if not isinstance(temp_corr_matrix, (np.ndarray, pd.DataFrame)): raise TypeError('Your correlation matrix must be a numpy.ndarray or pandas.DataFrame') # -- This block of code should not be here, but its necessary for the next validations -- if isinstance(temp_corr_matrix, pd.DataFrame): temp_corr_matrix = self.corr_matrix.to_numpy() # --------------------------------------------------------------------------------------- if temp_corr_matrix.ndim != 2: raise ValueError('Your correlation matrix must be a 2-dimensional array') if temp_corr_matrix.shape[0] != temp_corr_matrix.shape[1]: raise ValueError('Your correlation matrix must be a square matrix.') if len(temp_corr_matrix) == 0: raise ValueError('Your correlation matrix cannot be empty') if not np.issubdtype(temp_corr_matrix.dtype, np.floating): raise TypeError('Your correlation matrix must contain float values') if not np.allclose(temp_corr_matrix, temp_corr_matrix.T): raise ValueError('Your correlation matrix must be symmetric') def _optimize_nodes(self): """Optimize node order using Prim's algorithm.""" n_nodes = self.corr_matrix.shape[0] # We convert the correlations to distances # The strongest the correlation, the shorter the distance distance_matrix = 1 - np.abs(self.corr_matrix) # In order to ignore the diagonal, we fill it with infinity values np.fill_diagonal(distance_matrix, np.inf) # Prim's algorithm visited = set() order = [] start_node = np.argmin(np.sum(distance_matrix, axis=0)) visited.add(start_node) order.append(start_node) while len(visited) < n_nodes: # Closest non-visited node to any visited node min_dist = np.inf next_node = -1 for node in visited: for neighbor in range(n_nodes): if (neighbor not in visited) and (distance_matrix[node,neighbor] < min_dist): min_dist = distance_matrix[node, neighbor] next_node = neighbor if next_node == -1: break # Just in case somehow we have disconnected nodes visited.add(next_node) order.append(next_node) # Apply new order self.order = order self.__order_nodes() def _radius_rule(self, dist): """Rule to set the radius of a single chord""" if dist <= self.min_dist: return self.max_rho_radius else: return self.max_rho_radius * (1 - (dist - self.min_dist) / (np.pi - self.min_dist)) def _scale_rho(self, rho): """Scale rho (link thickness)""" if self.scale == 'linear': rho_lin = np.abs(rho) * self.max_rho return np.clip(rho_lin, 0, 1) # Clip to avoid numerical issues elif self.scale == 'log': rho_log = (1 - np.log10(10 - 9*np.abs(rho))) * self.max_rho return np.clip(rho_log, 0, 1) # Clip to avoid numerical issues else: raise ValueError(f'Unknown scale type {self.scale}') # Main generation methods def __generate_diagram(self): """Generate the complete chord diagram""" for i,color in enumerate(self.colors): if isinstance(color,str): self.colors[i] = mtpl_colors.to_rgb(color) if self.filter == True: self.__filter_nodes() if len(self.corr_matrix) == 0: raise ValueError(f'No nodes remaining after threshold filtering: ' f'all correlations were below the threshold = {self.threshold}.') else: if self.optimize == True: self._optimize_nodes() self.__generate_nodes() self.__generate_chords() # Add patches to axes for node_patch, node_label in zip(self.node_patches, self.node_labels_params): self.ax.add_patch(node_patch) label = chg.PolarText( node_label['r'], np.rad2deg(node_label['theta']), text=node_label['label'], center=self.position, pad=self.node_labelpad, rotation=node_label['rot'], ha='center', va='center', clip_on=True, rasterized=self.rasterized, ) label.set_font(self.font) self.ax.add_artist(label) self.node_labels.append(label) flat_chord_patches = [p for plist in self.chord_patches for p in plist] flat_bezier_curves = [c for clist in self.bezier_curves for c in clist] for k,(chord_patch,bezier_curve) in enumerate(zip(flat_chord_patches,flat_bezier_curves)): self.ax.add_patch(chord_patch) if self.blend: self.__add_chord_blend(chord_patch,bezier_curve,self.global_indexes[k]) self.__adjust_ax() self.__generate_legend() self.__generate_port_refs() # Components generation methods def __generate_nodes(self): """Generate nodes""" # Initial variables # Minus 1 from each diagonal of A to A relevance = np.sum(np.abs(self.corr_matrix),axis=1) - 1 relevance_norm = relevance / np.sum(relevance) start_angles = np.cumsum([0] + list(2*np.pi*relevance_norm[:-1])) gap_angle = (2*np.pi/len(self.corr_matrix))*self.node_gap # Base patch self.ax.add_patch(Circle(self.position,self.radius, lw=0, zorder=2, fc='w', ec='none', rasterized=self.rasterized)) lw = 2*self.node_linewidth for node in range(len(self.corr_matrix)): node_data = dict() theta_i = start_angles[node] theta_f = theta_i + 2*np.pi*relevance_norm[node] # -- Gap correction ----- theta_i = theta_i + np.min([gap_angle,theta_f]) # ----------------------- theta_m = (theta_i + theta_f)/2 theta_arc = chu.angdist(theta_i,theta_f) node_data['theta_i'] = theta_i node_data['theta_f'] = theta_f node_data['theta_m'] = theta_m node_data['theta_arc'] = theta_arc rhos = dict() # Correlations ports = dict() # Ports of the node states = dict() # Ports states (1 or -1) corr = np.insert(self.corr_matrix[node],node,1) real_corr = corr.copy() # Control of the allowed ports using the correlation factor # 1 = Allowed # -1 = Forbidden node_ports_state = [1 for p in range(len(corr))] for p,r in enumerate(corr): if np.abs(r) < self.threshold: node_ports_state[p] = -1 real_corr[p] = 0 if not self.show_diag: node_ports_state[node] = -1 real_corr[node] = 0 node_ports_state[node+1] = -1 real_corr[node+1] = 0 if np.sum(real_corr) == 0: corr_norm = real_corr else: corr_norm = np.abs(real_corr)/np.sum(np.abs(real_corr)) for j,(rho,port_size,port_state) in enumerate(zip(corr,corr_norm,node_ports_state)): if j == node : port_id = node elif j == (node+1) : port_id = f'{node}*' else: if j < node: port_id = j elif j > node: port_id = j-1 port_i = theta_i + theta_arc*np.sum(corr_norm[:j]) port_f = port_i + theta_arc*port_size if port_state < 0: port_i = 0 port_f = 0 rhos[port_id] = rho ports[port_id] = {'i':port_i,'f':port_f} states[port_id] = port_state node_data['rhos'] = rhos node_data['ports'] = ports node_data['ports_state'] = states self.nodes[node] = node_data # Node for n in self.nodes: node = self.nodes[n] # Patch self.node_patches.append( Arc(self.position, width=2*self.radius, height=2*self.radius, theta1=np.rad2deg(node['theta_i']), theta2=np.rad2deg(node['theta_f']), lw=lw,zorder=1, rasterized=self.rasterized, color=self.colors[n]) ) # Label params = dict() params['label'] = self.names[n] params['r'] = self.radius params['theta'] = node['theta_m'] params['x'] = params['r'] * np.cos(params['theta']) + self.position[0] params['y'] = params['r'] * np.sin(params['theta']) + self.position[1] params['rot'] = np.rad2deg(node['theta_m'] - np.sign(params['y']-self.position[1])*np.pi/2)%360 self.node_labels_params.append(params) def __generate_chords(self): """Generate chords""" for n in self.nodes: node = self.nodes[n] chord_color = self.colors[n] chord_edge = chu.mod_color(self.colors[n],light=0.5) if self.blend: chord_color = 'none' chord_edge = '#3D3D3D' # Links if self.show_diag: points,codes,curve = self.__compute_bezier_curves( (node['ports'][n]['i'],node['ports'][n]['f']), (node['ports'][f'{n}*']['i'],node['ports'][f'{n}*']['f']), self._scale_rho(1) ) self.chord_patches[n].append( PathPatch(Path(points, codes), facecolor=chord_color, edgecolor=chord_edge, alpha=self.chord_alpha, hatch=self.positive_hatch, lw=self.chord_linewidth, rasterized=self.rasterized, zorder=4) ) curve['c1'] = self.colors[n] curve['c2'] = self.colors[n] self.bezier_curves[n].append(curve) self.global_indexes.append(n) for m in self.nodes: if m > n and node['ports_state'][m] > 0: try: target = self.nodes[m] this_rho = node['rhos'][m] vis_rho = self._scale_rho(this_rho) hatch = self.positive_hatch if this_rho < 0: hatch = self.negative_hatch points,codes,curve = self.__compute_bezier_curves( (node['ports'][m]['i'],node['ports'][m]['f']), (target['ports'][n]['i'],target['ports'][n]['f']), vis_rho ) self.chord_patches[n].append( PathPatch(Path(points, codes), facecolor=chord_color, edgecolor=chord_edge, alpha=self.chord_alpha, hatch=hatch, lw=self.chord_linewidth, rasterized=self.rasterized, zorder=4) ) curve['c1'] = self.colors[n] curve['c2'] = self.colors[m] self.bezier_curves[n].append(curve) self.global_indexes.append(n) except Exception as e: print(chu.strcol(rf'ChordError: Problem creating chord from {self.names[n]} to {self.names[m]}.', c='red')) print(chu.strcol(f' details: {e}', c='red')) def __generate_legend(self): """Add dummie labels to show in the legend""" if self.legend is True: if self.positive_label is None: self.positive_label = 'Positive\ncorrelation' if self.negative_label is None: self.negative_label = 'Negative\ncorrelation' # Dummies if self.positive_label is not None: dummy = self.ax.scatter(*self.position,marker='s',s=200, c='lightgray',ec='k',hatch=self.positive_hatch, label=self.positive_label,zorder=0,rasterized=True) #dummy.set_visible(False) if self.negative_label is not None: dummy = self.ax.scatter(*self.position,marker='s',s=200, c='lightgray',ec='k',hatch=self.negative_hatch, label=self.negative_label,zorder=0,rasterized=True) #dummy.set_visible(False) def __generate_port_refs(self): for n in self.nodes: self.__ports_refs.append(self.__get_node_ports_references(n)) # Helper methods def __filter_nodes(self): """Remove nodes with no correlation (0 chords)""" mask = np.all((np.abs(self.corr_matrix) < self.threshold)\ | (np.eye(self.corr_matrix.shape[0], dtype=bool)), axis=1) indexes = np.where(~mask)[0] self.corr_matrix = self.corr_matrix[indexes][:, indexes] self.names = [self.names[i] for i in indexes] self.colors = [self.colors[i] for i in indexes] def __order_nodes(self): """Order nodes (matrix), names and colors""" self.corr_matrix = self.corr_matrix[np.ix_(self.order, self.order)] self.names = [self.names[i] for i in self.order] self.colors = [self.colors[i] for i in self.order] def __compute_bezier_curves(self,alpha,beta,rho): """Compute bezier curves to modelate a chord""" # Polar alpha_i, alpha_f = alpha alpha_m = np.mean([alpha_f,alpha_i]) alphas = chu._angspace(alpha_i,alpha_f) if len(alphas) == 0: alphas = np.array([alpha_i,alpha_f]) # Case: angdist too short beta_i, beta_f = beta beta_m = np.mean([beta_f,beta_i]) betas = chu._angspace(beta_i,beta_f) if len(betas) == 0: betas = np.array([beta_i,beta_f]) # Case: angdist too short dist = chu.angdist(alpha_m, beta_m) r_rho = self._radius_rule(dist) * self.radius dist_inex = np.min([chu.angdist(alpha_i, beta_f), chu.angdist(alpha_f, beta_i)]) # Convex case if chu.angdist(alpha_i, beta_f) < chu.angdist(alpha_f, beta_i): theta_rho = beta_f + dist_inex / 2 r_AB = r_rho r_BA = r_rho + rho * self.radius # Concave case elif chu.angdist(alpha_i, beta_f) >= chu.angdist(alpha_f, beta_i): theta_rho = alpha_f + dist_inex / 2 r_AB = r_rho + rho * self.radius r_BA = r_rho # Cartesian points_A = np.column_stack([np.cos(alphas) * self.radius, np.sin(alphas) * self.radius]) points_B = np.column_stack([np.cos(betas) * self.radius, np.sin(betas) * self.radius]) # A to B point_AB = [np.array([r_AB * np.cos(theta_rho), r_AB * np.sin(theta_rho)])] control_AB = [2 * point_AB[0] - (points_A[-1] + points_B[0]) / 2] # B to A point_BA = [np.array([r_BA * np.cos(theta_rho), r_BA * np.sin(theta_rho)])] control_BA = [2 * point_BA[0] - (points_A[0] + points_B[-1]) / 2] # Bezier curve in the middle mid_bezier = dict() r_mid = (r_AB + r_BA) / 2 point_mid = [np.array([r_mid * np.cos(theta_rho), r_mid * np.sin(theta_rho)])] control_mid = [2 * point_mid[0] - (points_A[-1] + points_B[0]) / 2] mid_bezier['P0'] = points_A[-1] + self.position mid_bezier['P1'] = control_mid[0] + self.position mid_bezier['P2'] = points_B[0] + self.position # Points points = np.vstack((points_A, control_AB, points_B, control_BA, points_A[0])) \ + self.position # Codes codes = [Path.MOVETO] + \ [Path.LINETO] * (len(points_A) - 1) + \ [Path.CURVE3] * 2 + \ [Path.LINETO] * (len(points_B) - 1) + \ [Path.CURVE3] * 2 return points,codes,mid_bezier def __add_chord_blend(self,patch,curve,n): """Add color mapped patches using the initial and final colors""" # Pach vertices vertices = patch.get_path().vertices xmin, ymin = np.min(vertices, axis=0) xmax, ymax = np.max(vertices, axis=0) # Bézier curve P0 = curve['P0'] P1 = curve['P1'] P2 = curve['P2'] bezier = chu.get_bezier_curve([P0,P1,P2],n=self.bezier_n) bezier_equidistant = chu.equidistant(bezier) # Color map c1 = curve['c1'] # Color 1 c2 = curve['c2'] # Color 2 chord_cmap = sns.blend_palette([c1,c1,c2,c2],as_cmap=True) cmap_matrix = chu.map_from_curve(bezier_equidistant,xlim=(xmin,xmax),ylim=(ymin,ymax), resolution=self.blend_resolution) self.chord_blends[n].append( chu.colormapped_patch( patch, cmap_matrix, ax=self.ax, colormap=chord_cmap, zorder=2, alpha=self.chord_alpha, rasterized=self.rasterized) ) def __adjust_ax(self): """Adjust scale, limits, and visibility of the axis""" adjust_x = self.ax.get_autoscalex_on() adjust_y = self.ax.get_autoscaley_on() if adjust_x: self.ax.set_xlim(self.position[0] - self.radius*1.5,self.position[0] + self.radius*1.5) if adjust_y: self.ax.set_ylim(self.position[1] - self.radius*1.5,self.position[1] + self.radius*1.5) if adjust_x and adjust_y: self.ax.set_aspect('equal') if self.show_axis == False: self.ax.axis('off') def __get_node_ports(self,n): """Return the occupied ports of the n-th node""" this_items = list(self.nodes[n]['ports_state'].items()) this_items.pop(n+1) return [p for p, s in this_items if s > 0 and p > n] def __get_node_ports_references(self,n): """ Return the reference of the chords of the n-th node as (n,c), where: - n: index of the n-th node in the resulting diagram - c: index of the c-th chord in the n-th node Always anti-clockwise. When show_diag=True, the self-referencing chord is (n,0). """ node = self.nodes[n] this_ports = self.__get_node_ports(n) refs = [] for port in node['ports_state'].keys(): if '*' not in str(port): if node['ports_state'][port] > 0 and port > n: refs.append((n,this_ports.index(port))) elif node['ports_state'][port] > 0 and port < n: target_ports = self.__get_node_ports(port) refs.append((port,target_ports.index(n))) if self.show_diag: diag_ref_position = None refs_modified = [] for i, (x, y) in enumerate(refs): modified = (x, y + 1) refs_modified.append(modified) if diag_ref_position is None and x == n: diag_ref_position = i if diag_ref_position is None: diag_ref_position = len(refs_modified) refs_modified.insert(diag_ref_position, (n, 0)) refs = refs_modified return refs def __update_highlights(self): for n in self.nodes: for c in range(len(self.chord_patches[n])): if (n,c) not in self.__highlighted_ports: self.chord_patches[n][c].set_alpha(self.off_alpha) if self.blend == True: self.chord_blends[n][c].set_alpha(self.off_alpha) # Customization methods
[docs] def highlight_node(self,node,chords=None,alpha=None): """:meta private: Highlights a specific node. This affects the node and all its chords. If you want to highlight only some chords in the node, you can indicate this with ``chords``. Nodes are indexed based on their circular arrangement around the origin (center of the Chord Diagram). Index ``0`` corresponds to the first node in the first quadrant, with numbering proceeding counterclockwise around the diagram. Chords within each node are also indexed counterclockwise, starting from the outermost chord. Parameters node : :class:`int` Index of the node to highlight (starting in 0). chords : :class:`list` or :class:`array-like`, optional List of chord indices to highlight (default: all chords). alpha : :class:`float`, optional Transparency level for highlighting. """ if node >= len(self.nodes): raise IndexError('Node is out of range. ' f'This Chord Diagram has only {len(self.nodes)} nodes.') if node < 0 : raise ValueError('The input node must be positive or zero.') if chords is None: chords = [c for c in range(len(self.__ports_refs[node]))] for chord in chords: self.highlight_chord(node,chord,alpha)
[docs] def highlight_chord(self,node,chord,alpha=None): """:meta private: Highlights a specific chord connected to a particular node. Parameters node : :class:`int` Index of the node where the chord originates. chord : :class:`int` Index of the chord to highlight (starting in 0). alpha : :class:`float`, optional Transparency level for highlighting. """ if node >= len(self.nodes): raise IndexError('Node is out of range. ' f'This Chord Diagram has only {len(self.nodes)} nodes.') if node < 0 or chord < 0: raise ValueError('The input node and chord must be positive or zero.') if alpha is None: alpha = self.chord_alpha if alpha <= self.off_alpha: alpha = 0.8 self.__update_highlights() try: n,c = self.__ports_refs[node][chord] self.chord_patches[n][c].set_alpha(alpha) if self.blend == True: self.chord_blends[n][c].set_alpha(alpha) if (n,c) not in self.__highlighted_ports: self.__highlighted_ports.append((n,c)) except IndexError: raise IndexError(f'Chord {chord} is out of range. ' f'Node {node} has only {len(self.__ports_refs[node])} chords.')
[docs] def set_chord_alpha(self,alpha): """:meta private: Sets the transparency level for all chords in the diagram. Parameters alpha : :class:`float` Transparency value applied to all chords. """ for n in self.nodes: for cp in self.chord_patches[n]: cp.set_alpha(alpha) if self.blend == True: for cb in self.chord_blends[n]: cb.set_alpha(alpha)
# Special methods def __str__(self): string = '' for n in self.nodes: string += f'node {n} "{self.names[n]}"\n' + '-'*50 for key in self.nodes[n]: if key == 'ports': string += f'\n{key:<10} :' for p in self.nodes[n][key]: string += f'\n\t\t{p:<10} : {self.nodes[n][key][p]}' else: string += f'\n{key:<10} : {self.nodes[n][key]}' string += '\n\n\n' return string