Coverage for lasso/dimred/svd/test_subsampling_methods.py: 100%

46 statements  

« prev     ^ index     » next       coverage.py v7.2.4, created at 2023-04-28 18:42 +0100

1import os 

2import tempfile 

3from typing import Tuple 

4from unittest import TestCase 

5 

6import numpy as np 

7 

8from lasso.dimred.svd.subsampling_methods import create_reference_subsample, remap_random_subsample 

9from lasso.dimred.test_plot_creator import create_2_fake_plots 

10 

11 

12class TestSubsampling(TestCase): 

13 def test_create_reference_sample(self): 

14 """Tests the creation of reference sample""" 

15 

16 with tempfile.TemporaryDirectory() as tmp_dir: 

17 

18 create_2_fake_plots(tmp_dir, 500, 10) 

19 load_path = os.path.join(tmp_dir, "SVDTestPlot00/plot") 

20 n_nodes = 200 

21 

22 result = create_reference_subsample(load_path, parts=[], nr_samples=n_nodes) 

23 

24 # result should be tuple containing subsample, load time and total process time 

25 self.assertTrue(isinstance(result, Tuple)) 

26 

27 ref_sample, t_total, t_load = result 

28 

29 # check for correct types 

30 self.assertTrue(isinstance(ref_sample, np.ndarray)) 

31 self.assertTrue(isinstance(t_total, float)) 

32 self.assertTrue(isinstance(t_load, float)) 

33 

34 # t_total should be greater than t_load 

35 self.assertTrue(t_total - t_load >= 0) 

36 

37 # check for correct dimensions of ref_sample 

38 self.assertEqual(ref_sample.shape, (n_nodes, 3)) 

39 

40 # should return string error message if desired samplesize is greater 

41 # than avaiable nodes 

42 n_nodes = 5500 

43 result = create_reference_subsample(load_path, parts=[], nr_samples=n_nodes) 

44 

45 self.assertTrue(isinstance(result, str)) 

46 

47 # should return string error message for nonexitant parts: 

48 n_nodes = 200 

49 result = create_reference_subsample(load_path, parts=[1], nr_samples=n_nodes) 

50 

51 self.assertTrue(isinstance(result, str)) 

52 

53 def test_remap_random_subsample(self): 

54 """Verifies correct subsampling""" 

55 

56 with tempfile.TemporaryDirectory() as tmp_dir: 

57 

58 create_2_fake_plots(tmp_dir, 500, 10) 

59 ref_path = os.path.join(tmp_dir, "SVDTestPlot00/plot") 

60 sample_path = os.path.join(tmp_dir, "SVDTestPlot01/plot") 

61 n_nodes = 200 

62 

63 ref_result = create_reference_subsample(ref_path, parts=[], nr_samples=n_nodes) 

64 

65 ref_sample = ref_result[0] 

66 

67 sub_result = remap_random_subsample( 

68 sample_path, parts=[], reference_subsample=ref_sample 

69 ) 

70 

71 # sub_result should be Tuple containing subsample, total process time, 

72 # and plot load time 

73 self.assertTrue(isinstance(sub_result, Tuple)) 

74 

75 subsample, t_total, t_load = sub_result 

76 

77 # confirm correct types 

78 self.assertTrue(isinstance(subsample, np.ndarray)) 

79 self.assertTrue(isinstance(t_total, float)) 

80 self.assertTrue(isinstance(t_load, float)) 

81 

82 # t_total should be greater t_load 

83 self.assertTrue(t_total - t_load >= 0) 

84 

85 # correct shape of subsample 

86 self.assertEqual(subsample.shape, (5, n_nodes, 3)) 

87 

88 # entries of subsmaple at timestep 0 should be the same as the reference sample 

89 # this is only true for for the dimredTestPlots, this might not be the case 

90 # with real plots we check if the difference is 0 

91 self.assertTrue((ref_sample - subsample[0]).max() == 0) 

92 

93 # should return string error message for nonexistant parts: 

94 err_msg = remap_random_subsample(sample_path, parts=[1], reference_subsample=ref_sample) 

95 

96 self.assertTrue(isinstance(err_msg, str))