Coverage for lasso/dimred/svd/subsampling_methods.py: 53%
120 statements
« prev ^ index » next coverage.py v7.2.4, created at 2023-04-28 18:42 +0100
« prev ^ index » next coverage.py v7.2.4, created at 2023-04-28 18:42 +0100
1import os
2import random
3import time
4from typing import List, Sequence, Tuple, Union
6import numpy as np
7from sklearn.neighbors import NearestNeighbors
9from ...dyna import ArrayType, D3plot
12def _mark_dead_eles(node_indexes: np.ndarray, alive_shells: np.ndarray) -> np.ndarray:
13 """
14 Returns a mask to filter out elements mark as 'no alive'
16 Parameters
17 ----------
18 node_indexes: ndarray
19 Array containing node indexes
20 alive_nodes: ndarray
21 Array containing float value representing if element is alive.
22 Expected for D3plot.arrays[ArrayType.element_shell_is_alive] or equivalent for beams etc
24 Returns
25 -------
26 node_coordinate_mask: np.ndarray
27 Array containing indizes of alive shells.
28 Use node_coordinates[node_coordinate_mask] to get all nodes alive.
30 See Also
31 --------
32 bury_the_dead(), also removes dead beam nodes
33 """
35 dead_eles_shell = np.unique(np.where(alive_shells == 0)[1])
37 ele_filter = np.zeros(node_indexes.shape[0])
38 ele_filter[dead_eles_shell] = 1
39 ele_filter_bool = ele_filter == 1
41 dead_nodes = np.unique(node_indexes[ele_filter_bool])
43 return dead_nodes
46def _extract_shell_parts(
47 part_list: Sequence[int], d3plot: D3plot
48) -> Union[Tuple[np.ndarray, np.ndarray], str]:
49 """
50 Extracts a shell part defined by its part ID out of the given d3plot.
51 Returns a new node index, relevant coordinates and displacement
53 Parameters
54 ----------
55 part_list: list
56 List of part IDs of the parts that shall be extracted
57 d3plot: D3plot
58 D3plot the part shall be extracted from
60 Returns
61 -------
62 node_coordinates: ndarray
63 Numpy array containing the node coordinates of the extracted part
64 node_displacement: ndarray
65 Numpy array containing the node displacement of the extracted part
66 err_msg: str
67 If an error occurs, a string containing the error msg is returned instead
68 """
70 # pylint: disable = too-many-locals, too-many-statements
72 # convert into list
73 part_list = list(part_list)
75 shell_node_indexes = d3plot.arrays[ArrayType.element_shell_node_indexes]
76 shell_part_indexes = d3plot.arrays[ArrayType.element_shell_part_indexes]
77 beam_node_indexes = d3plot.arrays[ArrayType.element_beam_node_indexes]
78 beam_part_indexes = d3plot.arrays[ArrayType.element_beam_part_indexes]
79 solid_node_indexes = d3plot.arrays[ArrayType.element_solid_node_indexes]
80 solid_part_indexes = d3plot.arrays[ArrayType.element_solid_part_indexes]
81 tshell_node_indexes = d3plot.arrays[ArrayType.element_tshell_node_indexes]
82 tshell_part_indexes = d3plot.arrays[ArrayType.element_tshell_part_indexes]
84 node_coordinates = d3plot.arrays[ArrayType.node_coordinates]
85 node_displacement = d3plot.arrays[ArrayType.node_displacement]
87 alive_mask = np.full((node_coordinates.shape[0]), True)
89 if ArrayType.element_shell_is_alive in d3plot.arrays:
90 dead_shell_mask = _mark_dead_eles(
91 shell_node_indexes, d3plot.arrays[ArrayType.element_shell_is_alive]
92 )
93 alive_mask[dead_shell_mask] = False
94 if ArrayType.element_beam_is_alive in d3plot.arrays:
95 dead_beam_mask = _mark_dead_eles(
96 beam_node_indexes, d3plot.arrays[ArrayType.element_beam_is_alive]
97 )
98 alive_mask[dead_beam_mask] = False
99 if ArrayType.element_solid_is_alive in d3plot.arrays:
100 dead_solid_mask = _mark_dead_eles(
101 solid_node_indexes, d3plot.arrays[ArrayType.element_solid_is_alive]
102 )
103 alive_mask[dead_solid_mask] = False
104 if ArrayType.element_tshell_is_alive in d3plot.arrays:
105 dead_tshell_mask = _mark_dead_eles(
106 tshell_node_indexes, d3plot.arrays[ArrayType.element_tshell_is_alive]
107 )
108 alive_mask[dead_tshell_mask] = False
110 if len(part_list) > 0:
111 try:
112 part_ids = d3plot.arrays[ArrayType.part_ids]
113 except KeyError:
114 err_msg = "KeyError: Loaded plot has no parts"
115 return err_msg
116 part_ids_as_list = part_ids.tolist()
117 # check if parts exist
118 for part in part_list:
119 try:
120 part_ids_as_list.index(int(part))
121 except ValueError:
122 err_msg = "ValueError: Could not find part: {0}"
123 return err_msg.format(part)
125 def mask_parts(
126 part_list2: List[int], element_part_index: np.ndarray, element_node_index: np.ndarray
127 ) -> np.ndarray:
129 element_part_filter = np.full(element_part_index.shape, False)
130 proc_parts = []
132 for pid in part_list2:
133 part_index = part_ids_as_list.index(int(pid))
134 locs = np.where(element_part_index == part_index)[0]
135 if not locs.shape == (0,):
136 proc_parts.append(pid)
137 element_part_filter[locs] = True
139 for part in proc_parts:
140 part_list2.pop(part_list2.index(part))
142 unique_element_node_indexes = np.unique(element_node_index[element_part_filter])
144 return unique_element_node_indexes
146 # shells:
147 unique_shell_node_indexes = mask_parts(part_list, shell_part_indexes, shell_node_indexes)
149 # beams
150 unique_beam_node_indexes = mask_parts(part_list, beam_part_indexes, beam_node_indexes)
152 # solids:
153 unique_solide_node_indexes = mask_parts(part_list, solid_part_indexes, solid_node_indexes)
155 # tshells
156 unique_tshell_node_indexes = mask_parts(part_list, tshell_part_indexes, tshell_node_indexes)
158 # this check may seem redundant, but also verifies that our masking of parts works
159 if not len(part_list) == 0:
160 err_msg = "Value Error: Could not find parts: " + str(part_list)
161 return err_msg
163 # New coordinate mask
164 coord_mask = np.full((node_coordinates.shape[0]), False)
165 coord_mask[unique_shell_node_indexes] = True
166 coord_mask[unique_solide_node_indexes] = True
167 coord_mask[unique_beam_node_indexes] = True
168 coord_mask[unique_tshell_node_indexes] = True
170 inv_alive_mask = np.logical_not(alive_mask)
171 coord_mask[inv_alive_mask] = False
173 node_coordinates = node_coordinates[coord_mask]
174 node_displacement = node_displacement[:, coord_mask]
175 else:
176 node_coordinates = node_coordinates[alive_mask]
177 node_displacement = node_displacement[:, alive_mask]
179 return node_coordinates, node_displacement
182def create_reference_subsample(
183 load_path: str, parts: Sequence[int], nr_samples=2000
184) -> Union[Tuple[np.ndarray, float, float], str]:
185 """
186 Loads the D3plot at load_path, extracts the node coordinates of part 13, returns
187 a random subsample of these nodes
189 Parameters
190 ----------
191 load_path: str
192 Filepath of the D3plot
193 parts: Sequence[int]
194 List of parts to be extracted
195 nr_samples: int
196 How many nodes are subsampled
198 Returns
199 -------
200 reference_sample: np.array
201 Numpy array containing the reference sample
202 t_total: float
203 Total time required for subsampling
204 t_load: float
205 Time required to load plot
206 err_msg: str
207 If an error occurs, a string containing the error is returned instead
208 """
209 t_null = time.time()
210 try:
211 plot = D3plot(
212 load_path,
213 state_array_filter=[ArrayType.node_displacement, ArrayType.element_shell_is_alive],
214 )
215 except Exception:
216 err_msg = (
217 f"Failed to load {load_path}! Please make sure it is a D3plot file. "
218 f"This might be due to {os.path.split(load_path)[1]} being a timestep of a plot"
219 )
220 return err_msg
222 t_load = time.time() - t_null
223 result = _extract_shell_parts(parts, plot)
224 if isinstance(result, str):
225 return result
227 coordinates = result[0]
228 if coordinates.shape[0] < nr_samples:
229 err_msg = "Number of nodes is lower than desired samplesize"
230 return err_msg
232 random.seed("seed")
233 samples = random.sample(range(len(coordinates)), nr_samples)
235 reference_sample = coordinates[samples]
236 t_total = time.time() - t_null
237 return reference_sample, t_total, t_load
240def remap_random_subsample(
241 load_path: str, parts: list, reference_subsample: np.ndarray
242) -> Union[Tuple[np.ndarray, float, float], str]:
243 """
244 Remaps the specified sample onto a new mesh provided by reference subsampl, using knn matching
246 Parameters
247 ----------
248 load_path: str
249 Filepath of the desired D3plot
250 parts: list of int
251 Which parts shall be extracted from the D3plot
252 reference_subsample: np.array
253 Numpy array containing the reference nodes
255 Returns
256 -------
257 subsampled_displacement: np.ndarray
258 Subsampled displacement of provided sample
259 t_total: float
260 Total time required to perform subsampling
261 t_load: float
262 Time required to load D3plot
263 err_msg: str
264 If an error occured, a string is returned instead containing the error
265 """
266 t_null = time.time()
267 try:
268 plot = D3plot(
269 load_path,
270 state_array_filter=[ArrayType.node_displacement, ArrayType.element_shell_is_alive],
271 )
272 except Exception:
273 err_msg = (
274 f"Failed to load {load_path}! Please make sure it is a D3plot file. "
275 f"This might be due to {os.path.split(load_path)[1]} being a timestep of a plot"
276 )
277 return err_msg
279 t_load = time.time() - t_null
280 result = _extract_shell_parts(parts, plot)
281 if isinstance(result, str):
282 return result
284 coordinates, displacement = result[0], result[1]
286 quarantine_zone = NearestNeighbors(n_neighbors=1, n_jobs=4).fit(coordinates)
287 _, quarantined_index = quarantine_zone.kneighbors(reference_subsample)
289 subsampled_displacement = displacement[:, quarantined_index[:, 0]]
291 return subsampled_displacement, time.time() - t_null, t_load