Coverage for lasso/io/files.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.2.4, created at 2023-04-28 18:42 +0100

1import contextlib 

2import glob 

3import os 

4import typing 

5from typing import Iterator, List, Union 

6 

7 

8@contextlib.contextmanager 

9def open_file_or_filepath( 

10 path_or_file: Union[str, typing.BinaryIO], mode: str 

11) -> Iterator[typing.BinaryIO]: 

12 """This function accepts a file or filepath and handles closing correctly 

13 

14 Parameters 

15 ---------- 

16 path_or_file: Union[str, typing.IO] 

17 path or file 

18 mode: str 

19 filemode 

20 

21 Yields 

22 ------ 

23 f: file object 

24 """ 

25 if isinstance(path_or_file, str): 

26 # We open this file in binary mode anyway so no encoding is needed. 

27 # pylint: disable = unspecified-encoding 

28 f = file_to_close = open(path_or_file, mode) 

29 else: 

30 f = path_or_file 

31 file_to_close = None 

32 try: 

33 yield f 

34 finally: 

35 if file_to_close: 

36 file_to_close.close() 

37 

38 

39def collect_files( 

40 dirpath: Union[str, List[str]], patterns: Union[str, List[str]], recursive: bool = False 

41): 

42 """Collect files from directories 

43 

44 Parameters 

45 ---------- 

46 dirpath: Union[str, List[str]] 

47 path to one or multiple directories to search through 

48 patterns: Union[str, List[str]] 

49 patterns to search for 

50 recursive: bool 

51 whether to also search subdirs 

52 

53 Returns 

54 ------- 

55 found_files: Union[List[str], List[List[str]]] 

56 returns the list of files found for every pattern specified 

57 

58 Examples 

59 -------- 

60 >>> png_images, jpeg_images = collect_files('./folder', ['*.png', '*.jpeg']) 

61 """ 

62 

63 if not isinstance(dirpath, (list, tuple)): 

64 dirpath = [dirpath] 

65 if not isinstance(patterns, (list, tuple)): 

66 patterns = [patterns] 

67 

68 found_files = [] 

69 for pattern in patterns: 

70 

71 files_with_pattern = [] 

72 for current_dir in dirpath: 

73 # files in root dir 

74 files_with_pattern += glob.glob(os.path.join(current_dir, pattern)) 

75 # subfolders 

76 if recursive: 

77 files_with_pattern += glob.glob(os.path.join(current_dir, "**", pattern)) 

78 

79 found_files.append(sorted(files_with_pattern)) 

80 

81 if len(found_files) == 1: 

82 return found_files[0] 

83 

84 return found_files