Source code for motion_correction.desktop.flim_aligner

from motion_correction.motion_correction import _flow_warp, calculate_correction
from motion_correction.algorithms import _CorrectionAlgorithm
import os
import sparse
import numpy as np
from tqdm import tqdm
from enum import Enum
from sparse import GCXS
from pathlib import Path
import matplotlib.pyplot as plt
from motion_correction.pqreader import load_ptfile
from .utility import join_path, save_sequence_images
from numba import njit

# np.warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)


[docs] class SimMetric(Enum): NCC = "ncc" SPM = "spm" MSE = "mse" NRM = "nrm" SSI = "ssi"
[docs] class FlimAligner: """ A class for aligning and correcting FLIM (Fluorescence Lifetime Imaging Microscopy) data. Attributes: ptfile (str): The path to the input PT3/PTU file. transforms (numpy.ndarray): Transformation matrices estimated from intensity frame alignment. global_method (AlignMethod): The global alignment method (e.g., AlignMethod.PHASE). local_method (AlignMethod): The local alignment method (e.g., AlignMethod.OPTICAL_POLY). channel (int): The channel index selected for intensity frame alignment. shape (tuple): The shape of the FLIM stack data. flim_dict (dict): A dictionary containing the indices and values of non-zero FLIM data entries. flim_frames (numpy.ndarray): FLIM intensity frames before correction. flim_frames_corrected (numpy.ndarray): Corrected FLIM intensity frames. curve_fit (numpy.ndarray): Curve decay curve histogramed data. curve_fit_corrected (numpy.ndarray): Corrected decay curve histogramed data. curve_fit_corrected_int (numpy.ndarray): Corrected curve histogramed data as integers. save_dir (str): The directory where results will be saved. sim_metric (SimMetric): The similarity metric for frame alignment (e.g., SimMetric.NCC). old_sim (numpy.ndarray): Similarity values before alignment for all intensity frames. new_sim (numpy.ndarray): Similarity values after alignment for all intensity frames. meta (dict): Metadata associated with the FLIM data. Methods: set_methods(global_method=None, local_method=None): Set global and local alignment methods. set_sim_metric(sim=SimMetric.NCC): Set the similarity metric for alignment. get_intensity_stack(pt_file, is_raw=False): Load and prepare FLIM intensity data. apply_correction_intensity(ptfile, ref_frame_idx=0): Apply intensity correction to FLIM frames. export_results(save_dir=None): Export corrected results to files. apply_correction_flim(): Apply FLIM data correction. """ def __init__(self): """ Initializes a FlimAligner object with default attribute values. """ self.ptfile = None self.transforms = None self.global_method = None self.local_method = None self.channel = 0 self.shape = None self.flim_dict = None self.flim_frames = None self.flim_frames_corrected = None self.curve_fit = None self.curve_fit_corrected = None self.curve_fit_corrected_int = None self.save_dir = None self.sim_metric = SimMetric.NCC self.old_sim = None self.new_sim = None self.meta = None
[docs] def set_methods( self, global_method: _CorrectionAlgorithm | None = None, local_method: _CorrectionAlgorithm | None = None, ): """ Set the global and local alignment methods. Args: global_method (AlignMethod): The global alignment method. local_method (AlignMethod): The local alignment method. """ if type(global_method) is _CorrectionAlgorithm: assert global_method.algorithm_type == "global" global_method = global_method.value if type(local_method) is _CorrectionAlgorithm: assert global_method.algorithm_type == "local" local_method = local_method.value self.global_method = global_method self.local_method = local_method
[docs] def set_sim_metric(self, sim=SimMetric.NCC): """ Set the similarity metric for frame alignment. Args: sim (SimMetric): The similarity metric to use. """ if isinstance(sim, SimMetric): self.sim_metric = sim else: raise ValueError
[docs] def set_channel(self, channel: int): self.channel = channel
[docs] def get_intensity_stack(self, pt_file, is_raw=False): """ Load and prepare FLIM intensity data from a PT3 file. Args: pt_file (str): The path to the PT3 file. is_raw (bool): Whether the data is raw intensity data. """ self.ptfile = pt_file data, self.meta = load_ptfile(self.ptfile, is_raw) if isinstance(data, tuple): self.flim_dict = data[0] self.shape = data[1] else: self.flim_frames = data[:, :, self.channel, :].sum(axis=-1).astype(np.int64) del data self.shape = self.flim_frames.shape
[docs] def apply_correction_intensity(self, ptfile, ref_frame_idx=0): """ Apply intensity correction to FLIM frames. Args: ptfile (str): The path to the PT3 file. ref_frame_idx (int): The index of the reference frame for alignment. """ try: self.get_intensity_stack(ptfile, is_raw=False) except FileNotFoundError as e: print(e.errno) assert 0 <= ref_frame_idx < self.shape[2] results = calculate_correction( self.flim_frames, ref_frame_idx, self.local_method, self.global_method ) self.transforms = results["combined_transforms"] self.old_sim = results["metrics"][self.sim_metric.value]["original"] self.new_sim = results["metrics"][self.sim_metric.value]["corrected"] self.flim_frames_corrected = results["corrected_intensity_data_stack"]
[docs] def export_results(self, save_dir=None): """ Export corrected results to files. Args: save_dir (str): The directory where results will be saved. """ if save_dir is None: self.save_dir = os.path.join(os.getcwd(), "save_dir") else: self.save_dir = save_dir Path(self.save_dir).mkdir(parents=True, exist_ok=True) name = os.path.basename(self.ptfile) save_sequence_images( join_path(self.save_dir, f"{name}_original.mp4"), np.moveaxis(self.flim_frames, [0, 1, 2], [1, 2, 0]), ) save_sequence_images( join_path(self.save_dir, f"{name}_aligned.mp4"), np.moveaxis(self.flim_frames_corrected, [0, 1, 2], [1, 2, 0]), ) fig, axes = plt.subplots(1, 2, figsize=(12, 6)) subfig0 = axes[0].imshow(np.sum(self.flim_frames, axis=2)) plt.colorbar(subfig0, ax=axes[0], fraction=0.046, pad=0.04) axes[0].set_title("original") subfig1 = axes[1].imshow(np.sum(self.flim_frames_corrected, axis=2)) plt.colorbar(subfig1, ax=axes[1], fraction=0.046, pad=0.04) axes[1].set_title("aligned") for ax in axes: ax.set_xticks([]) ax.set_yticks([]) plt.tight_layout() plt.savefig( os.path.join(self.save_dir, f"{name}_intensity_images.svg"), format="svg" ) plt.close() plt.plot(self.old_sim, linewidth=2, label="Original") plt.plot(self.new_sim, linewidth=2, label="Aligned") plt.ylabel(self.sim_metric) plt.xlabel("Frame") plt.legend(loc="best") plt.tight_layout() plt.savefig(os.path.join(self.save_dir, f"{name}_sim_plot.svg"), format="svg") plt.close() print(f"Visualization results have been saved to {self.save_dir}")
[docs] def apply_correction_flim(self): """ Apply FLIM data correction to the loaded data. This method applies FLIM data correction to the loaded raw FLIM data. It uses the transformations estimated in function apply_correction_intensity to correct the data. Note: This method assumes that the FLIM data has already been loaded using the `get_intensity_stack` method. """ tuple_data, _ = load_ptfile( self.ptfile, is_raw=True ) # shape: H x W x C x F x nanotime flim_data_dict, shape = tuple_data coords = flim_data_dict[:5, :] data = flim_data_dict[5, :] coo = sparse.COO(coords, data, shape) gcxs = GCXS.from_coo(coo, compressed_axes=[2, 3, 4]) del flim_data_dict num_rows, num_cols, num_channels, num_frames, num_nanotimes = shape self.curve_fit = np.zeros((num_nanotimes, num_rows, num_cols), dtype=np.uint16) self.curve_fit_corrected = np.zeros( (num_nanotimes, num_rows, num_cols), dtype=np.float32 ) self.curve_fit_corrected_int = np.zeros( (num_nanotimes, num_rows, num_cols), dtype=np.uint16 ) header_variables = np.array( [ self.meta["imghdr"][1], self.meta["imghdr"][6], self.meta["imghdr"][7], self.meta["imghdr"][3], self.meta["imghdr"][4], self.meta["imghdr"][2], ], dtype=np.uint64, ) # ImgHdr_PixX = header_variables[1] ImgHdr_LineStart = header_variables[3] ImgHdr_LineStop = header_variables[4] ImgHdr_Frame = header_variables[5] LineStartMarker = 2 ** (ImgHdr_LineStart - 1) LineStopMarker = 2 ** (ImgHdr_LineStop - 1) FrameMarker = 2 ** (ImgHdr_Frame - 1) timestamps = np.array([], dtype=np.uint32) detectors = np.array([], dtype=np.uint32) nanotimes = np.array([], dtype=np.uint32) ts_index = 0 current_ts = 0 corrected_frames = np.zeros( (num_rows, num_cols, num_channels, num_nanotimes), dtype=np.uint16 ) for frame_idx in (pbar := tqdm(range(num_frames))): pbar.set_description("Aligning raw data") for ch in range(num_channels): frame = gcxs[:, :, ch, frame_idx, :].todense().astype(np.float32) self.curve_fit += frame.astype(np.uint16).transpose(2, 1, 0) flow = self.transforms[:, :, :, frame_idx] warped, warped_int = _flow_warp(frame, flow) # Z x H x W self.curve_fit_corrected += warped self.curve_fit_corrected_int += warped_int corrected_frames[:, :, ch, :] = np.moveaxis( warped_int, [0, 1, 2], [2, 0, 1] ) sync, chan, tcspc, current_ts, ts_index = _stream_one_frame( corrected_frames, LineStartMarker, LineStopMarker, FrameMarker, current_ts, ts_index, ) timestamps = np.concatenate((timestamps, sync), axis=0) detectors = np.concatenate((detectors, chan), axis=0) nanotimes = np.concatenate((nanotimes, tcspc), axis=0) # save stream data to pt3 file time_bit = 16 dtime_bit = 12 t3records = ( np.left_shift(detectors.astype(np.uint32), time_bit + dtime_bit) | np.left_shift(nanotimes.astype(np.uint32), time_bit) | timestamps.astype(np.uint16) ) filename = os.path.join( self.save_dir, os.path.basename(self.ptfile)[:-4] + "_corrected.pt3" ) print(f"Data exported to {filename}") with open(filename, "wb") as f: for m in [ "header", "dispcurve", "params", "repeatgroup", "hardware", "router", "ttmode", ]: f.write(np.array(self.meta[m]).tobytes()) f.write(np.array(self.meta["imghdr"])) f.write(t3records.astype(np.uint32))
@njit def _stream_one_frame( corrected_frames, LineStartMarker, LineStopMarker, FrameMarker, current_ts=0, ts_index=0, ): num_rows, num_cols, num_channels, num_nanotimes = corrected_frames.shape total_entries = int(corrected_frames.sum() * 2 + 1 + num_rows * 2) # Initialize arrays sync = np.zeros(total_entries, dtype=np.uint32) chan = np.zeros(total_entries, dtype=np.uint32) tcspc = np.zeros(total_entries, dtype=np.uint32) idx = 0 for r in range(num_rows): # LineStartMarker sync[idx] = current_ts chan[idx] = 15 tcspc[idx] = LineStartMarker idx += 1 for ch in range(num_channels): for c in range(num_cols): ts = current_ts + c if ts >= 65536 * (1 + ts_index): ts_index += 1 sync[idx] = ts chan[idx] = 15 tcspc[idx] = 0 idx += 1 num_non_zeros = np.sum(corrected_frames[r, c, ch, :]) sync[idx : idx + num_non_zeros] = np.repeat(ts, num_non_zeros) chan[idx : idx + num_non_zeros] = np.repeat(ch + 1, num_non_zeros) tcspc[idx : idx + num_non_zeros] = np.repeat( np.arange(num_nanotimes), corrected_frames[r, c, ch, :] ) idx += num_non_zeros current_ts += num_cols if current_ts >= 65536 * (1 + ts_index): ts_index += 1 sync[idx] = current_ts chan[idx] = 15 tcspc[idx] = 0 idx += 1 sync[idx] = current_ts chan[idx] = 15 tcspc[idx] = LineStopMarker idx += 1 sync[idx] = current_ts chan[idx] = 15 tcspc[idx] = FrameMarker idx += 1 sync = sync[:idx] chan = chan[:idx] tcspc = tcspc[:idx] return sync, chan, tcspc, current_ts, ts_index