Coverage for lasso/dimred/svd/subsampling_methods.py: 53%

120 statements  

« prev     ^ index     » next       coverage.py v7.2.4, created at 2023-04-28 18:42 +0100

1import os 

2import random 

3import time 

4from typing import List, Sequence, Tuple, Union 

5 

6import numpy as np 

7from sklearn.neighbors import NearestNeighbors 

8 

9from ...dyna import ArrayType, D3plot 

10 

11 

12def _mark_dead_eles(node_indexes: np.ndarray, alive_shells: np.ndarray) -> np.ndarray: 

13 """ 

14 Returns a mask to filter out elements mark as 'no alive' 

15 

16 Parameters 

17 ---------- 

18 node_indexes: ndarray 

19 Array containing node indexes 

20 alive_nodes: ndarray 

21 Array containing float value representing if element is alive. 

22 Expected for D3plot.arrays[ArrayType.element_shell_is_alive] or equivalent for beams etc 

23 

24 Returns 

25 ------- 

26 node_coordinate_mask: np.ndarray 

27 Array containing indizes of alive shells. 

28 Use node_coordinates[node_coordinate_mask] to get all nodes alive. 

29 

30 See Also 

31 -------- 

32 bury_the_dead(), also removes dead beam nodes 

33 """ 

34 

35 dead_eles_shell = np.unique(np.where(alive_shells == 0)[1]) 

36 

37 ele_filter = np.zeros(node_indexes.shape[0]) 

38 ele_filter[dead_eles_shell] = 1 

39 ele_filter_bool = ele_filter == 1 

40 

41 dead_nodes = np.unique(node_indexes[ele_filter_bool]) 

42 

43 return dead_nodes 

44 

45 

46def _extract_shell_parts( 

47 part_list: Sequence[int], d3plot: D3plot 

48) -> Union[Tuple[np.ndarray, np.ndarray], str]: 

49 """ 

50 Extracts a shell part defined by its part ID out of the given d3plot. 

51 Returns a new node index, relevant coordinates and displacement 

52 

53 Parameters 

54 ---------- 

55 part_list: list 

56 List of part IDs of the parts that shall be extracted 

57 d3plot: D3plot 

58 D3plot the part shall be extracted from 

59 

60 Returns 

61 ------- 

62 node_coordinates: ndarray 

63 Numpy array containing the node coordinates of the extracted part 

64 node_displacement: ndarray 

65 Numpy array containing the node displacement of the extracted part 

66 err_msg: str 

67 If an error occurs, a string containing the error msg is returned instead 

68 """ 

69 

70 # pylint: disable = too-many-locals, too-many-statements 

71 

72 # convert into list 

73 part_list = list(part_list) 

74 

75 shell_node_indexes = d3plot.arrays[ArrayType.element_shell_node_indexes] 

76 shell_part_indexes = d3plot.arrays[ArrayType.element_shell_part_indexes] 

77 beam_node_indexes = d3plot.arrays[ArrayType.element_beam_node_indexes] 

78 beam_part_indexes = d3plot.arrays[ArrayType.element_beam_part_indexes] 

79 solid_node_indexes = d3plot.arrays[ArrayType.element_solid_node_indexes] 

80 solid_part_indexes = d3plot.arrays[ArrayType.element_solid_part_indexes] 

81 tshell_node_indexes = d3plot.arrays[ArrayType.element_tshell_node_indexes] 

82 tshell_part_indexes = d3plot.arrays[ArrayType.element_tshell_part_indexes] 

83 

84 node_coordinates = d3plot.arrays[ArrayType.node_coordinates] 

85 node_displacement = d3plot.arrays[ArrayType.node_displacement] 

86 

87 alive_mask = np.full((node_coordinates.shape[0]), True) 

88 

89 if ArrayType.element_shell_is_alive in d3plot.arrays: 

90 dead_shell_mask = _mark_dead_eles( 

91 shell_node_indexes, d3plot.arrays[ArrayType.element_shell_is_alive] 

92 ) 

