"""Image alignment pipeline and video frame ingestion.
``Image_preparation_pipeline`` builds a clean reference image for registration
from a sequence of frames by downsampling, masking with a VQT spectrogram to
keep only low-spectral-variance (non-behavior) frames, and applying CLAHE.
Also provides SFTP / local video frame extractors used to seed alignment.
"""
from typing import Dict, Optional
import os
import base64
import io
from PIL import Image
import numpy as np
import ffmpeg
import matplotlib.pyplot as plt
import cv2
from tqdm.auto import tqdm
from face_rhythm import rois
[docs]
class Image_preparation_pipeline():
"""
Builds a clean reference image for registration by downsampling frames,
selecting frames with low spectral variance via a VQT spectrogram, and
applying CLAHE contrast enhancement. RH 2023
Args:
ds_factor (int):
Spatial downsampling factor applied before spectral analysis.
(Default is ``20``)
ptile_specVar_keep (float):
Percentile cutoff for the per-frame mean spectral magnitude;
frames at or below this percentile are kept as **low-variance**
(non-behavior) frames. (Default is ``10``)
ptile_intensity_keep (float):
Percentile cutoff used when normalizing pixel intensities (kept
for backwards compatibility; current implementation no longer
applies this clip). (Default is ``90``)
params_vqt (Dict):
Keyword arguments forwarded to ``vqt.VQT`` for the spectral
analysis. Recognized keys include ``Fs_sample`` (sample rate),
``Q_lowF``, ``Q_highF`` (quality factors at the low and high
frequency bounds), ``F_min``, ``F_max`` (frequency range),
``n_freq_bins`` (number of frequency bins), ``window_type``,
``downsample_factor``, ``fft_conv`` (use FFT-based convolution),
and ``plot_pref``. (Default is the dictionary shown in the
signature)
clip_limit (float):
``clipLimit`` argument forwarded to OpenCV CLAHE.
(Default is ``2.0``)
grid_size (int):
Tile grid size forwarded to OpenCV CLAHE. (Default is ``20``)
verbose (bool):
If ``True``, prints progress messages and shows intermediate
plots. (Default is ``True``)
"""
def __init__(
self,
ds_factor: int = 20,
ptile_specVar_keep: float = 10,
ptile_intensity_keep: float = 90,
params_vqt: Dict = {
"Fs_sample": 250,
"Q_lowF": 3.5,
"Q_highF": 20,
"F_min": 0.5,
"F_max": 60,
"n_freq_bins": 50,
"window_type": 'hann',
"downsample_factor": 10,
"fft_conv": True,
"plot_pref": False,
},
clip_limit: float = 2.0,
grid_size: int = 20,
verbose: bool = True,
) -> None:
"""Initializes the pipeline and stores all preprocessing parameters."""
self.ds_factor = ds_factor
self.ptile_specVar_keep = ptile_specVar_keep
self.ptile_intensity_keep = ptile_intensity_keep
self.params_vqt = params_vqt
self.clip_limit = clip_limit
self.grid_size = grid_size
self.verbose = verbose
[docs]
def downsample(self, images: np.ndarray, ds_factor: Optional[int]=None) -> np.ndarray:
"""
Spatially downsamples a stack of images by ``ds_factor`` using bilinear
interpolation, collapsing any color channel by mean.
Args:
images (np.ndarray):
Input image stack. shape: *(n_frames, H, W)* or
*(n_frames, H, W, C)*.
ds_factor (Optional[int]):
Integer downsampling factor. If ``None``, ``self.ds_factor``
is used. (Default is ``None``)
Returns:
(np.ndarray):
images_ds (np.ndarray):
Downsampled image stack.
shape: *(n_frames, H // ds_factor, W // ds_factor)*,
dtype: *float32*.
"""
import torch
images_tensor = torch.as_tensor(images, dtype=torch.float32)
images_tensor = images_tensor.mean(-1) if images_tensor.ndim == 4 else images_tensor
ds_factor = self.ds_factor if ds_factor is None else ds_factor
images_ds = torch.nn.functional.interpolate(images_tensor.permute(0, 1, 2)[None, ...], size=(images_tensor.shape[1]//ds_factor, images_tensor.shape[2]//ds_factor), mode='bilinear', align_corners=False).float().numpy()[0]
if self.verbose:
print(f"Images shape: {images.shape}")
print(f"Images downsampled shape: {images_ds.shape}")
plt.figure()
plt.imshow(images_ds[0])
return images_ds
[docs]
def find_low_spectral_variance_idx(
self,
images_ds: np.ndarray,
images: np.ndarray,
ptile_specVar_keep: Optional[float] = 10,
ptile_intensity_keep: Optional[float] = 90,
params_vqt: Optional[Dict] = {
"Fs_sample": 250,
"Q_lowF": 3.5,
"Q_highF": 20,
"F_min": 0.5,
"F_max": 60,
"n_freq_bins": 50,
"window_type": 'hann',
"downsample_factor": 10,
"fft_conv": True,
"plot_pref": False,
},
):
"""
Selects frames whose mean VQT spectral magnitude lies in the lowest
``ptile_specVar_keep`` percentile and returns a normalized mean image
over those frames for use as a registration reference.
Args:
images_ds (np.ndarray):
Downsampled image stack used to compute spectrograms.
shape: *(n_frames, H_ds, W_ds)*.
images (np.ndarray):
Full-resolution image stack used to build the reference image.
shape: *(n_frames, H, W)* or *(n_frames, H, W, C)*.
ptile_specVar_keep (Optional[float]):
Percentile cutoff for the mean spectral magnitude; frames at
or below this percentile are kept. If ``None``,
``self.ptile_specVar_keep`` is used. (Default is ``10``)
ptile_intensity_keep (Optional[float]):
Percentile cutoff for intensity normalization (currently
unused in the active code path; retained for backwards
compatibility). If ``None``, ``self.ptile_intensity_keep`` is
used. (Default is ``90``)
params_vqt (Optional[Dict]):
Keyword arguments forwarded to ``vqt.VQT``. If ``None``,
``self.params_vqt`` is used. (Default is the dictionary shown
in the signature)
Returns:
(np.ndarray):
im (np.ndarray):
Square-rooted, max-normalized mean of the kept full-
resolution frames. shape: *(H, W)*, dtype: *float32* (or
matching the dtype of ``images.mean(0)``).
"""
import vqt
import torch
ptile_specVar_keep = self.ptile_specVar_keep if ptile_specVar_keep is None else ptile_specVar_keep
ptile_intensity_keep = self.ptile_intensity_keep if ptile_intensity_keep is None else ptile_intensity_keep
params_vqt = self.params_vqt if params_vqt is None else params_vqt
my_vqt = vqt.VQT(**params_vqt)
spec = my_vqt.forward(torch.as_tensor(images_ds, dtype=torch.float32).reshape(images_ds.shape[0], -1).T)
spec = spec.permute(1, 2, 0).reshape(spec.shape[1], spec.shape[2], images_ds.shape[1], images_ds.shape[2])
xAxis = my_vqt.get_xAxis(n_samples=images_ds.shape[0])
if self.verbose:
## plot a random spectrogram
plt.figure()
plt.imshow(spec[:, :, np.random.randint(0, spec.shape[-2]), np.random.randint(0, spec.shape[-1])], aspect='auto')
vals = spec.numpy().mean((0, 2, 3))
idx = np.int64(xAxis[np.where(vals < np.percentile(vals, ptile_specVar_keep))[0]])
images = images.mean(-1) if images.ndim == 4 else images
# v = images[idx].astype(np.float32).var(0)
# v = np.clip(v, 1e-1, np.percentile(v, ptile_intensity_keep))
im = images[idx].mean(0)
# im /= v
# im = np.nan_to_num(im, 0)
# im = np.clip(im, 0, np.percentile(im, 90))
im = im ** 0.5
im /= im.max()
if self.verbose:
plt.figure()
plt.imshow(im)
return im
[docs]
def apply_clahe(
self,
image: np.ndarray,
clip_limit: Optional[float] = 2.0,
grid_size: Optional[int] = 20,
) -> np.ndarray:
"""
Applies CLAHE contrast enhancement to a single image via
``rois.Image_Aligner.augment_images``.
Args:
image (np.ndarray):
Input image. shape: *(H, W)*.
clip_limit (Optional[float]):
CLAHE ``clipLimit`` argument. If ``None``, ``self.clip_limit``
is used. (Default is ``2.0``)
grid_size (Optional[int]):
CLAHE tile grid size. If ``None``, ``self.grid_size`` is used.
(Default is ``20``)
Returns:
(np.ndarray):
im_aug (np.ndarray):
CLAHE-enhanced image. shape: *(H, W)*.
"""
import functools
clip_limit = self.clip_limit if clip_limit is None else clip_limit
grid_size = self.grid_size if grid_size is None else grid_size
aligner = rois.Image_Aligner(verbose=self.verbose)
augmenter = functools.partial(
aligner.augment_images,
use_CLAHE=True,
CLAHE_grid_size=grid_size,
CLAHE_clipLimit=clip_limit,
CLAHE_normalize=True,
)
im_aug = augmenter(ims=[image])[0]
if self.verbose:
plt.figure()
plt.imshow(im_aug)
return im_aug
[docs]
def apply_pipeline(
self,
images: np.ndarray,
):
"""
Runs the full reference-image pipeline: downsample, select
low-spectral-variance frames, then apply CLAHE.
Args:
images (np.ndarray):
Input image stack. shape: *(n_frames, H, W)* or
*(n_frames, H, W, C)*.
Returns:
(np.ndarray):
im_aug (np.ndarray):
CLAHE-enhanced reference image. shape: *(H, W)*.
"""
print(f"downsampling...") if self.verbose > 0 else None
images_ds = self.downsample(images)
print(f"computing spectrograms...") if self.verbose > 0 else None
im = self.find_low_spectral_variance_idx(images_ds, images)
print(f"applying CLAHE...") if self.verbose > 0 else None
im_aug = self.apply_clahe(im)
return im_aug
[docs]
def get_frames(path, time_start, time_end, verbose=False):
"""
Reads a contiguous range of frames from a local video using OpenCV by
seeking with ``cv2.CAP_PROP_POS_FRAMES``. Stops early if the requested
range extends past EOF rather than raising.
Args:
path (str):
Path to the local video file.
time_start (float):
Start time, in seconds, of the read window.
time_end (float):
End time, in seconds, of the read window. The number of frames
requested is ``int((time_end - time_start) * fps)``.
verbose (bool):
If ``True``, displays a tqdm progress bar over the seek loop.
(Default is ``False``)
Returns:
(np.ndarray):
ims (np.ndarray):
Decoded frames stacked along axis 0. shape:
*(n_frames, H, W, 3)*, dtype matches the dtype returned by
``cv2.VideoCapture.read`` (typically *uint8*).
Raises:
ValueError:
If no frames could be read in the requested interval.
"""
vc = cv2.VideoCapture(path)
## Get Fs
fps = vc.get(cv2.CAP_PROP_FPS)
print(f"Found fps: {fps}")
## Get sample_start and sample_end
sample_start = time_start * fps
sample_end = time_end * fps
print(f"sample_start: {sample_start}, sample_end: {sample_end}")
## Get number of frames
num_frames = int(sample_end - sample_start)
print(f"num_frames: {num_frames}")
## Get frames. If the request extends past EOF we stop early rather
## than crashing — dtype is taken from the first successful read.
frames = []
for i in tqdm(range(num_frames), disable=not verbose):
vc.set(cv2.CAP_PROP_POS_FRAMES, int(sample_start + i))
ret, frame = vc.read()
if not ret:
break
frames.append(frame)
if not frames:
raise ValueError(
f"No frames could be read from {path} in the interval "
f"[{time_start}, {time_end}] s."
)
ims = np.array(frames, dtype=frames[0].dtype)
return ims