tests.test_scan_pipeline

  1import sys
  2
  3import numpy as np
  4import pandas as pd
  5
  6from csi_images.csi_scans import Scan
  7from csi_images.csi_events import EventArray
  8from csi_analysis.pipelines import scan_pipeline
  9
 10
 11class DummyPreprocessor(scan_pipeline.TilePreprocessor):
 12    def __init__(
 13        self,
 14        scan: Scan,
 15        version: str,
 16        save: bool = False,
 17    ):
 18        """
 19        Must have a logging.Logger as self.log.
 20        :param scan: scan metadata, which may be used for inferring parameters.
 21        :param version: a version string, recommended to be an ISO date.
 22        :param save: whether to save the immediate results of this module.
 23        """
 24        self.scan = scan
 25        self.version = version
 26        self.save = save
 27
 28    def __repr__(self):
 29        return f"{self.__class__.__name__}-{self.version})"
 30
 31    def preprocess(self, frame_images: list[np.ndarray]) -> list[np.ndarray]:
 32        return frame_images
 33
 34
 35class DummySegmenter(scan_pipeline.TileSegmenter):
 36    def __init__(
 37        self,
 38        scan: Scan,
 39        version: str,
 40        save: bool = False,
 41    ):
 42        self.scan = scan
 43        self.version = version
 44        self.save = save
 45        # List of output mask types that this segmenter can output; must exist
 46        self.mask_types = [mask_type for mask_type in scan_pipeline.MaskType]
 47
 48    def __repr__(self):
 49        return f"{self.__class__.__name__}-{self.version})"
 50
 51    def segment(
 52        self, frame_images: list[np.ndarray]
 53    ) -> dict[scan_pipeline.MaskType, np.ndarray]:
 54        mask = np.zeros(frame_images[0].shape).astype(np.uint16)
 55        mask[100:200, 100:200] = 1
 56        return {scan_pipeline.MaskType.EVENT: mask}
 57
 58
 59class DummyImageFilter(scan_pipeline.ImageFilter):
 60    def __init__(
 61        self,
 62        scan: Scan,
 63        version: str,
 64        save: bool = False,
 65    ):
 66        self.scan = scan
 67        self.version = version
 68        self.save = save
 69
 70    def __repr__(self):
 71        return f"{self.__class__.__name__}-{self.version})"
 72
 73    def filter_images(
 74        self,
 75        frame_images: list[np.ndarray],
 76        masks: dict[scan_pipeline.MaskType, np.ndarray],
 77    ) -> dict[scan_pipeline.MaskType, np.ndarray]:
 78        return masks
 79
 80
 81class DummyFeatureExtractor(scan_pipeline.FeatureExtractor):
 82    def __init__(
 83        self,
 84        scan: Scan,
 85        version: str,
 86        save: bool = False,
 87    ):
 88        self.scan = scan
 89        self.version = version
 90        self.save = save
 91
 92    def __repr__(self):
 93        return f"{self.__class__.__name__}-{self.version})"
 94
 95    def extract_features(
 96        self,
 97        frame_images: list[np.ndarray],
 98        masks: dict[scan_pipeline.MaskType, np.ndarray],
 99        events: EventArray,
100    ) -> pd.DataFrame:
101        features = pd.DataFrame({"mean_intensity": [np.mean(frame_images[0])]})
102        return features
103
104
105class DummyFeatureFilter(scan_pipeline.FeatureFilter):
106    def __init__(
107        self,
108        scan: Scan,
109        version: str,
110        save: bool = False,
111    ):
112        self.scan = scan
113        self.version = version
114        self.save = save
115
116    def __repr__(self):
117        return f"{self.__class__.__name__}-{self.version})"
118
119    def filter_features(self, events: EventArray) -> tuple[EventArray, EventArray]:
120        return events, EventArray()
121
122
123class DummyClassifier(scan_pipeline.EventClassifier):
124    def __init__(
125        self,
126        scan: Scan,
127        version: str,
128        save: bool = False,
129    ):
130        self.scan = scan
131        self.version = version
132        self.save = save
133
134    def __repr__(self):
135        return f"{self.__class__.__name__}-{self.version})"
136
137    def classify_events(self, events: EventArray) -> EventArray:
138        events.add_metadata(
139            pd.DataFrame(
140                {f"model_classification{len(events)}": ["dummy"] * len(events)}
141            )
142        )
143        return events
144
145
146def test_scan_pipeline():
147    scan = Scan.load_yaml("tests/data")
148    log_options = {
149        sys.stderr: {"level": "DEBUG", "colorize": True},
150    }
151    pipeline = scan_pipeline.TilingScanPipeline(
152        scan,
153        output_path="tests/data",
154        preprocessors=[DummyPreprocessor(scan, "2024-10-30")],
155        segmenters=[DummySegmenter(scan, "2024-10-30")],
156        image_filters=[DummyImageFilter(scan, "2024-10-30")],
157        feature_extractors=[DummyFeatureExtractor(scan, "2024-10-30")],
158        tile_feature_filters=[DummyFeatureFilter(scan, "2024-10-30")],
159        tile_event_classifiers=[DummyClassifier(scan, "2024-10-30")],
160        scan_feature_filters=[DummyFeatureFilter(scan, "2024-10-30")],
161        scan_event_classifiers=[DummyClassifier(scan, "2024-10-30")],
162        max_workers=1,
163        log_options=log_options,
164    )
165    events = pipeline.run()
166    assert (
167        len(events) == scan_pipeline.roi[0].tile_rows * scan_pipeline.roi[0].tile_cols
168    )
class DummyPreprocessor(csi_analysis.pipelines.scan_pipeline.TilePreprocessor):
12class DummyPreprocessor(scan_pipeline.TilePreprocessor):
13    def __init__(
14        self,
15        scan: Scan,
16        version: str,
17        save: bool = False,
18    ):
19        """
20        Must have a logging.Logger as self.log.
21        :param scan: scan metadata, which may be used for inferring parameters.
22        :param version: a version string, recommended to be an ISO date.
23        :param save: whether to save the immediate results of this module.
24        """
25        self.scan = scan
26        self.version = version
27        self.save = save
28
29    def __repr__(self):
30        return f"{self.__class__.__name__}-{self.version})"
31
32    def preprocess(self, frame_images: list[np.ndarray]) -> list[np.ndarray]:
33        return frame_images

