Coverage for lasso/dimred/svd/plot_beta_clusters.py: 66%

65 statements  

« prev     ^ index     » next       coverage.py v7.2.4, created at 2023-04-28 19:49 +0100

1import os 

2import re 

3import time 

4import webbrowser 

5from typing import Sequence, Union 

6 

7import numpy as np 

8 

9from lasso.dimred.svd.html_str_eles import ( 

10 CONST_STRING, 

11 OVERHEAD_STRING, 

12 SCRIPT_STRING, 

13 TRACE_STRING, 

14) 

15from lasso.plotting.plot_shell_mesh import _read_file 

16 

17 

18def timestamp() -> str: 

19 """ 

20 Creates a timestamp string of format yymmdd_hhmmss_ 

21 """ 

22 

23 def add_zero(in_str) -> str: 

24 if len(in_str) == 1: 

25 return "0" + in_str 

26 return in_str 

27 

28 year, month, day, hour, minute, second, _, _, _ = time.localtime() 

29 y_str = str(year)[2:] 

30 mo_str = add_zero(str(month)) 

31 d_str = add_zero(str(day)) 

32 h_str = add_zero(str(hour)) 

33 mi_str = add_zero(str(minute)) 

34 s_str = add_zero(str(second)) 

35 t_str = y_str + mo_str + d_str + "_" + h_str + mi_str + s_str + "_" 

36 return t_str 

37 

38 

39# pylint: disable = inconsistent-return-statements 

40def plot_clusters_js( 

41 beta_cluster: Sequence, 

42 id_cluster: Union[np.ndarray, Sequence], 

43 save_path: str, 

44 img_path: Union[None, str] = None, 

45 mark_outliers: bool = False, 

46 mark_timestamp: bool = True, 

47 filename: str = "3d_beta_plot", 

48 write: bool = True, 

49 show_res: bool = True, 

50) -> Union[None, str]: 

51 """ 

52 Creates a .html visualization of input data 

53 

54 Parameters 

55 ---------- 

56 beta_cluster: np.ndarray 

57 Numpy array containing beta clusters 

58 id_cluster: Union[np.ndarray, Sequence] 

59 Numpy array or sequence containing the ids samples in clusters. 

60 Must be of same structure as beta_clusters 

61 save_path: str 

62 Where to save the .html visualization 

63 img_path: Union[None, str], default: None 

64 Path to images of samples 

65 mark_outliers: bool, default: False 

66 Set to True if first entry in beta_cluster are outliers 

67 mark_timestamp: bool, default: True 

68 Set to True if name of visualization shall contain time of creation. 

69 If set to False, visualization will override previous file 

70 filename: str, default "3d_beta_plot" 

71 Name of .hmtl file 

72 write: bool, default: True 

73 Set to False to not write .html file and return as string instead 

74 show_res: bool, default: True 

75 Set to False to not open resulting page in webbrowser 

76 

77 Returns 

78 ------- 

79 html_str_formatted: str 

80 If **write=False** returns .hmtl file as string, else None 

81 """ 

82 

83 # pylint: disable = too-many-arguments, too-many-locals 

84 

85 if not isinstance(img_path, str): 

86 img_path = "" 

87 

88 colorlist = [ 

89 "#1f77b4", 

90 "#ff7f0e", 

91 "#2ca02c", 

92 "#d62728", 

93 "#9467bd", 

94 "#8c564b", 

95 "#e377c2", 

96 "#7f7f7f", 

97 "#bcbd22", 

98 "#17becf", 

99 ] 

100 tracelist = [] 

101 

102 # rescaling betas to better fit in viz 

103 scale_multiplier = 300 

104 max_val = ( 

105 max(*[max(np.max(cluster), abs(np.min(cluster))) for cluster in beta_cluster]) 

106 if len(beta_cluster) > 1 

107 else max(np.max(beta_cluster[0]), abs(np.min(beta_cluster[0]))) 

108 ) 

109 

110 beta_cluster = [cluster / max_val * scale_multiplier for cluster in beta_cluster] 

111 

112 id_nr = [] 

113 for group in id_cluster: 

114 id_group = [] 

115 for entry in group: 

116 nr = re.findall(r"\d+", entry)[0] 

117 id_group.append(nr) 

118 id_nr.append(id_group) 

119 

120 # pylint: disable = consider-using-f-string 

121 _three_min_ = '<script type="text/javascript">%s</script>' % _read_file( 

122 os.path.join( 

123 # move path to "~/lasso/" 

124 os.path.split(os.path.split(os.path.dirname(__file__))[0])[0], 

125 "plotting/resources/three_latest.min.js", 

126 ) 

127 ) 

128 

129 html_str_formatted = OVERHEAD_STRING + CONST_STRING.format( 

130 _three_min_=_three_min_, _path_str_=img_path, _runIdEntries_=id_nr 

131 ) 

132 for index, cluster in enumerate(beta_cluster): 

133 name = "Error, my bad" 

134 color = "pink" 

135 if (index == 0) and mark_outliers: 

136 name = "outliers" 

137 color = "black" 

138 else: 

139 name = "cluster {i}".format(i=index) 

140 color = colorlist[(index - 1) % 10] 

141 formated_trace = TRACE_STRING.format( 

142 _traceNr_="trace{i}".format(i=index), 

143 _name_=name, 

144 _color_=color, 

145 _runIDs_=id_cluster[index].tolist(), 

146 _x_=np.around(cluster[:, 0], decimals=5).tolist(), 

147 _y_=np.around(cluster[:, 1], decimals=5).tolist(), 

148 _z_=np.around(cluster[:, 2], decimals=5).tolist(), 

149 ) 

150 tracelist.append(f"trace{index}") 

151 html_str_formatted += formated_trace 

152 trace_list_string = " traceList = [" 

153 for trace in tracelist: 

154 trace_list_string += trace + ", " 

155 trace_list_string += "]" 

156 html_str_formatted += trace_list_string 

157 html_str_formatted += SCRIPT_STRING 

158 

159 if write: 

160 os.makedirs(save_path, exist_ok=True) 

161 

162 # Timestamp for differentiating different viz / not override previous viz 

163 stamp = timestamp() if mark_timestamp else "" 

164 

165 output_filepath = os.path.join(save_path, stamp + filename + ".html") 

166 with open(output_filepath, "w", encoding="utf-8") as f: 

167 f.write(html_str_formatted) 

168 if show_res: 

169 webbrowser.open("file://" + os.path.realpath(output_filepath)) 

170 else: 

171 # only needed for testcases 

172 return html_str_formatted