93 alive_mask[dead_shell_mask] = False 

94 if ArrayType.element_beam_is_alive in d3plot.arrays: 

95 dead_beam_mask = _mark_dead_eles( 

96 beam_node_indexes, d3plot.arrays[ArrayType.element_beam_is_alive] 

97 ) 

98 alive_mask[dead_beam_mask] = False 

99 if ArrayType.element_solid_is_alive in d3plot.arrays: 

100 dead_solid_mask = _mark_dead_eles( 

101 solid_node_indexes, d3plot.arrays[ArrayType.element_solid_is_alive] 

102 ) 

103 alive_mask[dead_solid_mask] = False 

104 if ArrayType.element_tshell_is_alive in d3plot.arrays: 

105 dead_tshell_mask = _mark_dead_eles( 

106 tshell_node_indexes, d3plot.arrays[ArrayType.element_tshell_is_alive] 

107 ) 

108 alive_mask[dead_tshell_mask] = False 

109 

110 if len(part_list) > 0: 

111 try: 

112 part_ids = d3plot.arrays[ArrayType.part_ids] 

113 except KeyError: 

114 err_msg = "KeyError: Loaded plot has no parts" 

115 return err_msg 

116 part_ids_as_list = part_ids.tolist() 

117 # check if parts exist 

118 for part in part_list: 

119 try: 

120 part_ids_as_list.index(int(part)) 

121 except ValueError: 

122 err_msg = "ValueError: Could not find part: {0}" 

123 return err_msg.format(part) 

124 

125 def mask_parts( 

126 part_list2: List[int], element_part_index: np.ndarray, element_node_index: np.ndarray 

127 ) -> np.ndarray: 

128 

129 element_part_filter = np.full(element_part_index.shape, False) 

130 proc_parts = [] 

131 

132 for pid in part_list2: 

133 part_index = part_ids_as_list.index(int(pid)) 

134 locs = np.where(element_part_index == part_index)[0] 

135 if not locs.shape == (0,): 

136 proc_parts.append(pid) 

137 element_part_filter[locs] = True 

138 

139 for part in proc_parts: 

140 part_list2.pop(part_list2.index(part)) 

141 

142 unique_element_node_indexes = np.unique(element_node_index[element_part_filter]) 

143 

144 return unique_element_node_indexes 

145 

146 # shells: 

147 unique_shell_node_indexes = mask_parts(part_list, shell_part_indexes, shell_node_indexes) 

148 

149 # beams 

150 unique_beam_node_indexes = mask_parts(part_list, beam_part_indexes, beam_node_indexes) 

151 

152 # solids: 

153 unique_solide_node_indexes = mask_parts(part_list, solid_part_indexes, solid_node_indexes) 

154 

155 # tshells 

156 unique_tshell_node_indexes = mask_parts(part_list, tshell_part_indexes, tshell_node_indexes) 

157 

158 # this check may seem redundant, but also verifies that our masking of parts works 

159 if not len(part_list) == 0: 

160 err_msg = "Value Error: Could not find parts: " + str(part_list) 

161 return err_msg 

162 

163 # New coordinate mask 

164 coord_mask = np.full((node_coordinates.shape[0]), False) 

165 coord_mask[unique_shell_node_indexes] = True 

166 coord_mask[unique_solide_node_indexes] = True 

167 coord_mask[unique_beam_node_indexes] = True 

168 coord_mask[unique_tshell_node_indexes] = True 

169 

170 inv_alive_mask = np.logical_not(alive_mask) 

171 coord_mask[inv_alive_mask] = False 

172 

173 node_coordinates = node_coordinates[coord_mask] 

174 node_displacement = node_displacement[:, coord_mask] 

175 else: 

176 node_coordinates = node_coordinates[alive_mask] 

177 node_displacement = node_displacement[:, alive_mask] 

178 