Abstract class for a tile preprocessor.

DummyPreprocessor(scan: csi_images.csi_scans.Scan, version: str, save: bool = False)
13    def __init__(
14        self,
15        scan: Scan,
16        version: str,
17        save: bool = False,
18    ):
19        """
20        Must have a logging.Logger as self.log.
21        :param scan: scan metadata, which may be used for inferring parameters.
22        :param version: a version string, recommended to be an ISO date.
23        :param save: whether to save the immediate results of this module.
24        """
25        self.scan = scan
26        self.version = version
27        self.save = save

Must have a logging.Logger as self.log.

Parameters
  • scan: scan metadata, which may be used for inferring parameters.
  • version: a version string, recommended to be an ISO date.
  • save: whether to save the immediate results of this module.
scan
version
save = False
def preprocess(self, frame_images: list[numpy.ndarray]) -> list[numpy.ndarray]:
32    def preprocess(self, frame_images: list[np.ndarray]) -> list[np.ndarray]:
33        return frame_images

Preprocess the frames of a tile, preferably in-place. Should return the frames in the same order. No coordinate system changes should occur here, as they are handled elsewhere.

Parameters
  • images: a list of np.ndarrays, each representing a frame.
Returns

a list of np.ndarrays, each representing a frame.

class DummySegmenter(csi_analysis.pipelines.scan_pipeline.TileSegmenter):
36class DummySegmenter(scan_pipeline.TileSegmenter):
37    def __init__(
38        self,
39        scan: Scan,
40        version: str,
41        save: bool = False,
42    ):
43        self.scan = scan
44        self.version = version
45        self.save = save
46        # List of output mask types that this segmenter can output; must exist
47        self.mask_types = [mask_type for mask_type in scan_pipeline.MaskType]
48
49    def __repr__(self):
50        return f"{self.__class__.__name__}-{self.version})"
51
52    def segment(
53        self, frame_images: list[np.ndarray]
54    ) -> dict[scan_pipeline.MaskType, np.ndarray]:
55        mask = np.zeros(frame_images[0].shape).astype(np.uint16)
56        mask[100:200, 100:200] = 1
57        return {scan_pipeline.MaskType.EVENT: mask}

Abstract class for a tile segmenter.

