"""Spectral analysis of point-tracked motion via Variable-Q Transform (VQT).
``VQT_Analyzer`` converts point trajectories (from :mod:`face_rhythm.point_tracking`)
into per-point spectrograms with a Variable-Q transform, applies 1/f and
per-timepoint normalization, and writes the resulting complex or magnitude
tensors out for downstream :mod:`face_rhythm.decomposition`.
"""
from typing import Union
from pathlib import Path
import math
import numpy as np
import torch
from tqdm.auto import tqdm
from .util import FR_Module
from . import helpers
[docs]
class VQT_Analyzer(FR_Module):
"""
Computes normalized Variable-Q Transform (VQT) spectrograms for point
displacement traces. RH 2022
Args:
params_VQT (dict):
Keyword arguments forwarded to ``vqt.VQT`` (the Variable Q-Transform
implementation in :mod:`vqt`). Notable keys include ``Fs_sample``
(sampling rate in Hz), ``Q_lowF`` and ``Q_highF`` (Q-factors at the
low and high frequency bounds), ``F_min`` and ``F_max`` (frequency
range in Hz), ``n_freq_bins``, ``window_type``, ``downsample_factor``,
and ``take_abs``. (Default is the dict shown in the signature)
batch_size (int):
Number of points processed per VQT batch. (Default is ``10``)
device (str):
Torch device on which the VQT model is run (e.g. ``'cpu'`` or
``'cuda'``). (Default is ``'cpu'``)
normalization_factor (float):
Strength of the per-timepoint power normalization, in the range
``[0, 1]``. ``0`` disables normalization; ``1`` forces every time
point to have equal total power. (Default is ``0.99``)
spectrogram_exponent (float):
Exponent applied to the spectrogram magnitudes prior to
normalization. (Default is ``1.0``)
one_over_f_exponent (float):
Exponent for the 1/f correction; the spectrogram is multiplied by
``freqs ** one_over_f_exponent``. ``0`` disables the correction.
(Default is ``1.0``)
verbose (int):
Verbosity level. ``0`` is silent, higher values print and show
progress bars. (Default is ``1``)
Attributes:
spectrograms (dict):
Dict mapping each input key to its normalized spectrogram array.
Populated by :meth:`transform_all`.
x_axis (dict):
Dict mapping each input key to the time axis (in samples) of its
spectrogram. Populated by :meth:`transform_all`.
freqs (dict):
Dict mapping each input key to the frequency bin centers (in Hz).
Populated by :meth:`transform_all`.
point_positions (torch.Tensor):
Reshaped reference positions used to subtract offsets from traces.
vqt_model (vqt.VQT):
The underlying VQT filter-bank model.
config (dict):
Constructor arguments echoed for ``FR_Module`` serialization.
run_data (dict):
Output payload (filters, frequencies, spectrograms, axes) used by
``FR_Module`` for export.
"""
def __init__(
self,
params_VQT: dict={
'Fs_sample': 90,
'Q_lowF': 3,
'Q_highF': 20,
'F_min': 1,
'F_max': 40,
'n_freq_bins': 55,
'window_type': 'hann',
'symmetry': 'center',
'taper_asymmetric': True,
'downsample_factor': 8,
'padding': 'valid',
'fft_conv': True,
'fast_length': True,
'take_abs': True,
'filters': None,
'plot_pref': False,
},
batch_size: int=10,
device='cpu',
normalization_factor: float=0.99,
spectrogram_exponent: float=1.0,
one_over_f_exponent: float=1.0,
verbose: int=1,
):
"""Initializes the analyzer, builds the VQT filter bank, and stores
config metadata for ``FR_Module`` serialization."""
super().__init__()
## Set attributes
self._params_VQT = params_VQT
self._batch_size = int(batch_size)
self._device = device
self._normalization_factor = float(normalization_factor)
self._spectrogram_exponent = float(spectrogram_exponent)
self._one_over_f_exponent = float(one_over_f_exponent)
self._verbose = int(verbose)
self.spectrograms = None
self.x_axis = None
self.freqs = None
self._demo_spectrogram = None
## Initalize VQT filters
import vqt
self.vqt_model = vqt.VQT(**params_VQT)
self.vqt_model.cpu()
## For FR_Module compatibility
self.config = {
'params_VQT': params_VQT,
'batch_size': batch_size,
'device': device,
'normalization_factor': normalization_factor,
'spectrogram_exponent': spectrogram_exponent,
'one_over_f_exponent': one_over_f_exponent,
'verbose': verbose,
}
self.run_info = {
}
self.run_data = {
'VQT': {key: getattr(self.vqt_model, key).cpu().detach().numpy() for key in ['filters', 'wins']},
'frequencies': self.vqt_model.freqs,
}
## Append the self.run_info data to self.run_data
# self.run_data.update(self.run_info)
# self.run_data['config'] = self.config
[docs]
def cleanup(self):
"""
Deletes every attribute on the instance and triggers garbage collection
to free large tensors held by the analyzer.
"""
import gc
print(f"FR: Deleting all attributes")
while len(self.__dict__.keys()) > 0:
key = list(self.__dict__.keys())[0]
del self.__dict__[key]
gc.collect()
gc.collect()
def _check_inputs(self, points_tracked: dict, point_positions: np.ndarray):
"""
Validates the structure, dtype, and shape of the inputs to
:meth:`transform_all` and :meth:`demo_transform`.
Args:
points_tracked (dict):
Mapping from a name to a tracked-points array of shape
*(n_frames, n_points, 2)*.
point_positions (np.ndarray):
Reference positions of the tracked points. shape:
*(n_points, 2)*.
"""
## Assertions
### Assert that the points_tracked dict is valid
assert isinstance(points_tracked, (dict,)), f"points_tracked must be a dict, not {type(points_tracked)}. See docstring for details."
### Assert that points_tracked contains 3D numpy arrays of shape(n_frames, n_points, 2)
for key, value in points_tracked.items():
assert isinstance(value, np.ndarray), f"points_tracked must contain numpy arrays, not {type(value)}. See docstring for details."
assert value.ndim == 3, f"points_tracked must contain 3D numpy arrays, not {value.ndim}D. Shape should be (n_frames, n_points, 2). See docstring for details."
assert value.shape[2] == 2, f"points_tracked must contain 3D numpy arrays of shape(n_frames, n_points, 2), not {value.shape}. See docstring for details."
### Assert that point_positions is a 2D numpy array of shape(n_points, 2)
assert isinstance(point_positions, np.ndarray), f"point_positions must be a numpy array, not {type(point_positions)}. See docstring for details."
assert point_positions.ndim == 2, f"point_positions must be a 2D numpy array, not {point_positions.ndim}D. Shape should be (n_points, 2). See docstring for details."
assert point_positions.shape[1] == 2, f"point_positions must be a 2D numpy array of shape(n_points, 2), not {point_positions.shape}. See docstring for details."
## Prepare normalization function
def _normalize_spectrogram(self, spectrogram):
"""
Normalizes a spectrogram by dividing each time point by the mean total
power across frequencies and points, then applies a 1/f correction.
Args:
spectrogram (torch.Tensor):
Raw VQT output. shape: *(n_points * 2, n_freq_bins, n_frames)*.
May be real (when ``take_abs=True``) or complex.
Returns:
(torch.Tensor):
spectrogram (torch.Tensor):
Normalized spectrogram split into x/y components. shape:
*(2, n_points, n_freq_bins, n_frames)*. Same dtype family as
the input (real or complex).
"""
## Check inputs
s = spectrogram
if torch.is_complex(s) == False:
s_exp = s ** self._spectrogram_exponent
s_mean = torch.mean(torch.sum(s_exp , dim=1) , dim=0) ## Mean of the summed power across all frequencies and points. Shape (n_frames,)
s_norm = s_exp / ((self._normalization_factor * s_mean[None,None,:]) + (1-self._normalization_factor)) ## Normalize the spectrogram by the mean power across all frequencies and points. Shape (n_points, n_freq_bins, n_frames)
elif torch.is_complex(s) == True:
s_mag = torch.abs(s)
s_phase = torch.angle(s)
s_exp = s_mag ** self._spectrogram_exponent
s_mean = torch.mean(torch.sum(s_exp , dim=1) , dim=0) ## Mean of the summed power across all frequencies and points. Shape (n_frames,)
s_mag_norm = s_exp / ((self._normalization_factor * s_mean[None,None,:]) + (1-self._normalization_factor)) ## Normalize the spectrogram by the mean power across all frequencies and points. Shape (n_points, n_freq_bins, n_frames)
s_norm = torch.polar(s_mag_norm, s_phase)
## Do 1/f correction
s_norm = s_norm * torch.as_tensor(self.vqt_model.freqs[None,:,None] ** self._one_over_f_exponent, dtype=torch.float32) if self._one_over_f_exponent != 0 else s_norm
s_norm_rs = s_norm.reshape(2, int(s_norm.shape[0]/2), s_norm.shape[1], s_norm.shape[2])
return s_norm_rs
def _prepare_displacements(self, traces, point_positions):
"""
Reshapes tracked-point traces from *(n_frames, n_points, 2)* to
*(n_points * 2, n_frames)* (Fortran order) and subtracts the reference
positions to yield displacement traces.
Args:
traces (np.ndarray):
Tracked point coordinates. shape: *(n_frames, n_points, 2)*.
point_positions (torch.Tensor):
Flattened reference positions. shape: *(n_points * 2,)*.
Returns:
(torch.Tensor):
displacements (torch.Tensor):
Per-component displacement traces. shape:
*(n_points * 2, n_frames)*, dtype: *float32*.
"""
out = torch.as_tensor(traces.reshape((traces.shape[0], -1), order='F').T, dtype=torch.float32) - point_positions[:, None]
return out
def _prepare_pointPositions(self, point_positions):
"""
Flattens reference point positions from *(n_points, 2)* to
*(n_points * 2,)* using Fortran order.
Args:
point_positions (np.ndarray):
Reference positions. shape: *(n_points, 2)*.
Returns:
(torch.Tensor):
point_positions (torch.Tensor):
Flattened reference positions. shape: *(n_points * 2,)*,
dtype: *float32*.
"""
return torch.as_tensor(point_positions.reshape((-1,), order='F').T, dtype=torch.float32)
def __repr__(self):
"""Returns a concise string representation of the analyzer."""
return f"{self.__class__.__name__}( normalization_factor={self._normalization_factor}, spectrogram_exponent={self._spectrogram_exponent}, VQT={self.vqt_model}, verbose={self._verbose} )"
def __getitem__(self, index):
"""Indexes into ``self.spectrograms`` by key."""
return self.spectrograms(index)
def __len__(self):
"""Returns the number of stored spectrogram entries."""
return len(self.spectrograms)
def __iter__(self):
"""Iterates over ``(key, spectrogram)`` pairs in ``self.spectrograms``."""
return iter(self.spectrograms.items())
def __call__(self, points_tracked: dict, point_positions: np.ndarray, name_points: str='0'):
"""
Computes spectrograms for one entry of ``points_tracked``. Thin wrapper
around :meth:`transform`; see that method for details.
Args:
points_tracked (dict):
Mapping from a name to a tracked-points array of shape
*(n_frames, n_points, 2)*.
point_positions (np.ndarray):
Reference positions of the tracked points. shape:
*(n_points, 2)*.
name_points (str):
Key into ``points_tracked`` selecting which array to transform.
(Default is ``'0'``)
"""
return self.transform(points_tracked[name_points], point_positions)