179 return node_coordinates, node_displacement 

180 

181 

182def create_reference_subsample( 

183 load_path: str, parts: Sequence[int], nr_samples=2000 

184) -> Union[Tuple[np.ndarray, float, float], str]: 

185 """ 

186 Loads the D3plot at load_path, extracts the node coordinates of part 13, returns 

187 a random subsample of these nodes 

188 

189 Parameters 

190 ---------- 

191 load_path: str 

192 Filepath of the D3plot 

193 parts: Sequence[int] 

194 List of parts to be extracted 

195 nr_samples: int 

196 How many nodes are subsampled 

197 

198 Returns 

199 ------- 

200 reference_sample: np.array 

201 Numpy array containing the reference sample 

202 t_total: float 

203 Total time required for subsampling 

204 t_load: float 

205 Time required to load plot 

206 err_msg: str 

207 If an error occurs, a string containing the error is returned instead 

208 """ 

209 t_null = time.time() 

210 try: 

211 plot = D3plot( 

212 load_path, 

213 state_array_filter=[ArrayType.node_displacement, ArrayType.element_shell_is_alive], 

214 ) 

215 except Exception: 

216 err_msg = ( 

217 f"Failed to load {load_path}! Please make sure it is a D3plot file. " 

218 f"This might be due to {os.path.split(load_path)[1]} being a timestep of a plot" 

219 ) 

220 return err_msg 

221 

222 t_load = time.time() - t_null 

223 result = _extract_shell_parts(parts, plot) 

224 if isinstance(result, str): 

225 return result 

226 

227 coordinates = result[0] 

228 if coordinates.shape[0] < nr_samples: 

229 err_msg = "Number of nodes is lower than desired samplesize" 

230 return err_msg 

231 

232 random.seed("seed") 

233 samples = random.sample(range(len(coordinates)), nr_samples) 

234 

235 reference_sample = coordinates[samples] 

236 t_total = time.time() - t_null 

237 return reference_sample, t_total, t_load 

238 

239 

240def remap_random_subsample( 

241 load_path: str, parts: list, reference_subsample: np.ndarray 

242) -> Union[Tuple[np.ndarray, float, float], str]: 

243 """ 

244 Remaps the specified sample onto a new mesh provided by reference subsampl, using knn matching 

245 

246 Parameters 

247 ---------- 

248 load_path: str 

249 Filepath of the desired D3plot 

250 parts: list of int 

251 Which parts shall be extracted from the D3plot 

252 reference_subsample: np.array 

253 Numpy array containing the reference nodes 

254 

255 Returns 

256 ------- 

257 subsampled_displacement: np.ndarray 

258 Subsampled displacement of provided sample 

259 t_total: float 

260 Total time required to perform subsampling 

261 t_load: float 

262 Time required to load D3plot 

263 err_msg: str 

264 If an error occured, a string is returned instead containing the error 

265 """ 

266 t_null = time.time() 

267 try: 

268 plot = D3plot( 

269 load_path, 

270 state_array_filter=[ArrayType.node_displacement, ArrayType.element_shell_is_alive], 

271 ) 

272 except Exception: 

273 err_msg = ( 

274 f"Failed to load {load_path}! Please make sure it is a D3plot file. " 

275 f"This might be due to {os.path.split(load_path)[1]} being a timestep of a plot" 

276 ) 

277 return err_msg 

278 

279 t_load = time.time() - t_null 

280 result = _extract_shell_parts(parts, plot) 

281 if isinstance(result, str): 

282 return result 

283 

284 coordinates, displacement = result[0], result[1] 

285 

286 quarantine_zone = NearestNeighbors(n_neighbors=1, n_jobs=4).fit(coordinates) 

287 _, quarantined_index = quarantine_zone.kneighbors(reference_subsample) 

288 

289 subsampled_displacement = displacement[:, quarantined_index[:, 0]] 

290 

291 return subsampled_displacement, time.time() - t_null, t_load