DummySegmenter(scan: csi_images.csi_scans.Scan, version: str, save: bool = False)
37    def __init__(
38        self,
39        scan: Scan,
40        version: str,
41        save: bool = False,
42    ):
43        self.scan = scan
44        self.version = version
45        self.save = save
46        # List of output mask types that this segmenter can output; must exist
47        self.mask_types = [mask_type for mask_type in scan_pipeline.MaskType]
scan
version
save = False
mask_types
def segment( self, frame_images: list[numpy.ndarray]) -> dict[csi_analysis.pipelines.scan_pipeline.MaskType, numpy.ndarray]:
52    def segment(
53        self, frame_images: list[np.ndarray]
54    ) -> dict[scan_pipeline.MaskType, np.ndarray]:
55        mask = np.zeros(frame_images[0].shape).astype(np.uint16)
56        mask[100:200, 100:200] = 1
57        return {scan_pipeline.MaskType.EVENT: mask}

Segments the frames of a tile to enumerated mask(s), not modifying images. Mask(s) should be returned in a dict with labeled types.

Parameters
  • images: a list of np.ndarrays, each representing a frame.
Returns

a dict of np.ndarrays, each representing a mask.

class DummyImageFilter(csi_analysis.pipelines.scan_pipeline.ImageFilter):
60class DummyImageFilter(scan_pipeline.ImageFilter):
61    def __init__(
62        self,
63        scan: Scan,
64        version: str,
65        save: bool = False,
66    ):
67        self.scan = scan
68        self.version = version
69        self.save = save
70
71    def __repr__(self):
72        return f"{self.__class__.__name__}-{self.version})"
73
74    def filter_images(
75        self,
76        frame_images: list[np.ndarray],
77        masks: dict[scan_pipeline.MaskType, np.ndarray],
78    ) -> dict[scan_pipeline.MaskType, np.ndarray]:
79        return masks

Abstract class for an image-based event filter.

DummyImageFilter(scan: csi_images.csi_scans.Scan, version: str, save: bool = False)
61    def __init__(
62        self,
63        scan: Scan,
64        version: str,
65        save: bool = False,
66    ):
67        self.scan = scan
68        self.version = version
69        self.save = save
scan
version
save = False
def filter_images( self, frame_images: list[numpy.ndarray], masks: dict[csi_analysis.pipelines.scan_pipeline.MaskType, numpy.ndarray]) -> dict[csi_analysis.pipelines.scan_pipeline.MaskType, numpy.ndarray]:
74    def filter_images(
75        self,
76        frame_images: list[np.ndarray],
77        masks: dict[scan_pipeline.MaskType, np.ndarray],
78    ) -> dict[scan_pipeline.MaskType, np.ndarray]:
79        return masks

Using images and masks, returns new masks that should have filtered out unwanted objects from the existing masks. Should not be in-place, i.e. should not modify images or masks. Returns a dict of masks that will overwrite the existing masks on identical keys.

Parameters
  • images: a list of np.ndarrays, each representing a frame.
  • masks: a dict of np.ndarrays, each representing a mask.
Returns

a dict of np.ndarrays, each representing a mask; now filtered.

class DummyFeatureExtractor(csi_analysis.pipelines.scan_pipeline.FeatureExtractor):
 82class DummyFeatureExtractor(scan_pipeline.FeatureExtractor):
 83    def __init__(
 84        self,
 85        scan: Scan,
 86        version: str,
 87        save: bool = False,
 88    ):
 89        self.scan = scan
 90        self.version = version
 91        self.save = save
 92
 93    def __repr__(self):
 94        return f"{self.__class__.__name__}-{self.version})"
 95
 96    def extract_features(
 97        self,
 98        frame_images: list[np.ndarray],
 99        masks: dict[scan_pipeline.MaskType, np.ndarray],
100        events: EventArray,
101    ) -> pd.DataFrame:
102        features = pd.DataFrame({"mean_intensity": [np.mean(frame_images[0])]})
103        return features

Abstract class for a feature extractor.

DummyFeatureExtractor(scan: csi_images.csi_scans.Scan, version: str, save: bool = False)
83    def __init__(
84        self,
85        scan: Scan,
86        version: str,
87        save: bool = False,
88    ):
89        self.scan = scan
90        self.version = version
91        self.save = save
scan
version
save = False
def extract_features( self, frame_images: list[numpy.ndarray], masks: dict[csi_analysis.pipelines.scan_pipeline.MaskType, numpy.ndarray], events: csi_images.csi_events.EventArray) -> pandas.core.frame.DataFrame:
 96    def extract_features(
 97        self,
 98        frame_images: list[np.ndarray],
 99        masks: dict[scan_pipeline.MaskType, np.ndarray],
100        events: EventArray,
101    ) -> pd.DataFrame:
102        features = pd.DataFrame({"mean_intensity": [np.mean(frame_images[0])]})
103        return features

