Coverage for lasso/dimred/svd/test_subsampling_methods.py: 100%
46 statements
« prev ^ index » next coverage.py v7.2.4, created at 2023-04-28 18:42 +0100
« prev ^ index » next coverage.py v7.2.4, created at 2023-04-28 18:42 +0100
1import os
2import tempfile
3from typing import Tuple
4from unittest import TestCase
6import numpy as np
8from lasso.dimred.svd.subsampling_methods import create_reference_subsample, remap_random_subsample
9from lasso.dimred.test_plot_creator import create_2_fake_plots
12class TestSubsampling(TestCase):
13 def test_create_reference_sample(self):
14 """Tests the creation of reference sample"""
16 with tempfile.TemporaryDirectory() as tmp_dir:
18 create_2_fake_plots(tmp_dir, 500, 10)
19 load_path = os.path.join(tmp_dir, "SVDTestPlot00/plot")
20 n_nodes = 200
22 result = create_reference_subsample(load_path, parts=[], nr_samples=n_nodes)
24 # result should be tuple containing subsample, load time and total process time
25 self.assertTrue(isinstance(result, Tuple))
27 ref_sample, t_total, t_load = result
29 # check for correct types
30 self.assertTrue(isinstance(ref_sample, np.ndarray))
31 self.assertTrue(isinstance(t_total, float))
32 self.assertTrue(isinstance(t_load, float))
34 # t_total should be greater than t_load
35 self.assertTrue(t_total - t_load >= 0)
37 # check for correct dimensions of ref_sample
38 self.assertEqual(ref_sample.shape, (n_nodes, 3))
40 # should return string error message if desired samplesize is greater
41 # than avaiable nodes
42 n_nodes = 5500
43 result = create_reference_subsample(load_path, parts=[], nr_samples=n_nodes)
45 self.assertTrue(isinstance(result, str))
47 # should return string error message for nonexitant parts:
48 n_nodes = 200
49 result = create_reference_subsample(load_path, parts=[1], nr_samples=n_nodes)
51 self.assertTrue(isinstance(result, str))
53 def test_remap_random_subsample(self):
54 """Verifies correct subsampling"""
56 with tempfile.TemporaryDirectory() as tmp_dir:
58 create_2_fake_plots(tmp_dir, 500, 10)
59 ref_path = os.path.join(tmp_dir, "SVDTestPlot00/plot")
60 sample_path = os.path.join(tmp_dir, "SVDTestPlot01/plot")
61 n_nodes = 200
63 ref_result = create_reference_subsample(ref_path, parts=[], nr_samples=n_nodes)
65 ref_sample = ref_result[0]
67 sub_result = remap_random_subsample(
68 sample_path, parts=[], reference_subsample=ref_sample
69 )
71 # sub_result should be Tuple containing subsample, total process time,
72 # and plot load time
73 self.assertTrue(isinstance(sub_result, Tuple))
75 subsample, t_total, t_load = sub_result
77 # confirm correct types
78 self.assertTrue(isinstance(subsample, np.ndarray))
79 self.assertTrue(isinstance(t_total, float))
80 self.assertTrue(isinstance(t_load, float))
82 # t_total should be greater t_load
83 self.assertTrue(t_total - t_load >= 0)
85 # correct shape of subsample
86 self.assertEqual(subsample.shape, (5, n_nodes, 3))
88 # entries of subsmaple at timestep 0 should be the same as the reference sample
89 # this is only true for for the dimredTestPlots, this might not be the case
90 # with real plots we check if the difference is 0
91 self.assertTrue((ref_sample - subsample[0]).max() == 0)
93 # should return string error message for nonexistant parts:
94 err_msg = remap_random_subsample(sample_path, parts=[1], reference_subsample=ref_sample)
96 self.assertTrue(isinstance(err_msg, str))