tests.test_scan_pipeline
1import sys 2 3import numpy as np 4import pandas as pd 5 6from csi_images.csi_scans import Scan 7from csi_images.csi_events import EventArray 8from csi_analysis.pipelines import scan_pipeline 9 10 11class DummyPreprocessor(scan_pipeline.TilePreprocessor): 12 def __init__( 13 self, 14 scan: Scan, 15 version: str, 16 save: bool = False, 17 ): 18 """ 19 Must have a logging.Logger as self.log. 20 :param scan: scan metadata, which may be used for inferring parameters. 21 :param version: a version string, recommended to be an ISO date. 22 :param save: whether to save the immediate results of this module. 23 """ 24 self.scan = scan 25 self.version = version 26 self.save = save 27 28 def __repr__(self): 29 return f"{self.__class__.__name__}-{self.version})" 30 31 def preprocess(self, frame_images: list[np.ndarray]) -> list[np.ndarray]: 32 return frame_images 33 34 35class DummySegmenter(scan_pipeline.TileSegmenter): 36 def __init__( 37 self, 38 scan: Scan, 39 version: str, 40 save: bool = False, 41 ): 42 self.scan = scan 43 self.version = version 44 self.save = save 45 # List of output mask types that this segmenter can output; must exist 46 self.mask_types = [mask_type for mask_type in scan_pipeline.MaskType] 47 48 def __repr__(self): 49 return f"{self.__class__.__name__}-{self.version})" 50 51 def segment( 52 self, frame_images: list[np.ndarray] 53 ) -> dict[scan_pipeline.MaskType, np.ndarray]: 54 mask = np.zeros(frame_images[0].shape).astype(np.uint16) 55 mask[100:200, 100:200] = 1 56 return {scan_pipeline.MaskType.EVENT: mask} 57 58 59class DummyImageFilter(scan_pipeline.ImageFilter): 60 def __init__( 61 self, 62 scan: Scan, 63 version: str, 64 save: bool = False, 65 ): 66 self.scan = scan 67 self.version = version 68 self.save = save 69 70 def __repr__(self): 71 return f"{self.__class__.__name__}-{self.version})" 72 73 def filter_images( 74 self, 75 frame_images: list[np.ndarray], 76 masks: dict[scan_pipeline.MaskType, np.ndarray], 77 ) -> dict[scan_pipeline.MaskType, np.ndarray]: 78 return masks 79 80 81class DummyFeatureExtractor(scan_pipeline.FeatureExtractor): 82 def __init__( 83 self, 84 scan: Scan, 85 version: str, 86 save: bool = False, 87 ): 88 self.scan = scan 89 self.version = version 90 self.save = save 91 92 def __repr__(self): 93 return f"{self.__class__.__name__}-{self.version})" 94 95 def extract_features( 96 self, 97 frame_images: list[np.ndarray], 98 masks: dict[scan_pipeline.MaskType, np.ndarray], 99 events: EventArray, 100 ) -> pd.DataFrame: 101 features = pd.DataFrame({"mean_intensity": [np.mean(frame_images[0])]}) 102 return features 103 104 105class DummyFeatureFilter(scan_pipeline.FeatureFilter): 106 def __init__( 107 self, 108 scan: Scan, 109 version: str, 110 save: bool = False, 111 ): 112 self.scan = scan 113 self.version = version 114 self.save = save 115 116 def __repr__(self): 117 return f"{self.__class__.__name__}-{self.version})" 118 119 def filter_features(self, events: EventArray) -> tuple[EventArray, EventArray]: 120 return events, EventArray() 121 122 123class DummyClassifier(scan_pipeline.EventClassifier): 124 def __init__( 125 self, 126 scan: Scan, 127 version: str, 128 save: bool = False, 129 ): 130 self.scan = scan 131 self.version = version 132 self.save = save 133 134 def __repr__(self): 135 return f"{self.__class__.__name__}-{self.version})" 136 137 def classify_events(self, events: EventArray) -> EventArray: 138 events.add_metadata( 139 pd.DataFrame( 140 {f"model_classification{len(events)}": ["dummy"] * len(events)} 141 ) 142 ) 143 return events 144 145 146def test_scan_pipeline(): 147 scan = Scan.load_yaml("tests/data") 148 log_options = { 149 sys.stderr: {"level": "DEBUG", "colorize": True}, 150 } 151 pipeline = scan_pipeline.TilingScanPipeline( 152 scan, 153 output_path="tests/data", 154 preprocessors=[DummyPreprocessor(scan, "2024-10-30")], 155 segmenters=[DummySegmenter(scan, "2024-10-30")], 156 image_filters=[DummyImageFilter(scan, "2024-10-30")], 157 feature_extractors=[DummyFeatureExtractor(scan, "2024-10-30")], 158 tile_feature_filters=[DummyFeatureFilter(scan, "2024-10-30")], 159 tile_event_classifiers=[DummyClassifier(scan, "2024-10-30")], 160 scan_feature_filters=[DummyFeatureFilter(scan, "2024-10-30")], 161 scan_event_classifiers=[DummyClassifier(scan, "2024-10-30")], 162 max_workers=1, 163 log_options=log_options, 164 ) 165 events = pipeline.run() 166 assert ( 167 len(events) == scan_pipeline.roi[0].tile_rows * scan_pipeline.roi[0].tile_cols 168 )
12class DummyPreprocessor(scan_pipeline.TilePreprocessor): 13 def __init__( 14 self, 15 scan: Scan, 16 version: str, 17 save: bool = False, 18 ): 19 """ 20 Must have a logging.Logger as self.log. 21 :param scan: scan metadata, which may be used for inferring parameters. 22 :param version: a version string, recommended to be an ISO date. 23 :param save: whether to save the immediate results of this module. 24 """ 25 self.scan = scan 26 self.version = version 27 self.save = save 28 29 def __repr__(self): 30 return f"{self.__class__.__name__}-{self.version})" 31 32 def preprocess(self, frame_images: list[np.ndarray]) -> list[np.ndarray]: 33 return frame_images
Abstract class for a tile preprocessor.
13 def __init__( 14 self, 15 scan: Scan, 16 version: str, 17 save: bool = False, 18 ): 19 """ 20 Must have a logging.Logger as self.log. 21 :param scan: scan metadata, which may be used for inferring parameters. 22 :param version: a version string, recommended to be an ISO date. 23 :param save: whether to save the immediate results of this module. 24 """ 25 self.scan = scan 26 self.version = version 27 self.save = save
Must have a logging.Logger as self.log.
Parameters
- scan: scan metadata, which may be used for inferring parameters.
- version: a version string, recommended to be an ISO date.
- save: whether to save the immediate results of this module.
Preprocess the frames of a tile, preferably in-place. Should return the frames in the same order. No coordinate system changes should occur here, as they are handled elsewhere.
Parameters
- images: a list of np.ndarrays, each representing a frame.
Returns
a list of np.ndarrays, each representing a frame.
Inherited Members
36class DummySegmenter(scan_pipeline.TileSegmenter): 37 def __init__( 38 self, 39 scan: Scan, 40 version: str, 41 save: bool = False, 42 ): 43 self.scan = scan 44 self.version = version 45 self.save = save 46 # List of output mask types that this segmenter can output; must exist 47 self.mask_types = [mask_type for mask_type in scan_pipeline.MaskType] 48 49 def __repr__(self): 50 return f"{self.__class__.__name__}-{self.version})" 51 52 def segment( 53 self, frame_images: list[np.ndarray] 54 ) -> dict[scan_pipeline.MaskType, np.ndarray]: 55 mask = np.zeros(frame_images[0].shape).astype(np.uint16) 56 mask[100:200, 100:200] = 1 57 return {scan_pipeline.MaskType.EVENT: mask}
Abstract class for a tile segmenter.
52 def segment( 53 self, frame_images: list[np.ndarray] 54 ) -> dict[scan_pipeline.MaskType, np.ndarray]: 55 mask = np.zeros(frame_images[0].shape).astype(np.uint16) 56 mask[100:200, 100:200] = 1 57 return {scan_pipeline.MaskType.EVENT: mask}
Segments the frames of a tile to enumerated mask(s), not modifying images. Mask(s) should be returned in a dict with labeled types.
Parameters
- images: a list of np.ndarrays, each representing a frame.
Returns
a dict of np.ndarrays, each representing a mask.
Inherited Members
60class DummyImageFilter(scan_pipeline.ImageFilter): 61 def __init__( 62 self, 63 scan: Scan, 64 version: str, 65 save: bool = False, 66 ): 67 self.scan = scan 68 self.version = version 69 self.save = save 70 71 def __repr__(self): 72 return f"{self.__class__.__name__}-{self.version})" 73 74 def filter_images( 75 self, 76 frame_images: list[np.ndarray], 77 masks: dict[scan_pipeline.MaskType, np.ndarray], 78 ) -> dict[scan_pipeline.MaskType, np.ndarray]: 79 return masks
Abstract class for an image-based event filter.
74 def filter_images( 75 self, 76 frame_images: list[np.ndarray], 77 masks: dict[scan_pipeline.MaskType, np.ndarray], 78 ) -> dict[scan_pipeline.MaskType, np.ndarray]: 79 return masks
Using images and masks, returns new masks that should have filtered out unwanted objects from the existing masks. Should not be in-place, i.e. should not modify images or masks. Returns a dict of masks that will overwrite the existing masks on identical keys.
Parameters
- images: a list of np.ndarrays, each representing a frame.
- masks: a dict of np.ndarrays, each representing a mask.
Returns
a dict of np.ndarrays, each representing a mask; now filtered.
Inherited Members
82class DummyFeatureExtractor(scan_pipeline.FeatureExtractor): 83 def __init__( 84 self, 85 scan: Scan, 86 version: str, 87 save: bool = False, 88 ): 89 self.scan = scan 90 self.version = version 91 self.save = save 92 93 def __repr__(self): 94 return f"{self.__class__.__name__}-{self.version})" 95 96 def extract_features( 97 self, 98 frame_images: list[np.ndarray], 99 masks: dict[scan_pipeline.MaskType, np.ndarray], 100 events: EventArray, 101 ) -> pd.DataFrame: 102 features = pd.DataFrame({"mean_intensity": [np.mean(frame_images[0])]}) 103 return features
Abstract class for a feature extractor.
96 def extract_features( 97 self, 98 frame_images: list[np.ndarray], 99 masks: dict[scan_pipeline.MaskType, np.ndarray], 100 events: EventArray, 101 ) -> pd.DataFrame: 102 features = pd.DataFrame({"mean_intensity": [np.mean(frame_images[0])]}) 103 return features
Using images, masks, and events, returns new features as a pd.DataFrame.
Parameters
- images: a list of np.ndarrays, each representing a frame.
- masks: a dict of np.ndarrays, each representing a mask.
- events: an EventArray, potentially with populated feature data.
Returns
an EventArray with new populated feature data.
Inherited Members
106class DummyFeatureFilter(scan_pipeline.FeatureFilter): 107 def __init__( 108 self, 109 scan: Scan, 110 version: str, 111 save: bool = False, 112 ): 113 self.scan = scan 114 self.version = version 115 self.save = save 116 117 def __repr__(self): 118 return f"{self.__class__.__name__}-{self.version})" 119 120 def filter_features(self, events: EventArray) -> tuple[EventArray, EventArray]: 121 return events, EventArray()
Abstract class for a feature-based event filter.
120 def filter_features(self, events: EventArray) -> tuple[EventArray, EventArray]: 121 return events, EventArray()
Removes events from an event array based on feature values.
Parameters
- events: a EventArray with populated features.
Returns
two EventArray objects: tuple[remaining, filtered]
Inherited Members
124class DummyClassifier(scan_pipeline.EventClassifier): 125 def __init__( 126 self, 127 scan: Scan, 128 version: str, 129 save: bool = False, 130 ): 131 self.scan = scan 132 self.version = version 133 self.save = save 134 135 def __repr__(self): 136 return f"{self.__class__.__name__}-{self.version})" 137 138 def classify_events(self, events: EventArray) -> EventArray: 139 events.add_metadata( 140 pd.DataFrame( 141 {f"model_classification{len(events)}": ["dummy"] * len(events)} 142 ) 143 ) 144 return events
Abstract class for an event classifier.
138 def classify_events(self, events: EventArray) -> EventArray: 139 events.add_metadata( 140 pd.DataFrame( 141 {f"model_classification{len(events)}": ["dummy"] * len(events)} 142 ) 143 ) 144 return events
Classifies events based on features, then populates the metadata.
Parameters
- events: a EventArray with populated features.
Returns
a EventArray with populated metadata.
Inherited Members
147def test_scan_pipeline(): 148 scan = Scan.load_yaml("tests/data") 149 log_options = { 150 sys.stderr: {"level": "DEBUG", "colorize": True}, 151 } 152 pipeline = scan_pipeline.TilingScanPipeline( 153 scan, 154 output_path="tests/data", 155 preprocessors=[DummyPreprocessor(scan, "2024-10-30")], 156 segmenters=[DummySegmenter(scan, "2024-10-30")], 157 image_filters=[DummyImageFilter(scan, "2024-10-30")], 158 feature_extractors=[DummyFeatureExtractor(scan, "2024-10-30")], 159 tile_feature_filters=[DummyFeatureFilter(scan, "2024-10-30")], 160 tile_event_classifiers=[DummyClassifier(scan, "2024-10-30")], 161 scan_feature_filters=[DummyFeatureFilter(scan, "2024-10-30")], 162 scan_event_classifiers=[DummyClassifier(scan, "2024-10-30")], 163 max_workers=1, 164 log_options=log_options, 165 ) 166 events = pipeline.run() 167 assert ( 168 len(events) == scan_pipeline.roi[0].tile_rows * scan_pipeline.roi[0].tile_cols 169 )