Using images, masks, and events, returns new features as a pd.DataFrame.

Parameters
  • images: a list of np.ndarrays, each representing a frame.
  • masks: a dict of np.ndarrays, each representing a mask.
  • events: an EventArray, potentially with populated feature data.
Returns

an EventArray with new populated feature data.

class DummyFeatureFilter(csi_analysis.pipelines.scan_pipeline.FeatureFilter):
106class DummyFeatureFilter(scan_pipeline.FeatureFilter):
107    def __init__(
108        self,
109        scan: Scan,
110        version: str,
111        save: bool = False,
112    ):
113        self.scan = scan
114        self.version = version
115        self.save = save
116
117    def __repr__(self):
118        return f"{self.__class__.__name__}-{self.version})"
119
120    def filter_features(self, events: EventArray) -> tuple[EventArray, EventArray]:
121        return events, EventArray()

Abstract class for a feature-based event filter.

DummyFeatureFilter(scan: csi_images.csi_scans.Scan, version: str, save: bool = False)
107    def __init__(
108        self,
109        scan: Scan,
110        version: str,
111        save: bool = False,
112    ):
113        self.scan = scan
114        self.version = version
115        self.save = save
scan
version
save = False
def filter_features( self, events: csi_images.csi_events.EventArray) -> tuple[csi_images.csi_events.EventArray, csi_images.csi_events.EventArray]:
120    def filter_features(self, events: EventArray) -> tuple[EventArray, EventArray]:
121        return events, EventArray()

Removes events from an event array based on feature values.

Parameters
  • events: a EventArray with populated features.
Returns

two EventArray objects: tuple[remaining, filtered]

class DummyClassifier(csi_analysis.pipelines.scan_pipeline.EventClassifier):
124class DummyClassifier(scan_pipeline.EventClassifier):
125    def __init__(
126        self,
127        scan: Scan,
128        version: str,
129        save: bool = False,
130    ):
131        self.scan = scan
132        self.version = version
133        self.save = save
134
135    def __repr__(self):
136        return f"{self.__class__.__name__}-{self.version})"
137
138    def classify_events(self, events: EventArray) -> EventArray:
139        events.add_metadata(
140            pd.DataFrame(
141                {f"model_classification{len(events)}": ["dummy"] * len(events)}
142            )
143        )
144        return events

Abstract class for an event classifier.

DummyClassifier(scan: csi_images.csi_scans.Scan, version: str, save: bool = False)
125    def __init__(
126        self,
127        scan: Scan,
128        version: str,
129        save: bool = False,
130    ):
131        self.scan = scan
132        self.version = version
133        self.save = save
scan
version
save = False
def classify_events( self, events: csi_images.csi_events.EventArray) -> csi_images.csi_events.EventArray:
138    def classify_events(self, events: EventArray) -> EventArray:
139        events.add_metadata(
140            pd.DataFrame(
141                {f"model_classification{len(events)}": ["dummy"] * len(events)}
142            )
143        )
144        return events

Classifies events based on features, then populates the metadata.

Parameters
  • events: a EventArray with populated features.
Returns

a EventArray with populated metadata.

def test_scan_pipeline():
147def test_scan_pipeline():
148    scan = Scan.load_yaml("tests/data")
149    log_options = {
150        sys.stderr: {"level": "DEBUG", "colorize": True},
151    }
152    pipeline = scan_pipeline.TilingScanPipeline(
153        scan,
154        output_path="tests/data",
155        preprocessors=[DummyPreprocessor(scan, "2024-10-30")],
156        segmenters=[DummySegmenter(scan, "2024-10-30")],
157        image_filters=[DummyImageFilter(scan, "2024-10-30")],
158        feature_extractors=[DummyFeatureExtractor(scan, "2024-10-30")],
159        tile_feature_filters=[DummyFeatureFilter(scan, "2024-10-30")],
160        tile_event_classifiers=[DummyClassifier(scan, "2024-10-30")],
161        scan_feature_filters=[DummyFeatureFilter(scan, "2024-10-30")],
162        scan_event_classifiers=[DummyClassifier(scan, "2024-10-30")],
163        max_workers=1,
164        log_options=log_options,
165    )
166    events = pipeline.run()
167    assert (
168        len(events) == scan_pipeline.roi[0].tile_rows * scan_pipeline.roi[0].tile_cols
169    )