"""General-purpose helpers: video I/O wrappers, path tools, image warping, downloads.
Collected utilities used across the face-rhythm package. Notable groups:
* Video readers (``VideoReaderWrapper``, ``BufferedVideoReader``) around
``decord`` / ``torchcodec`` with pre-fetch threads.
* Path and file helpers (``find_paths``, ``prepare_filepath_for_saving``,
download + hash verification, zip extraction).
* Image registration helpers (``find_geometric_transformation``, remap-index
and flow-field conversions) used by :mod:`face_rhythm.rois`.
* Parameter dictionary utilities (``fill_missing_keys_with_defaults``,
``flatten_dict``) and a handful of numerical / plotting / device utilities.
Some routines are adapted from Rich Hakim's ``basic_neural_processing_modules``.
"""
import multiprocessing as mp
import threading
from typing import Union
import time
import gc
import json
from pathlib import Path
import copy
import re
from typing import List, Optional, Tuple, Union, Dict, Any, Callable, MutableMapping
import os
import sys
from functools import partial
import warnings
import numpy as np
import cv2
try:
import decord
except ImportError:
decord = None
import torch
from tqdm.auto import tqdm
import yaml
import zipfile
import pickle
import scipy
import scipy.sparse
import scipy.signal
import tensorly as tl
[docs]
def prepare_cv2_imshow():
"""
Pre-initializes ``cv2.imshow`` to avoid kernel crashes. RH 2022
Calling ``cv2.imshow`` after ``av`` or ``decord`` have been imported can
crash the Python kernel. Showing a small dummy frame here primes the
OpenCV display loop so subsequent ``cv2.imshow`` calls work safely.
"""
import numpy as np
import cv2
test = np.zeros((1,300,400,3))
for frame in test:
cv2.putText(frame, "WELCOME TO FACE RHYTHM!", (10,50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
cv2.putText(frame, "Prepping CV2", (10,100), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
cv2.putText(frame, "Calling this figure allows cv2.imshow ", (10,150), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
cv2.putText(frame, "to work without crashing if this function", (10,170), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
cv2.putText(frame, "is called before importing av and decord", (10,190), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 1)
cv2.imshow('startup', frame)
cv2.waitKey(1000)
cv2.destroyWindow('startup')
#############################################################################################################################################################################
################################################################################### BNPM ####################################################################################
######################################################################## EVERYTHING BELOW IS FROM THE #######################################################################
############################################################# BASIC NEURAL PROCESSING MODULES (BNPM) REPOSITORY #############################################################
####################################################### https://github.com/RichieHakim/basic_neural_processing_modules ######################################################
#############################################################################################################################################################################
#####################################################################################################################################
########################################################### PATH HELPERS ############################################################
#####################################################################################################################################
[docs]
def find_paths(
dir_outer: Union[str, List[str]],
reMatch: str = 'filename',
reMatch_in_path: Optional[str] = None,
find_files: bool = True,
find_folders: bool = False,
depth: int = 0,
natsorted: bool = True,
alg_ns: Optional[str] = None,
verbose: bool = False,
) -> List[str]:
"""
Searches for files and/or folders recursively in a directory using a regex
match.
RH 2022-2023
Args:
dir_outer (Union[str, List[str]]):
Path(s) to the directory(ies) to search. If a list of directories,
then all directories will be searched.
reMatch (str):
Regular expression to match. Each file or folder name encountered
will be compared using ``re.search(reMatch, filename)``. If the
output is not ``None``, the file will be included in the output.
reMatch_in_path (Optional[str]):
Additional regular expression to match anywhere in the upper path.
Useful for finding files/folders in specific subdirectories. If
``None``, then no additional matching is done. \n
(Default is ``None``)
find_files (bool):
Whether to find files. (Default is ``True``)
find_folders (bool):
Whether to find folders. (Default is ``False``)
depth (int):
Maximum folder depth to search. (Default is *0*). \n
* depth=0 means only search the outer directory.
* depth=2 means search the outer directory and two levels of
subdirectories below it
natsorted (bool):
Whether to sort the output using natural sorting with the natsort
package. (Default is ``True``)
alg_ns (str):
Algorithm to use for natural sorting. See ``natsort.ns`` or
https://natsort.readthedocs.io/en/4.0.4/ns_class.html/ for options.
Default is PATH. Other commons are INT, FLOAT, VERSION. (Default is
``None``)
verbose (bool):
Whether to print the paths found. (Default is ``False``)
Returns:
(List[str]):
paths (List[str]):
Paths to matched files and/or folders in the directory.
"""
import natsort
if alg_ns is None:
alg_ns = natsort.ns.PATH
def fn_match(path, reMatch, reMatch_in_path):
# returns true if reMatch is basename and reMatch_in_path in full dirname
if reMatch is not None:
if re.search(reMatch, os.path.basename(path)) is None:
return False
if reMatch_in_path is not None:
if re.search(reMatch_in_path, os.path.dirname(path)) is None:
return False
return True
def get_paths_recursive_inner(dir_inner, depth_end, depth=0):
paths = []
for path in os.listdir(dir_inner):
path = os.path.join(dir_inner, path)
if os.path.isdir(path):
if find_folders:
if fn_match(path, reMatch, reMatch_in_path):
print(f'Found folder: {path}') if verbose else None
paths.append(path)
if depth < depth_end:
paths += get_paths_recursive_inner(path, depth_end, depth=depth+1)
else:
if find_files:
if fn_match(path, reMatch, reMatch_in_path):
print(f'Found file: {path}') if verbose else None
paths.append(path)
return paths
def fn_check_pathLike(obj):
if isinstance(obj, (
str,
Path,
os.PathLike,
np.str_,
bytes,
memoryview,
np.bytes_,
re.Pattern,
re.Match,
)):
return True
else:
return False
dir_outer = [dir_outer] if fn_check_pathLike(dir_outer) else dir_outer
paths = list(set(sum([get_paths_recursive_inner(str(d), depth, depth=0) for d in dir_outer], start=[])))
if natsorted:
paths = natsort.natsorted(paths, alg=alg_ns)
return paths
#####################################################################################################################################
########################################################### FILE HELPERS ############################################################
#####################################################################################################################################
[docs]
def prepare_path(
path: str,
mkdir: bool = False,
exist_ok: bool = True,
) -> str:
"""
Validates a directory or file path for saving or loading. RH 2023
Resolution rules: \n
* If the path exists and ``exist_ok`` is ``True``, it is accepted.
* If the path exists and ``exist_ok`` is ``False``, an error is raised.
* If the path does not exist and refers to a file: the parent directory
is created when ``mkdir`` is ``True``, otherwise an error is raised
when the parent does not exist.
* If the path does not exist and refers to a directory: the directory is
created when ``mkdir`` is ``True``, otherwise an error is raised. \n
Args:
path (str):
Path to be checked.
mkdir (bool):
If ``True``, creates the parent directory (or directory) if it
does not exist. (Default is ``False``)
exist_ok (bool):
If ``True``, allows the path to already exist. (Default is
``True``)
Returns:
(str):
path (str):
Resolved absolute path.
"""
## check if path is valid
try:
path_obj = Path(path).resolve()
except FileNotFoundError as e:
print(f'Invalid path: {path}')
raise e
## check if path object exists
flag_exists = path_obj.exists()
## determine if path is a directory or file
if flag_exists:
flag_dirFileNeither = 'dir' if path_obj.is_dir() else 'file' if path_obj.is_file() else 'neither' ## 'neither' should never happen since path.is_file() or path.is_dir() should be True if path.exists()
assert flag_dirFileNeither != 'neither', f'Path: {path} is neither a file nor a directory.'
assert exist_ok, f'{path} already exists and exist_ok=False.'
else:
flag_dirFileNeither = 'dir' if path_obj.suffix == '' else 'file' ## rely on suffix to determine if path is a file or directory
## if path exists and is a file or directory
# all good. If exist_ok=False, then this should have already been caught above.
## if path doesn't exist and is a file
### if parent directory exists
# all good
### if parent directory doesn't exist
#### mkdir if mkdir=True and raise error if mkdir=False
if not flag_exists and flag_dirFileNeither == 'file':
if Path(path).parent.exists():
pass ## all good
elif mkdir:
Path(path).parent.mkdir(parents=True, exist_ok=True)
else:
assert False, f'File: {path} does not exist, Parent directory: {Path(path).parent} does not exist, and mkdir=False.'
## if path doesn't exist and is a directory
### mkdir if mkdir=True and raise error if mkdir=False
if not flag_exists and flag_dirFileNeither == 'dir':
if mkdir:
Path(path).mkdir(parents=True, exist_ok=True)
else:
assert False, f'{path} does not exist and mkdir=False.'
## if path is neither a file nor a directory
### raise error
if flag_dirFileNeither == 'neither':
assert False, f'{path} is neither a file nor a directory. This should never happen. Check this function for bugs.'
return str(path_obj)
[docs]
def prepare_filepath_for_saving(
filepath: str,
mkdir: bool = False,
allow_overwrite: bool = True
) -> str:
"""
Prepares a file path for saving a file. Ensures the file path is valid and has the necessary permissions.
Args:
filepath (str):
The file path to be prepared for saving.
mkdir (bool):
If set to ``True``, creates parent directory if it does not exist. (Default is ``False``)
allow_overwrite (bool):
If set to ``True``, allows overwriting of existing file. (Default is ``True``)
Returns:
(str):
path (str):
The prepared file path for saving.
"""
return prepare_path(filepath, mkdir=mkdir, exist_ok=allow_overwrite)
[docs]
def prepare_filepath_for_loading(
filepath: str,
must_exist: bool = True
) -> str:
"""
Prepares a file path for loading a file. Ensures the file path is valid and has the necessary permissions.
Args:
filepath (str):
The file path to be prepared for loading.
must_exist (bool):
If set to ``True``, the file at the specified path must exist. (Default is ``True``)
Returns:
(str):
path (str):
The prepared file path for loading.
"""
path = prepare_path(filepath, mkdir=False, exist_ok=must_exist)
if must_exist:
assert Path(path).is_file(), f'{path} is not a file.'
return path
[docs]
def prepare_directory_for_saving(
directory: str,
mkdir: bool = False,
exist_ok: bool = True
) -> str:
"""
Prepares a directory path for saving a file. This function is rarely used.
Args:
directory (str):
The directory path to be prepared for saving.
mkdir (bool):
If set to ``True``, creates parent directory if it does not exist. (Default is ``False``)
exist_ok (bool):
If set to ``True``, allows overwriting of existing directory. (Default is ``True``)
Returns:
(str):
path (str):
The prepared directory path for saving.
"""
return prepare_path(directory, mkdir=mkdir, exist_ok=exist_ok)
[docs]
def prepare_directory_for_loading(
directory: str,
must_exist: bool = True
) -> str:
"""
Prepares a directory path for loading a file. This function is rarely used.
Args:
directory (str):
The directory path to be prepared for loading.
must_exist (bool):
If set to ``True``, the directory at the specified path must exist. (Default is ``True``)
Returns:
(str):
path (str):
The prepared directory path for loading.
"""
path = prepare_path(directory, mkdir=False, exist_ok=must_exist)
if must_exist:
assert Path(path).is_dir(), f'{path} is not a directory.'
return path
[docs]
def pickle_save(
obj: Any,
filepath: str,
mode: str = 'wb',
zipCompress: bool = False,
mkdir: bool = False,
allow_overwrite: bool = True,
**kwargs_zipfile: Dict[str, Any],
) -> None:
"""
Saves an object to a pickle file using `pickle.dump`.
Allows for zipping of the file.
RH 2022
Args:
obj (Any):
The object to save.
filepath (str):
The path to save the object to.
mode (str):
The mode to open the file in. Options are: \n
* ``'wb'``: Write binary.
* ``'ab'``: Append binary.
* ``'xb'``: Exclusive write binary. Raises FileExistsError if the
file already exists. \n
(Default is ``'wb'``)
zipCompress (bool):
If ``True``, compresses pickle file using zipfileCompressionMethod,
which is similar to ``savez_compressed`` in numpy (with
``zipfile.ZIP_DEFLATED``). Useful for saving redundant and/or sparse
arrays objects. (Default is ``False``)
mkdir (bool):
If ``True``, creates parent directory if it does not exist. (Default
is ``False``)
allow_overwrite (bool):
If ``True``, allows overwriting of existing file. (Default is
``True``)
kwargs_zipfile (Dict[str, Any]):
Keyword arguments that will be passed into `zipfile.ZipFile`.
compression=``zipfile.ZIP_DEFLATED`` by default.
See https://docs.python.org/3/library/zipfile.html#zipfile-objects.
Other options for 'compression' are (input can be either int or object): \n
* ``0``: zipfile.ZIP_STORED (no compression)
* ``8``: zipfile.ZIP_DEFLATED (usual zip compression)
* ``12``: zipfile.ZIP_BZIP2 (bzip2 compression) (usually not as
good as ZIP_DEFLATED)
* ``14``: zipfile.ZIP_LZMA (lzma compression) (usually better
than ZIP_DEFLATED but slower)
"""
path = prepare_filepath_for_saving(filepath, mkdir=mkdir, allow_overwrite=allow_overwrite)
if len(kwargs_zipfile)==0:
kwargs_zipfile = {
'compression': zipfile.ZIP_DEFLATED,
}
if zipCompress:
with zipfile.ZipFile(path, 'w', **kwargs_zipfile) as f:
f.writestr('data', pickle.dumps(obj))
else:
with open(path, mode) as f:
pickle.dump(obj, f)
[docs]
def pickle_load(
filepath: str,
zipCompressed: bool = False,
mode: str = 'rb',
) -> Any:
"""
Loads an object from a pickle file.
RH 2022
Args:
filepath (str):
Path to the pickle file.
zipCompressed (bool):
If ``True``, the file is assumed to be a .zip file. The function
will first unzip the file, then load the object from the unzipped
file.
(Default is ``False``)
mode (str):
The mode to open the file in. (Default is ``'rb'``)
Returns:
(Any):
obj (Any):
The object loaded from the pickle file.
"""
path = prepare_filepath_for_loading(filepath, must_exist=True)
if zipCompressed:
with zipfile.ZipFile(path, 'r') as f:
return pickle.loads(f.read('data'))
else:
with open(path, mode) as f:
return pickle.load(f)
[docs]
def json_save(
obj: Any,
filepath: str,
indent: int = 4,
mode: str = 'w',
mkdir: bool = False,
allow_overwrite: bool = True,
) -> None:
"""
Saves an object to a json file using `json.dump`.
RH 2022
Args:
obj (Any):
The object to save.
filepath (str):
The path to save the object to.
indent (int):
Number of spaces for indentation in the output json file. (Default
is *4*)
mode (str):
The mode to open the file in. Options are: \n
* ``'wb'``: Write binary.
* ``'ab'``: Append binary.
* ``'xb'``: Exclusive write binary. Raises FileExistsError if the
file already exists. \n
(Default is ``'w'``)
mkdir (bool):
If ``True``, creates parent directory if it does not exist. (Default
is ``False``)
allow_overwrite (bool):
If ``True``, allows overwriting of existing file. (Default is
``True``)
"""
import json
path = prepare_filepath_for_saving(filepath, mkdir=mkdir, allow_overwrite=allow_overwrite)
with open(path, mode) as f:
json.dump(obj, f, indent=indent)
[docs]
def json_load(
filepath: str,
mode: str = 'r',
) -> Any:
"""
Loads an object from a json file.
RH 2022
Args:
filepath (str):
Path to the json file.
mode (str):
The mode to open the file in. (Default is ``'r'``)
Returns:
(Any):
obj (Any):
The object loaded from the json file.
"""
import json
path = prepare_filepath_for_loading(filepath, must_exist=True)
with open(path, mode) as f:
return json.load(f)
[docs]
def yaml_save(
obj: object,
filepath: str,
indent: int = 4,
mode: str = 'w',
mkdir: bool = False,
allow_overwrite: bool = True,
) -> None:
"""
Saves an object to a YAML file using the ``yaml.dump`` method.
RH 2022
Args:
obj (object):
The object to be saved.
filepath (str):
Path to save the object to.
indent (int):
The number of spaces for indentation in the saved YAML file.
(Default is *4*)
mode (str):
Mode to open the file in. \n
* ``'w'``: write (default)
* ``'wb'``: write binary
* ``'ab'``: append binary
* ``'xb'``: exclusive write binary. Raises ``FileExistsError`` if
file already exists. \n
(Default is ``'w'``)
mkdir (bool):
If ``True``, creates the parent directory if it does not exist.
(Default is ``False``)
allow_overwrite (bool):
If ``True``, allows overwriting of existing files. (Default is
``True``)
"""
path = prepare_filepath_for_saving(filepath, mkdir=mkdir, allow_overwrite=allow_overwrite)
with open(path, mode) as f:
yaml.dump(obj, f, indent=indent)
[docs]
def yaml_load(
filepath: str,
mode: str = 'r',
loader: object = yaml.FullLoader,
) -> object:
"""
Loads a YAML file.
RH 2022
Args:
filepath (str):
Path to the YAML file to load.
mode (str):
Mode to open the file in. (Default is ``'r'``)
loader (object):
The YAML loader to use. \n
* ``yaml.FullLoader``: Loads the full YAML language. Avoids
arbitrary code execution. (Default for PyYAML 5.1+)
* ``yaml.SafeLoader``: Loads a subset of the YAML language, safely.
This is recommended for loading untrusted input.
* ``yaml.UnsafeLoader``: The original Loader code that could be
easily exploitable by untrusted data input.
* ``yaml.BaseLoader``: Only loads the most basic YAML. All scalars
are loaded as strings. \n
(Default is ``yaml.FullLoader``)
Returns:
(object):
loaded_obj (object):
The object loaded from the YAML file.
"""
path = prepare_filepath_for_loading(filepath, must_exist=True)
with open(path, mode) as f:
return yaml.load(f, Loader=loader)
[docs]
def download_file(
url: Optional[str],
path_save: str,
check_local_first: bool = True,
check_hash: bool = False,
hash_type: str = 'MD5',
hash_hex: Optional[str] = None,
mkdir: bool = False,
allow_overwrite: bool = True,
write_mode: str = 'wb',
verbose: bool = True,
chunk_size: int = 1024,
) -> None:
"""
Downloads a file from a URL to a local path using requests. Checks if file
already exists locally and verifies the hash of the downloaded file against
a provided hash if required.
RH 2023
Args:
url (Optional[str]):
URL of the file to download. If ``None``, then no download is
attempted. (Default is ``None``)
path_save (str):
Path to save the file to.
check_local_first (bool):
Whether to check if the file already exists locally. If ``True`` and
the file exists locally, the download is skipped. If ``True`` and
``check_hash`` is also ``True``, the hash of the local file is
checked. If the hash matches, the download is skipped. If the hash
does not match, the file is downloaded. (Default is ``True``)
check_hash (bool):
Whether to check the hash of the local or downloaded file against
``hash_hex``. (Default is ``False``)
hash_type (str):
Type of hash to use. Options are: ``'MD5'``, ``'SHA1'``,
``'SHA256'``, ``'SHA512'``. (Default is ``'MD5'``)
hash_hex (Optional[str]):
Hash to compare to, in hexadecimal format (e.g., 'a1b2c3d4e5f6...').
Can be generated using ``hash_file()`` or ``hashlib.hexdigest()``.
If ``check_hash`` is ``True``, ``hash_hex`` must be provided.
(Default is ``None``)
mkdir (bool):
If ``True``, creates the parent directory of ``path_save`` if it
does not exist. (Default is ``False``)
write_mode (str):
Write mode for saving the file. Options include: ``'wb'`` (write
binary), ``'ab'`` (append binary), ``'xb'`` (write binary, fail if
file exists). (Default is ``'wb'``)
verbose (bool):
If ``True``, prints status messages. (Default is ``True``)
chunk_size (int):
Size of chunks in which to download the file. (Default is *1024*)
"""
import os
import requests
# Check if file already exists locally
if check_local_first:
if os.path.isfile(path_save):
print(f'File already exists locally: {path_save}') if verbose else None
# Check hash of local file
if check_hash:
hash_local = hash_file(path_save, type_hash=hash_type)
if hash_local == hash_hex:
print('Hash of local file matches provided hash_hex.') if verbose else None
return True
else:
print('Hash of local file does not match provided hash_hex.') if verbose else None
print(f'Hash of local file: {hash_local}') if verbose else None
print(f'Hash provided in hash_hex: {hash_hex}') if verbose else None
print('Downloading file...') if verbose else None
else:
return True
else:
print(f'File does not exist locally: {path_save}. Will attempt download from {url}') if verbose else None
# Download file
if url is None:
print('No URL provided. No download attempted.') if verbose else None
return None
try:
response = requests.get(url, stream=True)
except requests.exceptions.RequestException as e:
print(f'Error downloading file: {e}') if verbose else None
return False
# Check response
if response.status_code != 200:
print(f'Error downloading file. Response status code: {response.status_code}') if verbose else None
return False
# Create parent directory if it does not exist
prepare_filepath_for_saving(path_save, mkdir=mkdir, allow_overwrite=allow_overwrite)
# Download file with progress bar
total_size = int(response.headers.get('content-length', 0))
wrote = 0
with open(path_save, write_mode) as f:
with tqdm(total=total_size, disable=(verbose==False), unit='B', unit_scale=True, unit_divisor=1024) as pbar:
for data in response.iter_content(chunk_size):
wrote = wrote + len(data)
f.write(data)
pbar.update(len(data))
if total_size != 0 and wrote != total_size:
print("ERROR, something went wrong")
return False
# Check hash
hash_local = hash_file(path_save, type_hash=hash_type)
if check_hash:
if hash_local == hash_hex:
print('Hash of downloaded file matches hash_hex.') if verbose else None
return True
else:
print('Hash of downloaded file does not match hash_hex.') if verbose else None
print(f'Hash of downloaded file: {hash_local}') if verbose else None
print(f'Hash provided in hash_hex: {hash_hex}') if verbose else None
return False
else:
print(f'Hash of downloaded file: {hash_local}') if verbose else None
return True
[docs]
def hash_file(
path: str,
type_hash: str = 'MD5',
buffer_size: int = 65536,
) -> str:
"""
Computes the hash of a file using the specified hash type and buffer size.
RH 2022
Args:
path (str):
Path to the file to be hashed.
type_hash (str):
Type of hash to use. (Default is ``'MD5'``). Either \n
* ``'MD5'``: MD5 hash algorithm.
* ``'SHA1'``: SHA1 hash algorithm.
* ``'SHA256'``: SHA256 hash algorithm.
* ``'SHA512'``: SHA512 hash algorithm.
buffer_size (int):
Buffer size (in bytes) for reading the file.
65536 corresponds to 64KB. (Default is *65536*)
Returns:
(str):
hash_val (str):
The computed hash of the file.
"""
import hashlib
if type_hash == 'MD5':
hasher = hashlib.md5()
elif type_hash == 'SHA1':
hasher = hashlib.sha1()
elif type_hash == 'SHA256':
hasher = hashlib.sha256()
elif type_hash == 'SHA512':
hasher = hashlib.sha512()
else:
raise ValueError(f'{type_hash} is not a valid hash type.')
with open(path, 'rb') as f:
while True:
data = f.read(buffer_size)
if not data:
break
hasher.update(data)
hash_val = hasher.hexdigest()
return hash_val
[docs]
def get_dir_contents(
directory: str,
) -> Tuple[List[str], List[str]]:
"""
Retrieves the names of the folders and files in a directory (does not
include subdirectories).
RH 2021
Args:
directory (str):
The path to the directory.
Returns:
(tuple): tuple containing:
folders (List[str]):
A list of folder names.
files (List[str]):
A list of file names.
"""
walk = os.walk(directory, followlinks=False)
folders = []
files = []
for ii,level in enumerate(walk):
folders, files = level[1:]
if ii==0:
break
return folders, files
[docs]
def compare_file_hashes(
hash_dict_true: Dict[str, Tuple[str, str]],
dir_files_test: Optional[str] = None,
paths_files_test: Optional[List[str]] = None,
verbose: bool = True,
) -> Tuple[bool, Dict[str, bool], Dict[str, str]]:
"""
Compares hashes of files in a directory or list of paths to provided hashes.
RH 2022
Args:
hash_dict_true (Dict[str, Tuple[str, str]]):
Dictionary of hashes to compare. Each entry should be in the format:
*{'key': ('filename', 'hash')}*.
dir_files_test (str):
Path to directory containing the files to compare hashes.
Unused if paths_files_test is not ``None``. (Optional)
paths_files_test (List[str]):
List of paths to files to compare hashes.
dir_files_test is used if ``None``. (Optional)
verbose (bool):
If ``True``, failed comparisons are printed out. (Default is ``True``)
Returns:
(tuple): tuple containing:
total_result (bool):
``True`` if all hashes match, ``False`` otherwise.
individual_results (Dict[str, bool]):
Dictionary indicating whether each hash matched.
paths_matching (Dict[str, str]):
Dictionary of paths that matched. Each entry is in the format:
*{'key': 'path'}*.
"""
if paths_files_test is None:
if dir_files_test is None:
raise ValueError('Must provide either dir_files_test or path_files_test.')
## make a dict of {filename: path} for each file in dir_files_test
files_test = {filename: (Path(dir_files_test).resolve() / filename).as_posix() for filename in get_dir_contents(dir_files_test)[1]}
else:
files_test = {Path(path).name: path for path in paths_files_test}
paths_matching = {}
results_matching = {}
for key, (filename, hash_true) in hash_dict_true.items():
match = True
if filename not in files_test:
print(f'{filename} not found in test directory: {dir_files_test}.') if verbose else None
match = False
elif hash_true != hash_file(files_test[filename]):
print(f'{filename} hash mismatch with {key, filename}.') if verbose else None
match = False
if match:
paths_matching[key] = files_test[filename]
results_matching[key] = match
return all(results_matching.values()), results_matching, paths_matching
#####################################################################################################################################
############################################################# INDEXING ##############################################################
#####################################################################################################################################
[docs]
def make_batches(
iterable,
batch_size=None,
num_batches=None,
min_batch_size=0,
return_idx=False,
length=None
):
"""
Generates batches of data from an iterable. RH 2021
Args:
iterable (Iterable):
Iterable to be batched.
batch_size (Optional[int]):
Size of each batch. If ``None``, ``batch_size`` is computed from
``num_batches``. (Default is ``None``)
num_batches (Optional[int]):
Number of batches to make. Used only when ``batch_size`` is
``None``. (Default is ``None``)
min_batch_size (int):
Minimum size of each batch. Batches smaller than this are
skipped. (Default is ``0``)
return_idx (bool):
If ``True``, yields ``(batch, [start, end])`` tuples instead of
just the batch. (Default is ``False``)
length (Optional[int]):
Length of the iterable. If ``None``, uses ``len(iterable)``. Useful
when the iterable does not implement ``__len__``. (Default is
``None``)
Returns:
(Generator):
output (Generator):
Yields successive batches from ``iterable``. If ``return_idx``
is ``True``, yields ``(batch, [start, end])`` tuples.
"""
if length is None:
l = len(iterable)
else:
l = length
if batch_size is None:
batch_size = np.int64(np.ceil(l / num_batches))
for start in range(0, l, batch_size):
end = min(start + batch_size, l)
if (end-start) < min_batch_size:
break
else:
if return_idx:
yield iterable[start:end], [start, end]
else:
yield iterable[start:end]
[docs]
def cp_to_dense(cp, weights=None):
"""
Reconstructs a dense tensor from a CP-format list of factor matrices. RH 2022
Args:
cp (List[np.ndarray]):
List of length ``n_modes`` of 2D factor matrices, each with
shape *(len_dim, rank)*. This is the format Tensorly uses for
its ``'cp'`` representation. Elements may be NumPy arrays or
``torch.Tensor`` (matching dtype).
weights (Optional[np.ndarray]):
Per-rank weights of length ``rank``. If ``None``, uses a vector
of ones. (Default is ``None``)
Returns:
(np.ndarray):
dense (np.ndarray):
Reconstructed dense tensor. shape: *(len_dim_0, len_dim_1, ...)*.
"""
rank = cp[0].shape[1]
n_modes = len(cp)
str_einsum = ','.join([chr(97+m)+'r' for m in range(n_modes)]) + '->' + ''.join([chr(97+m) for m in range(n_modes)])
if weights is None:
weights = np.ones(rank)
## check if numpy or torch
if isinstance(cp[0], np.ndarray):
einsum = np.einsum
weights = np.array(weights).astype(cp[0].dtype)
elif isinstance(cp[0], torch.Tensor):
einsum = torch.einsum
weights = torch.as_tensor(weights).type(cp[0].dtype).to(cp[0].device)
dense = einsum(str_einsum, *[cp[m] * weights for m in range(n_modes)])
return dense
#####################################################################################################################################
######################################################### CONTAINER HELPERS #########################################################
#####################################################################################################################################
[docs]
class Lazy_repeat_item():
"""
Lazy iterator-like container that always returns the same item. RH 2021
Args:
item (Any):
Item to repeat on every access.
pseudo_length (Optional[int]):
Reported length of the container. If ``None``, the container has
no enforced length and ``__getitem__`` always returns ``item``.
(Default is ``None``)
Attributes:
item (Any):
The repeated item.
pseudo_length (Optional[int]):
Stored pseudo length.
"""
def __init__(self, item, pseudo_length=None):
"""Initializes the repeater with the item and an optional pseudo length."""
super().__init__()
self.item = item
self.pseudo_length = pseudo_length
def __getitem__(self, i):
"""
Returns ``self.item``.
Args:
i (int):
Index requested. Ignored when ``pseudo_length`` is ``None``.
Otherwise an ``IndexError`` is raised when ``i >= pseudo_length``.
Returns:
(Any):
item (Any):
The repeated item.
"""
if self.pseudo_length is None:
return self.item
elif i < self.pseudo_length:
return self.item
else:
raise IndexError('Index out of bounds')
def __len__(self):
return self.pseudo_length
def __repr__(self):
return repr(self.item)
[docs]
def deep_update_dict(dictionary, key, new_val=None, new_key=None, in_place=False):
"""
Updates a value or renames a key inside a nested dictionary. RH 2022
Args:
dictionary (Dict):
Dictionary to update.
key (List[str]):
Hierarchical path of string keys leading to the entry to update.
Each element corresponds to a nesting level.
new_val (Optional[Any]):
New value to assign. If ``None``, ``new_key`` must be provided and
only the key is renamed. (Default is ``None``)
new_key (Optional[str]):
If provided, ``key[-1]`` is removed and replaced with ``new_key``
(mapping to ``new_val`` if given, otherwise to the existing value).
(Default is ``None``)
in_place (bool):
If ``True``, updates ``dictionary`` in place and returns ``None``.
If ``False``, returns a deep-copied updated dictionary. (Default
is ``False``)
Returns:
(Optional[Dict]):
output (Optional[Dict]):
Updated dictionary when ``in_place`` is ``False``; otherwise
``None``.
Example:
.. highlight:: python
.. code-block:: python
deep_update_dict(params, ['dataloader_kwargs', 'prefetch_factor'], val)
"""
def helper_deep_update_dict(d, key):
if type(key) is str:
key = [key]
assert key[0] in d, f"RH ERROR, key: '{key[0]}' is not found"
if type(key) is list:
if len(key) > 1:
helper_deep_update_dict(d[key[0]], key[1:])
elif len(key) == 1:
val = d[key[0]] if new_val is None else new_val
if new_key is None:
d[key[0]] = val
else:
d[new_key] = val
del d[key[0]]
if in_place:
helper_deep_update_dict(dictionary, key)
else:
d = copy.deepcopy(dictionary)
helper_deep_update_dict(d, key)
return d
[docs]
def flatten_dict(d: MutableMapping, parent_key: str = '', sep: str ='.') -> MutableMapping:
"""
Flattens a nested dictionary into a single dictionary. RH 2022
All keys are coerced to strings and joined by ``sep``. Adapted from
https://stackoverflow.com/a/6027615.
Args:
d (MutableMapping):
Dictionary to flatten.
parent_key (str):
Key prefix prepended to flattened keys. Used internally for
recursion. (Default is ``''``)
sep (str):
Separator between key components. Used internally for recursion.
(Default is ``'.'``)
Returns:
(Dict):
flattened (Dict):
Flat dictionary with paths joined by ``sep``.
"""
items = []
for k, v in d.items():
new_key = str(parent_key) + str(sep) + str(k) if parent_key else str(k)
if isinstance(v, MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
[docs]
def find_subDict_key(d: dict, s: str, max_depth: int=9999999):
"""
Recursively searches a nested dictionary for keys matching a regex.
Args:
d (dict):
Dictionary to search.
s (str):
Regex pattern that keys are matched against.
max_depth (int):
Maximum depth to descend. ``1`` searches only the top level, ``2``
searches the first and second levels, etc. (Default is
``9999999``)
Returns:
(List[Tuple[List[str], Any]]):
k_all (List[Tuple[List[str], Any]]):
List of 2-tuples ``(path, value)`` where ``path`` is the list
of string keys leading to the matched entry and ``value`` is
the matched sub-dictionary value.
"""
def helper_find_subDict_key(d, s, depth=999, _k_all=[]):
"""
_k_all:
Used for recursion. List of keys. Set to [] on first call.
depth:
Used for recursion. Decrements by 1 each call. At 0, stops
recursion.
"""
if depth > 0:
depth -= 1
for k, v in d.items():
if re.search(s, k):
yield _k_all + [k], v
if isinstance(v, dict):
yield from helper_find_subDict_key(v, s, depth, _k_all + [k])
return list(helper_find_subDict_key(d, s, depth=max_depth, _k_all=[]))
## parameter dictionary helpers ##
[docs]
def fill_in_dict(
d: Dict,
defaults: Dict,
verbose: bool = True,
hierarchy: List[str] = ['dict'],
):
"""
Fills in a dictionary in place with values from ``defaults`` for missing
keys, recursing into nested dictionaries. RH 2023
Args:
d (Dict):
Dictionary to fill in (modified in place).
defaults (Dict):
Dictionary of default values.
verbose (bool):
If ``True``, prints a message each time a default value is
inserted. (Default is ``True``)
hierarchy (List[str]):
Path of keys leading to ``d``. Used internally for recursion.
(Default is ``['dict']``)
"""
from copy import deepcopy
for key in defaults:
if key not in d:
print(f"Key '{key}' not found in params dictionary: {' > '.join([f'{str(h)}' for h in hierarchy])}. Using default value: {defaults[key]}") if verbose else None
d.update({key: deepcopy(defaults[key])})
elif isinstance(defaults[key], dict):
assert isinstance(d[key], dict), f"Key '{key}' is a dict in defaults, but not in params. {' > '.join([f'{str(h)}' for h in hierarchy])}."
fill_in_dict(d[key], defaults[key], hierarchy=hierarchy+[key])
[docs]
def check_keys_subset(
d,
default_dict,
error_on_missing_keys=True,
hierarchy=['defaults'],
):
"""
Verifies recursively that every key in ``d`` also appears in
``default_dict``. RH 2023
Args:
d (Dict):
Dictionary to check.
default_dict (Dict):
Dictionary containing the allowed keys.
error_on_missing_keys (bool):
If ``True``, raises ``AssertionError`` when a key in ``d`` is
not in ``default_dict``. If ``False``, emits a warning instead.
(Default is ``True``)
hierarchy (List[str]):
Path of keys leading to ``d``. Used internally for recursion.
(Default is ``['defaults']``)
"""
default_keys = list(default_dict.keys())
for key in d.keys():
if error_on_missing_keys:
assert key in default_keys, f"Key '{key}' not found in defaults dictionary: {' > '.join([f'{str(h)}' for h in hierarchy])}."
else:
if key not in default_keys:
warnings.warn(f"Key '{key}' not found in defaults dictionary: {' > '.join([f'{str(h)}' for h in hierarchy])}.")
continue
if isinstance(default_dict[key], dict) and isinstance(d[key], dict):
check_keys_subset(
d=d[key],
default_dict=default_dict[key],
error_on_missing_keys=error_on_missing_keys,
hierarchy=hierarchy+[key],
)
[docs]
def prepare_params(
params,
defaults,
error_on_missing_keys=True,
verbose=True,
):
"""
Validates ``params`` against ``defaults`` and fills in missing keys.
Performs the following: \n
* Checks that all keys in ``params`` are also in ``defaults``.
* Fills in any missing keys in ``params`` with values from ``defaults``.
* Returns a deepcopy of the filled-in ``params``. \n
Args:
params (Dict):
Dictionary of parameters.
defaults (Dict):
Dictionary of defaults.
error_on_missing_keys (bool):
If ``True``, raises an error when a key in ``params`` is not in
``defaults``. If ``False``, emits a warning instead. (Default is
``True``)
verbose (bool):
If ``True``, prints messages while filling in defaults. (Default
is ``True``)
Returns:
(Dict):
params_out (Dict):
Validated and default-filled deepcopy of ``params``.
"""
from copy import deepcopy
## Check inputs
assert isinstance(params, dict), f"p must be a dict. Got {type(params)} instead."
## Make sure all the keys in p are valid
check_keys_subset(
d=params,
default_dict=defaults,
error_on_missing_keys=error_on_missing_keys,
)
## Fill in any missing keys with defaults
params_out = deepcopy(params)
fill_in_dict(
d=params_out,
defaults=defaults,
verbose=verbose,
)
return params_out
#####################################################################################################################################
############################################################## VIDEO ################################################################
#####################################################################################################################################
## VideoReaderWrapper is only defined if decord is available.
## When using backend='torchcodec', decord is not required.
if decord is not None:
[docs]
class VideoReaderWrapper(decord.VideoReader):
"""
Subclass of ``decord.VideoReader`` that works around a memory leak.
Calls ``self.seek(0)`` after initialization and after every
``__getitem__`` so that decord releases buffered frames. Adapted from
https://github.com/dmlc/decord/issues/208#issuecomment-1157632702.
Attributes:
path (str):
Path to the video file (the first positional argument).
"""
def __init__(self, *args, **kwargs):
"""Forwards all arguments to ``decord.VideoReader`` and resets the seek position."""
super().__init__(*args, **kwargs)
self.seek(0)
self.path = args[0]
def __getitem__(self, key):
"""Returns frames at ``key`` and resets the underlying decord seek position."""
frames = super().__getitem__(key)
self.seek(0)
return frames
else:
VideoReaderWrapper = None
def _is_torchcodec_load_error(exc: BaseException) -> bool:
"""Returns ``True`` when ``exc`` indicates a torchcodec shared-library load failure."""
msg = str(exc)
markers = (
"Could not load libtorchcodec",
"Could not load this library",
"FFmpeg is not properly installed",
"libtorchcodec",
"libavutil",
"libavcodec",
"libavformat",
"Library not loaded: @rpath/libav",
)
return any(marker in msg for marker in markers)
def _torchcodec_unavailable_message() -> str:
"""Returns the user-facing error message shown when torchcodec cannot be loaded."""
return (
"torchcodec could not be imported or could not load its FFmpeg-linked "
"shared libraries. Install FFmpeg shared libraries visible to the "
"dynamic loader, or install torchcodec and ffmpeg from conda-forge "
"before installing face-rhythm. To use the bundled decord backend "
"instead, construct your video reader with backend='decord':\n"
" BufferedVideoReader(paths_videos=..., backend='decord')"
)
[docs]
class TorchCodecVideoReader:
"""
Video reader backed by ``torchcodec.decoders.VideoDecoder`` with a
workaround for torchcodec issue #905.
Provides the same ``__getitem__`` / ``__len__`` / ``get_avg_fps``
interface as ``VideoReaderWrapper`` (decord) so it can be used as a
drop-in replacement inside ``BufferedVideoReader``.
Frames are returned as ``torch.Tensor`` with shape ``(H, W, C)`` and
dtype ``uint8`` (NHWC layout), matching the output of decord's torch
bridge.
**Issue #905 workaround:** torchcodec's sequential access path skips
cursor reset; after reading ``n - has_b_frames`` frames on a single
decoder, FFmpeg's H.264 drain emits ``has_b_frames`` AVFrames with
``pts = INT64_MIN``, which the internal PTS filter rejects, causing
``EndOfFileException`` before the last frames are decoded. The fix is
to serve only frames ``[0, n - SAFETY)`` from the primary decoder
(which never reaches the drain), and route the trailing ``SAFETY``
frames through a fresh decoder that takes the non-sequential seek
branch (``avformat_seek_file`` + flush + forward-decode from keyframe)
where ``pkt_dts`` remains valid. ``SAFETY = max(has_b_frames, 2)``.
torchcodec's ``VideoStreamMetadata`` does not currently expose
``has_b_frames``, so SAFETY always defaults to 2, which matches
``ffprobe``-reported ``has_b_frames=2`` for H.264 AVIs and is a safe
overestimate for ``has_b_frames=0`` files.
The tail decoder is lazily created on first tail access and cached for
the lifetime of this reader (one decoder recreation per video pass,
versus one per chunk with earlier workarounds).
Thread safety is guaranteed by an internal lock — required because
``BufferedVideoReader`` loads slots from background threads.
Args:
path_video (str):
Path to the video file.
device (str):
Decode device. ``'cpu'`` for software decode,
``'cuda'`` or ``'cuda:0'`` for NVDEC hardware decode.
NVDEC requires torchcodec built with CUDA support and an
FFmpeg built with ``--enable-cuda``.
num_ffmpeg_threads (int):
Number of FFmpeg internal threads for decoding.
``0`` lets FFmpeg choose automatically (recommended).
"""
def __init__(self, path_video: str, device: str = 'cpu', num_ffmpeg_threads: int = 0):
"""Opens ``path_video`` with torchcodec and prepares the issue-#905 tail-decoder cache."""
import threading
self.path = str(path_video)
self._device = device
self._num_ffmpeg_threads = num_ffmpeg_threads
self._lock = threading.Lock()
try:
self._decoder = self._make_fresh_decoder()
except (ImportError, ModuleNotFoundError) as e:
raise ImportError(_torchcodec_unavailable_message()) from e
except (OSError, RuntimeError) as e:
if _is_torchcodec_load_error(e):
raise ImportError(_torchcodec_unavailable_message()) from e
raise
self._num_frames = len(self._decoder)
## SAFETY = max(has_b_frames, 2). torchcodec VideoStreamMetadata does
## not expose has_b_frames, so getattr always falls back to 2 here.
## 2 is correct for standard H.264 main/high AVIs (confirmed by ffprobe)
## and is a safe overestimate for has_b_frames=0 files.
_has_b_frames = getattr(self._decoder.metadata, 'has_b_frames', None)
self._tail_safety = max(_has_b_frames if _has_b_frames is not None else 2, 2)
self._safe_last = max(self._num_frames - self._tail_safety, 0)
## Tail decoder: lazily created on first tail-range access; cached to
## avoid one-decoder-per-call cost on sequential tail reads.
self._tail_decoder = None
def _make_fresh_decoder(self):
"""Return a new VideoDecoder for self.path with the configured options."""
from torchcodec.decoders import VideoDecoder
return VideoDecoder(
self.path,
seek_mode='exact',
dimension_order='NHWC',
device=self._device,
num_ffmpeg_threads=self._num_ffmpeg_threads,
)
def _read_one(self, idx: int):
"""
Read a single frame by index, routing tail frames to a fresh decoder.
Args:
idx (int):
Frame index in ``[0, len(self))``.
Returns:
(torch.Tensor):
frame (torch.Tensor): Shape ``(H, W, C)``, dtype ``uint8``.
"""
if idx < self._safe_last:
## Safe range: primary decoder, sequential access, no drain risk.
return self._decoder.get_frame_at(idx).data
## Tail range: use a fresh decoder (takes the non-sequential seek
## branch inside torchcodec), sidestepping the #905 drain bug.
if self._tail_decoder is None:
self._tail_decoder = self._make_fresh_decoder()
return self._tail_decoder.get_frame_at(idx).data
def __len__(self) -> int:
return self._num_frames
def __getitem__(self, key):
"""
Index frames by int, ``np.integer``, slice, list, or ``np.ndarray``.
All multi-frame access is implemented via scalar ``get_frame_at``
calls internally (avoids ``get_frames_at`` / ``get_frames_in_range``,
which share the buggy sequential access path from issue #905).
Args:
key (int | np.integer | slice | list | np.ndarray):
Frame index or indices.
Returns:
(torch.Tensor):
frames (torch.Tensor):
Single frame ``(H, W, C)`` for scalar key, or batch
``(N, H, W, C)`` for slice / list / ndarray key.
"""
import torch
with self._lock:
if isinstance(key, (int, np.integer)):
return self._read_one(int(key))
elif isinstance(key, slice):
indices = list(range(*key.indices(self._num_frames)))
if not indices:
return self._empty_batch()
return torch.stack([self._read_one(i) for i in indices])
elif isinstance(key, (list, np.ndarray)):
indices = [int(i) for i in key]
if not indices:
return self._empty_batch()
return torch.stack([self._read_one(i) for i in indices])
else:
raise TypeError(
f"TorchCodecVideoReader: unsupported key type {type(key)}. "
"Expected int, np.integer, slice, list, or np.ndarray."
)
[docs]
def get_avg_fps(self) -> float:
"""Return the average frame rate of the video."""
return self._decoder.metadata.average_fps
def _empty_batch(self):
"""Return a 4D NHWC empty tensor matching the decoder's frame shape."""
import torch
meta = self._decoder.metadata
return torch.empty((0, meta.height, meta.width, 3), dtype=torch.uint8)
[docs]
class BufferedVideoReader:
"""
Reads frames from one or more videos with a chunked memory buffer and
optional background prefetching. RH 2022
Sequential batches of frames can be read quickly because buffers are
filled by background threads. In many cases, batches can be consumed
without waiting for the next chunk to finish loading.
Optimal use:
1. Create a ``BufferedVideoReader`` object.
2. EITHER set ``method_getitem='continuous'`` and iterate over the object
(fastest path), OR request batches of frames sequentially (going
backwards is slow because buffers move forward).
3. Each batch should fit inside a single buffer slot. Slices that span
multiple buffer slots require concatenation and are slow. With a
buffer size of 1000 frames, ``[0:1000], [1000:2000], ...`` is fast,
while ``[0:1700]``, ``[1700:3200]``, ``[0:990]``, ``[990:1010]`` are
slow (too big, overlapping, backwards, or crossing slot boundaries).
Args:
video_readers (Optional[list]):
List of video reader objects (``decord.VideoReader`` or
``TorchCodecVideoReader``). A single reader is also accepted.
If ``None``, ``paths_videos`` must be provided. (Default is
``None``)
paths_videos (Optional[list]):
List of paths to videos. A single ``str`` is also accepted. If
``None``, ``video_readers`` must be provided. If both are
supplied, ``video_readers`` wins. (Default is ``None``)
buffer_size (int):
Number of frames per buffer slot. Avoid indexing more than
``buffer_size`` frames at a time or across slot boundaries (e.g.
across ``idx % buffer_size == 0``); these require concatenating
buffers and are slow. (Default is ``1000``)
prefetch (int):
Number of buffers to prefetch ahead. ``0`` disables prefetching.
A single buffer slot only contains frames from one video, so
``buffer_size <= video length`` is recommended. (Default is ``2``)
posthold (int):
Number of buffers to keep loaded behind the current position.
``0`` disables posthold. Useful when iterating backwards.
(Default is ``1``)
method_getitem (str):
Indexing mode for ``__getitem__``. One of \n
* ``'continuous'``: index across all videos as a single
concatenated sequence; ``reader[idx_frames_slice]``.
* ``'by_video'``: index requires a ``(idx_video, idx_frames)``
tuple; ``reader[(idx_video, slice)]``. \n
(Default is ``'continuous'``)
starting_seek_position (int):
Starting frame index for the iterator. Used only when
``method_getitem == 'continuous'`` and iterating. (Default is
``0``)
backend (str):
Video decoding backend. One of \n
* ``'torchcodec'``: uses ``torchcodec.decoders.VideoDecoder``.
Frame-accurate seeking, actively maintained, supports CPU and
GPU (NVDEC) decode. Includes a workaround for torchcodec
issue #905 (sequential-access drain bug near EOF in H.264
AVIs): frames in ``[0, n - SAFETY)`` come from a persistent
decoder; the trailing ``SAFETY = max(has_b_frames, 2)`` frames
go through a fresh decoder cached as the tail decoder.
* ``'decord'``: uses ``decord.VideoReader``. Well-tested
fallback and the only backend available on Windows. Provided
by the ``decord2`` PyPI package on Linux/macOS (with vendored
FFmpeg 8 wheels for py3.10-3.14) and by ``eva_decord`` on
Windows. Both are installed by face-rhythm's default
dependencies. \n
Only used when ``paths_videos`` is provided. (Default is
``'torchcodec'``)
device (str):
Device for video decoding when using torchcodec. ``'cpu'``
decodes on CPU. ``'cuda'`` or ``'cuda:0'`` decodes on GPU using
NVDEC; frames are returned as CUDA tensors. GPU decode requires
an NVIDIA GPU, torchcodec installed with CUDA support, and
FFmpeg built with ``--enable-cuda``. (Default is ``'cpu'``)
decord_backend (str):
Backend used by decord when loading frames (``'torch'``,
``'numpy'``, ``'mxnet'``, ...). Only used when
``backend='decord'``. (Default is ``'torch'``)
decord_ctx (object):
Context used by decord when loading frames (e.g. ``decord.cpu()``,
``decord.gpu()``). Only used when ``backend='decord'``. (Default
is ``None``)
verbose (int):
Verbosity level. ``0`` silences output, ``1`` prints warnings,
``2`` prints warnings and info. (Default is ``1``)
Attributes:
num_frames_total (int):
Total number of frames across all videos.
num_videos (int):
Number of videos being read.
metadata (pandas.DataFrame):
Per-video metadata (path, length, fps, frame size, channels).
frame_rate (List[float]):
Frame rate of each video.
frame_height_width (List[Tuple[int, int]]):
``(H, W)`` of each video.
num_channels (List[int]):
Number of channels of each video.
slots (List[List[Optional[torch.Tensor]]]):
Buffer slots holding chunks of decoded frames.
boundaries (List[List[Tuple[int, int]]]):
Inclusive ``(start, end)`` frame index for each slot.
lookup (pandas.DataFrame):
Lookup table mapping continuous frame index to ``(video, slot)``.
"""
def __init__(
self,
video_readers: list=None,
paths_videos: list=None,
buffer_size: int=1000,
prefetch: int=2,
posthold: int=1,
method_getitem: str='continuous',
starting_seek_position: int=0,
backend: str='torchcodec',
device: str='cpu',
decord_backend: str='torch',
decord_ctx=None,
verbose: int=1,
):
"""Initializes the reader, opens the videos, and prepares buffer slots."""
import pandas as pd
self._verbose = verbose
self.buffer_size = buffer_size
self.prefetch = prefetch
self.posthold = posthold
self._backend = backend
self._device = device
self._decord_backend = decord_backend
self._decord_ctx = (decord.cpu(0) if decord_ctx is None else decord_ctx) if decord is not None else None
## Check inputs
_single_reader_types = (TorchCodecVideoReader,)
if decord is not None:
_single_reader_types = (decord.VideoReader, TorchCodecVideoReader)
if isinstance(video_readers, _single_reader_types):
video_readers = [video_readers]
if isinstance(paths_videos, str):
paths_videos = [paths_videos]
assert (video_readers is not None) or (paths_videos is not None), "Must provide either video_readers or paths_videos"
## If both video_readers and paths_videos are provided, use the video_readers and print a warning
if (video_readers is not None) and (paths_videos is not None):
print(f"FR WARNING: Both video_readers and paths_videos were provided. Using video_readers and ignoring path_videos.")
paths_videos = None
## If paths are specified, create video reader objects
if paths_videos is not None:
print(f"FR: Loading video reader objects (backend='{self._backend}')...") if self._verbose > 1 else None
assert isinstance(paths_videos, list), "paths_videos must be list of str"
assert all([isinstance(p, str) for p in paths_videos]), "paths_videos must be list of str"
if self._backend == 'torchcodec':
print(f"FR: Video decode device: {self._device}") if self._verbose > 1 else None
try:
video_readers = [TorchCodecVideoReader(path_video, device=self._device) for path_video in tqdm(paths_videos, disable=(self._verbose < 2))]
except (ImportError, ModuleNotFoundError) as e:
raise ImportError(_torchcodec_unavailable_message()) from e
except (OSError, RuntimeError) as e:
if _is_torchcodec_load_error(e):
raise ImportError(_torchcodec_unavailable_message()) from e
raise
elif self._backend == 'decord':
assert decord is not None, (
"FR ERROR: decord is not installed (this is unexpected — "
"decord2/eva_decord are face-rhythm required dependencies). "
"Install with: pip install decord2 (or eva_decord on Windows)"
)
video_readers = [VideoReaderWrapper(path_video, ctx=self._decord_ctx) for path_video in tqdm(paths_videos, disable=(self._verbose < 2))]
else:
raise ValueError(f"FR ERROR: Unknown video backend '{self._backend}'. Use 'torchcodec' or 'decord'.")
self.paths_videos = paths_videos
else:
print(f"FR: Using provided video reader objects...") if self._verbose > 1 else None
assert isinstance(video_readers, list), "video_readers must be a list of video reader objects"
self.paths_videos = [v.path for v in video_readers]
## Assert that method_getitem is valid
assert method_getitem in ['continuous', 'by_video'], "method_getitem must be 'continuous' or 'by_video'"
## Set decord bridge if using decord backend
if self._backend == 'decord':
decord.bridge.set_bridge(self._decord_backend)
self.paths_videos = [str(path) for path in self.paths_videos] ## ensure paths are str
self.video_readers = video_readers
self._cumulative_frame_end = np.cumsum([len(video_reader) for video_reader in self.video_readers])
self._cumulative_frame_start = np.concatenate([[0], self._cumulative_frame_end[:-1]])
self.num_frames_total = self._cumulative_frame_end[-1]
self.method_getitem = method_getitem
## Get metadata about videos: lengths, fps, frame size, etc.
self.metadata, self.num_frames_total, self.frame_rate, self.frame_height_width, self.num_channels = self._get_metadata(self.video_readers)
## Get number of videos
self.num_videos = len(self.video_readers)
## Set iterator starting frame
print(f"FR: Setting iterator starting frame to {starting_seek_position}") if self._verbose > 1 else None
self.set_iterator_frame_idx(starting_seek_position)
## Initialize the buffer
### Make a list containing a slot for each buffer chunk
self.slots = [[None] * np.ceil(len(d)/self.buffer_size).astype(int) for d in self.video_readers]
### Make a list containing the bounding indices for each buffer video chunk. Upper bound should be min(buffer_size, num_frames)
self.boundaries = [[(i*self.buffer_size, min((i+1)*self.buffer_size, len(d))-1) for i in range(len(s))] for d, s in zip(self.video_readers, self.slots)]
### Make a lookup table for the buffer slot that contains each frame
self.lookup = {
'video': np.concatenate([np.array([ii]*len(s), dtype=int) for ii, s in enumerate(self.slots)]).tolist(),
'slot': np.concatenate([np.arange(len(s)) for s in self.slots]).tolist(),
'start_frame': np.concatenate([np.array([s[0] for s in b]) for b in self.boundaries]).astype(int).tolist(),
'end_frame': np.concatenate([np.array([s[1] for s in b]) for b in self.boundaries]).astype(int).tolist(),
}
self.lookup['start_frame_continuous'] = (np.array(self.lookup['start_frame']) + np.array(self._cumulative_frame_start[self.lookup['video']])).tolist()
self.lookup = pd.DataFrame(self.lookup)
self._start_frame_continuous = self.lookup['start_frame_continuous'].values
## Make a list for which slots are loaded or loading
self.loading = []
self.loaded = []
def _get_metadata(self, video_readers):
"""
Collects per-video metadata: length, fps, frame size, channels.
Args:
video_readers (List[object]):
List of decoder objects (``decord.VideoReader`` or
``TorchCodecVideoReader``).
Returns:
(tuple): tuple containing:
metadata (Dict[str, list]):
Per-video metadata with keys ``'paths_videos'``,
``'num_frames'``, ``'frame_rate'``,
``'frame_height_width'``, ``'num_channels'``.
num_frames_total (int):
Total number of frames across all videos.
frame_rate (float):
Median frame rate across videos.
frame_height_width (List[int]):
Common ``[height, width]`` of all videos.
num_channels (int):
Common channel count across videos.
"""
## make video metadata dataframe
print("FR: Collecting video metadata...") if self._verbose > 1 else None
metadata = {"paths_videos": self.paths_videos}
num_frames, frame_rate, frame_height_width, num_channels = [], [], [], []
for v in tqdm(video_readers, disable=(self._verbose < 2)):
num_frames.append(int(len(v)))
frame_rate.append(float(v.get_avg_fps()))
frame_tmp = v[0]
frame_height_width.append([int(n) for n in frame_tmp.shape[:2]])
num_channels.append(int(frame_tmp.shape[2]))
metadata["num_frames"] = num_frames
metadata["frame_rate"] = frame_rate
metadata["frame_height_width"] = frame_height_width
metadata["num_channels"] = num_channels
## Assert that all videos must have at least one frame
assert all([n > 0 for n in metadata["num_frames"]]), "FR ERROR: All videos must have at least one frame"
## Assert that all videos must have the same shape
assert all([n == metadata["frame_height_width"][0] for n in metadata["frame_height_width"]]), "FR ERROR: All videos must have the same shape"
## Assert that all videos must have the same number of channels
assert all([n == metadata["num_channels"][0] for n in metadata["num_channels"]]), "FR ERROR: All videos must have the same number of channels"
## get frame rate
frame_rates = metadata["frame_rate"]
## warn if any video's frame rate is very different from others
max_diff = float((np.max(frame_rates) - np.min(frame_rates)) / np.mean(frame_rates))
print(f"FR WARNING: max frame rate difference is large: {max_diff*100:.2f}%") if ((max_diff > 0.1) and (self._verbose > 0)) else None
frame_rate = float(np.median(frame_rates))
num_frames_total = int(np.sum(metadata["num_frames"]))
frame_height_width = metadata["frame_height_width"][0]
num_channels = metadata["num_channels"][0]
return metadata, num_frames_total, frame_rate, frame_height_width, num_channels
def _load_slots(self, idx_slots: list, wait_for_load: Union[bool, list]=False):
"""
Loads buffer slots in the background using threading.
Args:
idx_slots (list):
List of ``(idx_video, idx_buffer)`` tuples identifying the
slots to load.
wait_for_load (Union[bool, List[bool]]):
If ``True``, blocks until each slot is loaded before
returning. If ``False``, returns immediately. If a list, each
entry corresponds to the slot at the same position in
``idx_slots``. (Default is ``False``)
"""
## Check if idx_slots is a list
if not isinstance(idx_slots, list):
idx_slots = [idx_slots]
## Check if wait_for_load is a list
if not isinstance(wait_for_load, list):
wait_for_load = [wait_for_load] * len(idx_slots)
print(f"FR: Loading slots {idx_slots} in the background. Waiting: {wait_for_load}") if self._verbose > 1 else None
print(f"FR: Loaded: {self.loaded}, Loading: {self.loading}") if self._verbose > 1 else None
thread = None
for idx_slot, wait in zip(idx_slots, wait_for_load):
## Check if slot is already loaded
(print(f"FR: Slot {idx_slot} already loaded") if (idx_slot in self.loaded) else None) if self._verbose > 1 else None
(print(f"FR: Slot {idx_slot} already loading") if (idx_slot in self.loading) else None) if self._verbose > 1 else None
## If the slot is not already loaded or loading
if (idx_slot not in self.loading) and (idx_slot not in self.loaded):
print(f"FR: Loading slot {idx_slot}") if self._verbose > 1 else None
## Load the slot
self.loading.append(idx_slot)
thread = threading.Thread(target=self._load_slot, args=(idx_slot, thread))
thread.start()
## Wait for the slot to load if wait_for_load is True
if wait:
print(f"FR: Waiting for slot {idx_slot} to load") if self._verbose > 1 else None
thread.join()
print(f"FR: Slot {idx_slot} loaded") if self._verbose > 1 else None
## If the slot is already loading
elif idx_slot in self.loading:
## Wait for the slot to load if wait_for_load is True
if wait:
print(f"FR: Waiting for slot {idx_slot} to load") if self._verbose > 1 else None
while idx_slot in self.loading:
time.sleep(0.01)
print(f"FR: Slot {idx_slot} loaded") if self._verbose > 1 else None
def _load_slot(self, idx_slot: tuple, blocking_thread: threading.Thread=None):
"""
Loads a single buffer slot, optionally after another thread finishes.
Populates ``self.slots[idx_slot[0]][idx_slot[1]]`` with decoded
frames. If ``blocking_thread`` is provided, this call ``join()`` s
that thread before reading, ensuring serial decode order.
Args:
idx_slot (Tuple[int, int]):
``(idx_video, idx_buffer)`` identifying the slot to load.
blocking_thread (Optional[threading.Thread]):
Thread to wait on before loading. (Default is ``None``)
"""
## Set backend of decord to PyTorch (only needed for decord)
if self._backend == 'decord':
decord.bridge.set_bridge(self._decord_backend)
## Wait for the previous slot to finish loading
if blocking_thread is not None:
blocking_thread.join()
## Load the slot
idx_video, idx_buffer = idx_slot
idx_frame_start, idx_frame_end = self.boundaries[idx_video][idx_buffer]
loaded = False
while loaded == False:
try:
self.slots[idx_video][idx_buffer] = self.video_readers[idx_video][idx_frame_start:idx_frame_end+1]
loaded = True
except Exception as e:
print(f"FR WARNING: Failed to load slot {idx_slot}. Likely causes are: 1) File is partially corrupted, 2) You are trying to go back to a file that was recently removed from a slot.") if self._verbose > 0 else None
print(f" Sleeping for 1s, then will try loading again. Decord error below:") if self._verbose > 0 else None
print(e)
time.sleep(1)
## Mark the slot as loaded
self.loaded.append(idx_slot)
## Remove the slot from the loading list
self.loading.remove(idx_slot)
def _delete_slots(self, idx_slots: list):
"""
Frees buffer slots by setting their contents to ``None``.
Args:
idx_slots (List[Tuple[int, int]]):
``(idx_video, idx_buffer)`` tuples identifying the slots to
delete.
"""
print(f"FR: Deleting slots {idx_slots}") if self._verbose > 1 else None
## Find all loaded slots
idx_loaded = [idx_slot for idx_slot in idx_slots if idx_slot in self.loaded]
for idx_slot in idx_loaded:
## If the slot is loaded
if idx_slot in self.loaded:
## Delete the slot
self.slots[idx_slot[0]][idx_slot[1]] = None
## Remove the slot from the loaded list
self.loaded.remove(idx_slot)
print(f"FR: Deleted slot {idx_slot}") if self._verbose > 1 else None
[docs]
def delete_all_slots(self):
"""Frees every currently loaded slot by delegating to ``_delete_slots``."""
print(f"FR: Deleting all slots") if self._verbose > 1 else None
self._delete_slots(self.loaded)
[docs]
def wait_for_loading(self):
"""Blocks until every background slot-loading thread has finished."""
print(f"FR: Waiting for all slots to load") if self._verbose > 1 else None
while len(self.loading) > 0:
time.sleep(0.01)
[docs]
def get_frames_from_single_video_index(self, idx: tuple):
"""
Returns frames from a single video by ``(video, frame)`` index.
If ``idx`` is an ``int`` or ``slice`` it is interpreted as a video
index and a new ``BufferedVideoReader`` is constructed over the
selected videos.
Args:
idx (Union[int, slice, Tuple[int, Union[int, slice]]]):
Either ``(idx_video, idx_frames)`` to read frames from one
video, or an ``int`` / ``slice`` to spawn a reader over a
subset of videos.
Returns:
(Union[torch.Tensor, BufferedVideoReader]):
frames (Union[torch.Tensor, BufferedVideoReader]):
Decoded frames with shape *(num_frames, H, W, C)* when
``idx`` is a tuple, or a new reader when ``idx`` selects
videos.
"""
## if idx is an int or slice, use idx to make a new BufferedVideoReader of just those videos
idx = slice(idx, idx+1) if isinstance(idx, int) else idx
if isinstance(idx, slice):
## convert to a slice
print(f"FR: Returning new buffered video reader(s). Videos={idx.start} to {idx.stop}.") if self._verbose > 1 else None
return BufferedVideoReader(
video_readers=self.video_readers[idx],
buffer_size=self.buffer_size,
prefetch=self.prefetch,
method_getitem='continuous',
starting_seek_position=0,
backend=self._backend,
device=self._device,
decord_backend=self._decord_backend,
decord_ctx=self._decord_ctx,
verbose=self._verbose,
)
print(f"FR: Getting item {idx}") if self._verbose > 1 else None
## Assert that idx is a tuple of (int, int) or (int, slice)
assert isinstance(idx, tuple), f"idx must be: int, tuple of (int, int), or (int, slice). Got {type(idx)}"
assert len(idx) == 2, f"idx must be: int, tuple of (int, int), or (int, slice). Got {len(idx)} elements"
assert isinstance(idx[0], int), f"idx[0] must be an int. Got {type(idx[0])}"
assert isinstance(idx[1], int) or isinstance(idx[1], slice), f"idx[1] must be an int or a slice. Got {type(idx[1])}"
## Get the index of the video and the slice of frames
idx_video, idx_frames = idx
## If idx_frames is a single integer, convert it to a slice
idx_frames = slice(idx_frames, idx_frames+1) if isinstance(idx_frames, int) else idx_frames
## Bound the range of the slice
idx_frames = slice(max(idx_frames.start, 0), min(idx_frames.stop, len(self.video_readers[idx_video])))
## Assert that slice is not empty
assert idx_frames.start < idx_frames.stop, f"Slice is empty: idx:{idx}"
## Get the start and end indices for the slice of frames
idx_frame_start = idx_frames.start if idx_frames.start is not None else 0
idx_frame_end = idx_frames.stop if idx_frames.stop is not None else len(self.video_readers[idx_video])
idx_frame_step = idx_frames.step if idx_frames.step is not None else 1
## Get the indices of the slots that contain the frames
idx_slots = [(idx_video, i) for i in range(idx_frame_start // self.buffer_size, ((idx_frame_end-1) // self.buffer_size)+1)]
print(f"FR: Slots to load: {idx_slots}") if self._verbose > 1 else None
## Load the prefetch slots
idx_slot_lookuptable = np.where((self.lookup['video']==idx_slots[-1][0]) * (self.lookup['slot']==idx_slots[-1][1]))[0][0]
if self.prefetch > 0:
idx_slots_prefetch = [(self.lookup['video'][ii], self.lookup['slot'][ii]) for ii in range(idx_slot_lookuptable+1, idx_slot_lookuptable+self.prefetch+1) if ii < len(self.lookup)]
else:
idx_slots_prefetch = []
## Load the slots
self._load_slots(idx_slots + idx_slots_prefetch, wait_for_load=[True]*len(idx_slots) + [False]*len(idx_slots_prefetch))
## Delete the slots that are no longer needed.
### Find slots before the posthold to delete
idx_slots_delete = [(self.lookup['video'][ii], self.lookup['slot'][ii]) for ii in range(idx_slot_lookuptable-self.posthold) if ii >= 0]
### Delete all previous slots
self._delete_slots(idx_slots_delete)
# ### All slots from old videos should be deleted.
# self._delete_slots([idx_slot for idx_slot in self.loaded if idx_slot[0] < idx_video])
# ### All slots from previous buffers should be deleted.
# self._delete_slots([idx_slot for idx_slot in self.loaded if idx_slot[0] == idx_video and idx_slot[1] < idx_frame_start // self.buffer_size])
## Get the frames from the slots
idx_frames_slots = [slice(max(idx_frame_start - self.boundaries[idx_slot[0]][idx_slot[1]][0], 0), min(idx_frame_end - self.boundaries[idx_slot[0]][idx_slot[1]][0], self.buffer_size), idx_frame_step) for idx_slot in idx_slots]
print(f"FR: Frames within slots: {idx_frames_slots}") if self._verbose > 1 else None
## Get the frames. Then concatenate them along the first dimension using torch.cat
### Skip the concatenation if there is only one slot
if len(idx_slots) == 1:
frames = self.slots[idx_slots[0][0]][idx_slots[0][1]][idx_frames_slots[0]]
else:
print(f"FR: Warning. Slicing across multiple slots is SLOW. Consider increasing buffer size or adjusting batching method.") if self._verbose > 1 else None
frames = torch.cat([self.slots[idx_slot[0]][idx_slot[1]][idx_frames_slot] for idx_slot, idx_frames_slot in zip(idx_slots, idx_frames_slots)], dim=0)
# ## Squeeze if there is only one frame
# frames = frames.squeeze(0) if frames.shape[0] == 1 else frames
return frames
[docs]
def get_frames_from_continuous_index(self, idx):
"""
Returns frames addressed by a continuous (concatenated) frame index.
The videos are treated as one long sequence of frames; ``idx`` is the
index of the frames within this sequence.
Args:
idx (Union[int, slice]):
Frame index. If an ``int``, a single frame is returned. If a
``slice``, the corresponding batch of frames is returned.
Returns:
(torch.Tensor):
frames (torch.Tensor):
Stacked frames. shape: *(num_frames, height, width, num_channels)*.
"""
## Assert that idx is an int or a slice
assert isinstance(idx, (int, np.intp)) or isinstance(idx, slice), f"idx must be an int or a slice. Got {type(idx)}"
idx = int(idx) if isinstance(idx, (np.intp)) else idx
## If idx is a single integer, convert it to a slice
idx = slice(idx, idx+1) if isinstance(idx, int) else idx
## Assert that the slice is not empty
assert idx.start < idx.stop, f"Slice is empty: idx:{idx}"
## Assert that the slice is not out of bounds
assert idx.stop <= self.num_frames_total, f"Slice is out of bounds: idx:{idx}"
## Find the video and frame indices
idx_video_start = np.searchsorted(self._cumulative_frame_start, idx.start, side='right') - 1
idx_video_end = np.searchsorted(self._cumulative_frame_end, idx.stop, side='left')
## Get the frames using the __getitem__ method
### This needs to be done one video at a time
frames = []
for idx_video in range(idx_video_start, idx_video_end+1):
## Get the start and end indices for the slice of frames
idx_frame_start = idx.start - self._cumulative_frame_start[idx_video] if idx_video == idx_video_start else 0
idx_frame_end = idx.stop - self._cumulative_frame_start[idx_video] if idx_video == idx_video_end else len(self.video_readers[idx_video])
## Get the frames
print(f"FR: Getting frames from video {idx_video} from {idx_frame_start} to {idx_frame_end}") if self._verbose > 1 else None
frames.append(self.get_frames_from_single_video_index((idx_video, slice(idx_frame_start, idx_frame_end, idx.step))))
## Concatenate the frames if there are multiple videos
frames = torch.cat(frames, dim=0) if len(frames) > 1 else frames[0]
return frames
[docs]
def set_iterator_frame_idx(self, idx):
"""
Sets the starting frame for the iterator.
Args:
idx (int):
Frame index from which the iterator should start. Must be in
``'continuous'`` format, i.e. the index of the frame within
the concatenated sequence of all videos.
"""
self._iterator_frame = idx
def __getitem__(self, idx):
"""Dispatches to ``get_frames_from_single_video_index`` or ``get_frames_from_continuous_index`` based on ``method_getitem``."""
if self.method_getitem == 'by_video':
return self.get_frames_from_single_video_index(idx)
elif self.method_getitem == 'continuous':
return self.get_frames_from_continuous_index(idx)
else:
raise ValueError(f"Invalid method_getitem: {self.method_getitem}")
def __len__(self):
"""Returns the number of videos (``'by_video'``) or the total frame count (``'continuous'``)."""
if self.method_getitem == 'by_video':
return len(self.video_readers)
elif self.method_getitem == 'continuous':
return self.num_frames_total
def __repr__(self):
"""Returns a debug string summarizing buffer size, video count, prefetch state, and verbosity."""
if self.method_getitem == 'by_video':
return f"BufferedVideoReader(buffer_size={self.buffer_size}, num_videos={self.num_videos}, method_getitem='{self.method_getitem}', loaded={self.loaded}, prefetch={self.prefetch}, loading={self.loading}, verbose={self._verbose})"
elif self.method_getitem == 'continuous':
return f"BufferedVideoReader(buffer_size={self.buffer_size}, num_videos={self.num_videos}, total_frames={self.num_frames_total}, method_getitem='{self.method_getitem}', iterator_frame={self._iterator_frame}, prefetch={self.prefetch}, loaded={self.loaded}, loading={self.loading}, verbose={self._verbose})"
def __iter__(self):
"""
Iterates over the contents according to ``method_getitem``.
When ``method_getitem == 'by_video'``, yields a fresh
``BufferedVideoReader`` for each underlying video. When
``method_getitem == 'continuous'``, yields single frames pulled
directly from buffer slots, calling
``get_frames_from_continuous_index`` whenever a new slot must be
loaded.
Returns:
(Iterator):
iterator (Iterator):
Iterator over readers (``'by_video'``) or frames
(``'continuous'``).
"""
if self.method_getitem == 'by_video':
return iter([BufferedVideoReader(
video_readers=[self.video_readers[idx]],
buffer_size=self.buffer_size,
prefetch=self.prefetch,
method_getitem='continuous',
starting_seek_position=0,
backend=self._backend,
device=self._device,
decord_backend=self._decord_backend,
decord_ctx=self._decord_ctx,
verbose=self._verbose,
) for idx in range(len(self.video_readers))])
elif self.method_getitem == 'continuous':
## Initialise the buffers by loading the first frame in the sequence
self.get_frames_from_continuous_index(self._iterator_frame)
## Make lazy iterator over all frames
def lazy_iterator():
while self._iterator_frame < self.num_frames_total:
## Find slot for current frame idx
idx_video = np.searchsorted(self._cumulative_frame_start, self._iterator_frame, side='right') - 1
idx_slot_in_video = (self._iterator_frame - self._cumulative_frame_start[idx_video]) // self.buffer_size
idx_frame = self._iterator_frame - self._cumulative_frame_start[idx_video]
## If the frame is at the beginning of a slot, then use get_frames_from_single_video_index otherwise just grab directly from the slot
if (self._iterator_frame in self._start_frame_continuous):
yield self.get_frames_from_continuous_index(self._iterator_frame)[0]
else:
## Get the frame directly from the slot
yield self.slots[idx_video][idx_slot_in_video][idx_frame%self.buffer_size]
self._iterator_frame += 1
return iter(lazy_iterator())
[docs]
def save_gif(
array,
path,
frameRate=5.0,
loop=0,
backend='PIL',
kwargs_backend={},
):
"""
Saves an array of images as an animated GIF. RH 2023
Args:
array (Union[np.ndarray, list]):
3D (grayscale) or 4D (color) array of images. If dtype is
floating, values are interpreted in *[0, 1]*; if integer, in
*[0, 255]*.
path (str):
Output path for the GIF.
frameRate (float):
Frame rate of the GIF in frames per second. (Default is ``5.0``)
loop (int):
Number of loops. ``0`` loops forever, ``1`` plays once, ``2``
plays twice, etc. (Default is ``0``)
backend (str):
GIF writer backend. One of \n
* ``'imageio'``
* ``'PIL'`` \n
(Default is ``'PIL'``)
kwargs_backend (dict):
Extra keyword arguments forwarded to the chosen backend.
(Default is ``{}``)
"""
array = np.stack(array, axis=0) if isinstance(array, list) else array
array = grayscale_to_rgb(array) if array.ndim == 3 else array
if np.issubdtype(array.dtype, np.floating):
array = (array*255).astype('uint8')
kwargs_backend.update({'loop': loop} if loop != 1 else {})
if backend == 'imageio':
import imageio
imageio.mimsave(
path,
array,
format='GIF',
duration=1000/frameRate,
**kwargs_backend,
)
elif backend == 'PIL':
from PIL import Image
frames = [Image.fromarray(array[i_frame]) for i_frame in range(array.shape[0])]
frames[0].save(
path,
format='GIF',
append_images=frames[1:],
save_all=True,
duration=1000/frameRate,
**kwargs_backend,
)
else:
raise Exception(f'Unsupported backend {backend}')
[docs]
def grayscale_to_rgb(array):
"""
Converts a grayscale image or movie to RGB by repeating the channel. RH 2023
Args:
array (Union[np.ndarray, torch.Tensor, list]):
2D image or 3D movie of grayscale frames. Lists of arrays or
tensors are stacked first.
Returns:
(Union[np.ndarray, torch.Tensor]):
rgb (Union[np.ndarray, torch.Tensor]):
Same backend as the input with an extra trailing channel
dimension of size 3.
"""
import torch
if isinstance(array, list):
if isinstance(array[0], np.ndarray):
array = np.stack(array, axis=0)
elif isinstance(array[0], torch.Tensor):
array = torch.stack(array, axis=0)
else:
raise Exception(f'Failed to convert list of type {type(array[0])} to array')
if isinstance(array, np.ndarray):
return np.stack([array, array, array], axis=-1)
elif isinstance(array, torch.Tensor):
return torch.stack([array, array, array], dim=-1)
########################################################################################################################################
############################################################# CONVOLUTION ##############################################################
########################################################################################################################################
[docs]
class Toeplitz_convolution2d:
"""
Convolves a 2D array with a 2D kernel via Toeplitz matrix multiplication.
RH 2022
Allows **sparse** ``x`` inputs (``k`` must remain dense). Ideal when
``x`` is very sparse (density < 0.01), ``x`` is small (shape <
*(1000, 1000)*), ``k`` is small (shape < *(100, 100)*), and the batch
size is large (e.g. 1000+). Generally faster than
``scipy.signal.convolve2d`` when convolving many arrays with the same
kernel. Memory footprint stays low because the Toeplitz matrix is held
as a sparse matrix.
See https://stackoverflow.com/a/51865516 and
https://github.com/alisaaalehi/convolution_as_multiplication for an
illustration. See
https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.convolution_matrix.html
for the 1D version, and
https://docs.scipy.org/doc/scipy/reference/generated/scipy.linalg.matmul_toeplitz.html
for potential speedups.
Args:
x_shape (Tuple[int, int]):
Shape of the 2D array to be convolved.
k (np.ndarray):
2D kernel to convolve with.
mode (str):
Convolution mode. One of \n
* ``'full'``
* ``'same'``
* ``'valid'`` \n
See ``scipy.signal.convolve2d`` for details. (Default is ``'same'``)
dtype (Optional[np.dtype]):
Data type for the Toeplitz matrix. Ideally matches the dtype of
the input array. If ``None``, the dtype of ``k`` is used.
(Default is ``None``)
Attributes:
k (np.ndarray):
Flipped copy of the kernel used internally.
mode (str):
Convolution mode set in ``__init__``.
x_shape (Tuple[int, int]):
Stored input shape.
dtype (np.dtype):
Data type of the Toeplitz matrix.
so (Tuple[int, int]):
Output array size before cropping.
dt (scipy.sparse.csr_matrix):
The double-block Toeplitz matrix in sparse CSR form.
Example:
.. highlight:: python
.. code-block:: python
conv = Toeplitz_convolution2d(x_shape=x.shape, k=kernel, mode='same')
y = conv(x)
"""
def __init__(
self,
x_shape,
k,
mode='same',
dtype=None,
):
"""Builds and caches the double-block Toeplitz matrix for the given shape and kernel."""
self.k = k = np.flipud(k.copy())
self.mode = mode
self.x_shape = x_shape
self.dtype = k.dtype if dtype is None else dtype
if mode == 'valid':
assert x_shape[0] >= k.shape[0] and x_shape[1] >= k.shape[1], "x must be larger than k in both dimensions for mode='valid'"
self.so = so = size_output_array = ( (k.shape[0] + x_shape[0] -1), (k.shape[1] + x_shape[1] -1)) ## 'size out' is the size of the output array
## make the toeplitz matrices
t = toeplitz_matrices = [scipy.sparse.diags(
diagonals=np.ones((k.shape[1], x_shape[1]), dtype=self.dtype) * k_i[::-1][:,None],
offsets=np.arange(-k.shape[1]+1, 1),
shape=(so[1], x_shape[1]),
dtype=self.dtype,
) for k_i in k[::-1]] ## make the toeplitz matrices for the rows of the kernel
tc = toeplitz_concatenated = scipy.sparse.vstack(t + [scipy.sparse.dia_matrix((t[0].shape), dtype=self.dtype)]*(x_shape[0]-1)) ## add empty matrices to the bottom of the block due to padding, then concatenate
## make the double block toeplitz matrix
self.dt = double_toeplitz = scipy.sparse.hstack([self._roll_sparse(
x=tc,
shift=(ii>0)*ii*(so[1]) ## shift the blocks by the size of the output array
) for ii in range(x_shape[0])]).tocsr()
def __call__(
self,
x,
batching=True,
mode=None,
):
"""
Convolves an input array (or batch of arrays) with the stored kernel.
Args:
x (Union[np.ndarray, scipy.sparse.csc_matrix, scipy.sparse.csr_matrix]):
Input array(s) to convolve. When ``batching`` is ``False``,
a single 2D array of shape *(self.x_shape[0],
self.x_shape[1])*. When ``batching`` is ``True``, multiple
2D arrays that have been flattened into row vectors with
shape *(n_arrays, self.x_shape[0] * self.x_shape[1])*.
batching (bool):
If ``False``, ``x`` is a single 2D array. If ``True``, ``x``
is a 2D array whose rows are flattened 2D inputs. (Default
is ``True``)
mode (Optional[str]):
Convolution mode. One of \n
* ``'full'``
* ``'same'``
* ``'valid'`` \n
Overrides the mode set in ``__init__`` when provided.
(Default is ``None``)
Returns:
(Union[np.ndarray, scipy.sparse.csr_matrix]):
out (Union[np.ndarray, scipy.sparse.csr_matrix]):
When ``batching`` is ``True``, multiple convolved 2D
arrays flattened into row vectors with shape
*(n_arrays, height * width)*. When ``batching`` is
``False``, a single convolved 2D array with shape
*(height, width)*.
"""
if mode is None:
mode = self.mode ## use the mode that was set in the init if not specified
issparse = scipy.sparse.issparse(x)
if batching:
x_v = x.T ## transpose into column vectors
else:
x_v = x.reshape(-1, 1) ## reshape 2D array into a column vector
if issparse:
x_v = x_v.tocsc()
out_v = self.dt @ x_v ## if sparse, then 'out_v' will be a csc matrix
## crop the output to the correct size
if mode == 'full':
p_t = 0
p_b = self.so[0]+1
p_l = 0
p_r = self.so[1]+1
if mode == 'same':
p_t = (self.k.shape[0]-1)//2
p_b = -(self.k.shape[0]-1)//2
p_l = (self.k.shape[1]-1)//2
p_r = -(self.k.shape[1]-1)//2
p_b = self.x_shape[0]+1 if p_b==0 else p_b
p_r = self.x_shape[1]+1 if p_r==0 else p_r
if mode == 'valid':
p_t = (self.k.shape[0]-1)
p_b = -(self.k.shape[0]-1)
p_l = (self.k.shape[1]-1)
p_r = -(self.k.shape[1]-1)
p_b = self.x_shape[0]+1 if p_b==0 else p_b
p_r = self.x_shape[1]+1 if p_r==0 else p_r
if batching:
idx_crop = np.zeros((self.so), dtype=np.bool_)
idx_crop[p_t:p_b, p_l:p_r] = True
idx_crop = idx_crop.reshape(-1)
out = out_v[idx_crop,:].T
else:
if issparse:
out = out_v.reshape((self.so)).tocsc()[p_t:p_b, p_l:p_r]
else:
out = out_v.reshape((self.so))[p_t:p_b, p_l:p_r] ## reshape back into 2D array and crop
return out
def _roll_sparse(
self,
x,
shift,
):
"""
Shifts the row indices of a sparse matrix by ``shift`` (no wrap).
Args:
x (scipy.sparse.coo_matrix):
Sparse matrix to shift.
shift (int):
Row offset to add to ``x.row``.
Returns:
(scipy.sparse.coo_matrix):
out (scipy.sparse.coo_matrix):
Copy of ``x`` with ``row`` shifted by ``shift``.
"""
out = x.copy()
out.row += shift
return out
[docs]
def cosine_kernel_2D(center=(5,5), image_size=(11,11), width=5):
"""
Generates a 2D radial cosine kernel. RH 2021
Args:
center (Tuple[int, int]):
``(x, y)`` peak position, zero-indexed. Set the second element
to ``0`` to obtain a 1D kernel. (Default is ``(5, 5)``)
image_size (Tuple[int, int]):
``(width, height)`` of the output kernel. Set the second
element to ``0`` for a 1D kernel. (Default is ``(11, 11)``)
width (float):
Full width of one cycle of the cosine. (Default is ``5``)
Returns:
(np.ndarray):
k_cos (np.ndarray):
Cosine kernel. shape: *(image_size[0], image_size[1])*.
"""
x, y = np.meshgrid(range(image_size[1]), range(image_size[0])) # note dim 1:X and dim 2:Y
dist = np.sqrt((y - int(center[1])) ** 2 + (x - int(center[0])) ** 2)
dist_scaled = (dist/(width/2))*np.pi
dist_scaled[np.abs(dist_scaled > np.pi)] = np.pi
k_cos = (np.cos(dist_scaled) + 1)/2
return k_cos
##########################################################################################################################################
############################################################# MATH FUNCTIONS #############################################################
##########################################################################################################################################
[docs]
def bounded_logspace(start, stop, num,):
"""
Logarithmically spaced values between ``start`` and ``stop`` (inclusive). RH 2022
Args:
start (float):
First value in the output array.
stop (float):
Last value in the output array.
num (int):
Number of values in the output array.
Returns:
(np.ndarray):
output (np.ndarray):
Logarithmically spaced values bounded by ``start`` and
``stop``. shape: *(num,)*.
"""
exp = 2 ## doesn't matter what this is, just needs to be > 1
return exp ** np.linspace(np.log(start)/np.log(exp), np.log(stop)/np.log(exp), num, endpoint=True)
[docs]
def gaussian(x=None, mu=0, sig=1, plot_pref=False):
"""
Evaluates a normalized 1D Gaussian function on a grid. RH 2021
Args:
x (Optional[np.ndarray]):
1D array of x positions. If ``None``, a default range covering
five sigma on each side is used. (Default is ``None``)
mu (float):
Mean of the Gaussian. (Default is ``0``)
sig (float):
Standard deviation of the Gaussian. (Default is ``1``)
plot_pref (bool):
If ``True``, plots the Gaussian using matplotlib. (Default is
``False``)
Returns:
(np.ndarray):
gaus (np.ndarray):
Gaussian evaluated at each value of ``x``.
"""
import matplotlib.pyplot as plt
if x is None:
x = np.linspace(-sig*5, sig*5, sig*7, endpoint=True)
gaus = 1/(np.sqrt(2*np.pi)*sig)*np.exp((-((x-mu)/sig) **2)/2)
if plot_pref:
plt.figure()
plt.plot(x , gaus)
plt.xlabel('x')
plt.title(f'mu={mu}, sigma={sig}')
return gaus
##########################################################################################################################################
########################################################### SPECTRAL ANALYSIS ############################################################
##########################################################################################################################################
[docs]
def torch_hilbert(x, N=None, dim=0):
"""
Computes the analytic signal of ``x`` via a Hilbert transform. RH 2022
Mirrors ``scipy.signal.hilbert`` but operates on ``torch.Tensor`` inputs.
Args:
x (torch.Tensor):
Real-valued signal of arbitrary rank.
N (Optional[int]):
Number of Fourier components. If ``None``, uses ``x.shape[dim]``.
(Default is ``None``)
dim (int):
Dimension along which to transform. (Default is ``0``)
Returns:
(torch.Tensor):
xa (torch.Tensor):
Complex analytic signal with the same shape as ``x``.
"""
assert x.is_complex() == False, "x should be real"
n = x.shape[dim] if N is None else N
assert n >= 0, "N must be non-negative"
xf = torch.fft.fft(input=x, n=n, dim=dim)
m = torch.zeros(n, dtype=xf.dtype, device=xf.device)
if n % 2 == 0: ## then even
m[0] = m[n//2] = 1
m[1:n//2] = 2
else:
m[0] = 1 ## then odd
m[1:(n+1)//2] = 2
if x.ndim > 1:
ind = [np.newaxis] * x.ndim
ind[dim] = slice(None)
m = m[tuple(ind)]
return torch.fft.ifft(xf * m, dim=dim)
[docs]
def make_VQT_filters(
Fs_sample=1000,
Q_lowF=3,
Q_highF=20,
F_min=10,
F_max=400,
n_freq_bins=55,
win_size=501,
symmetry='center',
taper_asymmetric=True,
plot_pref=False
):
"""
Builds a bank of complex sinusoid filters for the VQT algorithm. RH 2022
Setting ``Q_lowF == Q_highF`` produces a Constant-Q Transform (CQT)
filter set. Differing values vary the Q factor logarithmically across
the frequency range.
Args:
Fs_sample (float):
Sampling frequency of the signal. (Default is ``1000``)
Q_lowF (float):
Q factor for the lowest frequency. (Default is ``3``)
Q_highF (float):
Q factor for the highest frequency. (Default is ``20``)
F_min (float):
Lowest frequency. (Default is ``10``)
F_max (float):
Highest frequency (inclusive). (Default is ``400``)
n_freq_bins (int):
Number of frequency bins. (Default is ``55``)
win_size (int):
Window size in samples. Must be odd. (Default is ``501``)
symmetry (str):
Window symmetry. One of \n
* ``'center'``: symmetric / two-sided window.
* ``'left'``: one-sided window, only the left half is nonzero.
* ``'right'``: one-sided window, only the right half is nonzero. \n
(Default is ``'center'``)
taper_asymmetric (bool):
If ``True`` and ``symmetry != 'center'``, the center sample of
the window is multiplied by 0.5 to taper the discontinuity.
(Default is ``True``)
plot_pref (bool):
If ``True``, plots the filters and windows. (Default is
``False``)
Returns:
(tuple): tuple containing:
filts_complex (torch.Tensor):
Complex sinusoid filters. shape: *(n_freq_bins, win_size)*.
freqs (np.ndarray):
Filter center frequencies. shape: *(n_freq_bins,)*.
wins (torch.Tensor):
Gaussian window for each filter. shape: *(n_freq_bins, win_size)*.
"""
assert win_size%2==1, "RH Error: win_size should be an odd integer"
## Make frequencies. Use a geometric spacing.
freqs = np.geomspace(
start=F_min,
stop=F_max,
num=n_freq_bins,
endpoint=True,
dtype=np.float32,
)
periods = 1 / freqs
periods_inSamples = Fs_sample * periods
## Make sigmas for gaussian windows. Use a geometric spacing.
sigma_all = np.geomspace(
start=Q_lowF,
stop=Q_highF,
num=n_freq_bins,
endpoint=True,
dtype=np.float32,
)
sigma_all = sigma_all * periods_inSamples / 4
## Make windows
### Make windows gaussian
wins = torch.stack([gaussian(torch.arange(-win_size//2, win_size//2), 0, sig=sigma) for sigma in sigma_all])
### Make windows symmetric or asymmetric
if symmetry=='center':
pass
else:
heaviside = (torch.arange(win_size) <= win_size//2).float()
if symmetry=='left':
pass
elif symmetry=='right':
heaviside = torch.flip(heaviside, dims=[0])
else:
raise ValueError("symmetry must be 'center', 'left', or 'right'")
wins *= heaviside
### Taper the center of the window by multiplying center sample of window by 0.5
if taper_asymmetric:
wins[:, win_size//2] = wins[:, win_size//2] * 0.5
filts = torch.stack([torch.cos(torch.linspace(-np.pi, np.pi, win_size) * freq * (win_size/Fs_sample)) * win for freq, win in zip(freqs, wins)], dim=0)
filts_complex = torch_hilbert(filts.T, dim=0).T
## Normalize filters to have unit magnitude
filts_complex = filts_complex / torch.sum(torch.abs(filts_complex), dim=1, keepdims=True)
## Plot
if plot_pref:
import matplotlib.pyplot as plt
plt.figure()
plt.plot(freqs)
plt.xlabel('filter num')
plt.ylabel('frequency (Hz)')
plt.figure()
plt.imshow(wins / torch.max(wins, 1, keepdims=True)[0], aspect='auto')
plt.title('windows (gaussian)')
plt.figure()
plt.plot(sigma_all)
plt.xlabel('filter num')
plt.ylabel('window width (sigma of gaussian)')
plt.figure()
plt.imshow(torch.real(filts_complex) / torch.max(torch.real(filts_complex), 1, keepdims=True)[0], aspect='auto', cmap='bwr', vmin=-1, vmax=1)
plt.title('filters (real component)')
worN=win_size*4
filts_freq = np.array([scipy.signal.freqz(
b=filt,
fs=Fs_sample,
worN=worN,
)[1] for filt in filts_complex])
filts_freq_xAxis = scipy.signal.freqz(
b=filts_complex[0],
worN=worN,
fs=Fs_sample
)[0]
plt.figure()
plt.plot(filts_freq_xAxis, np.abs(filts_freq.T));
plt.xscale('log')
plt.xlabel('frequency (Hz)')
plt.ylabel('magnitude')
return filts_complex, freqs, wins
[docs]
class VQT():
"""
Variable Q Transform implemented with PyTorch. RH 2022
Differs from librosa / nnAudio: this implementation does not iterate
lowpass filtering. Instead it convolves a fixed set of complex filters,
optionally returns the envelope via Hilbert transform, and downsamples.
Gradients propagate through the transform, and computation can run on
GPU. ``Q`` is the quality factor, roughly the number of cycles inside
four sigma (95%) of a Gaussian window.
Args:
Fs_sample (float):
Sampling frequency of the signal. (Default is ``1000``)
Q_lowF (float):
Q factor for the lowest frequency. (Default is ``3``)
Q_highF (float):
Q factor for the highest frequency. (Default is ``20``)
F_min (float):
Lowest frequency. (Default is ``10``)
F_max (float):
Highest frequency. (Default is ``400``)
n_freq_bins (int):
Number of frequency bins. (Default is ``55``)
win_size (int):
Window size in samples. Must be odd. (Default is ``501``)
symmetry (str):
Window symmetry passed through to ``make_VQT_filters``. One of \n
* ``'center'``
* ``'left'``
* ``'right'`` \n
(Default is ``'center'``)
taper_asymmetric (bool):
If ``True`` and ``symmetry != 'center'``, the center sample of
the window is multiplied by 0.5. (Default is ``True``)
downsample_factor (int):
Time-downsampling factor. The input is zero-padded to be a
multiple of this value. (Default is ``4``)
padding (str):
Convolution padding. ``'same'`` pads to keep output length
equal to input length; ``'valid'`` does not pad. (Default is
``'valid'``)
DEVICE_compute (str):
Device used for computation. (Default is ``'cpu'``)
DEVICE_return (str):
Device on which results are returned. (Default is ``'cpu'``)
batch_size (int):
Number of signals processed per batch. Reduce when out of
memory. (Default is ``1000``)
return_complex (bool):
If ``True``, returns the complex-valued transform; otherwise
returns its absolute value (envelope). ``downsample_factor``
must be ``1`` when ``True``. (Default is ``False``)
filters (Optional[torch.Tensor]):
Pre-built complex sinusoid filters. shape: *(n_freq_bins,
win_size)*. If ``None``, ``make_VQT_filters`` is called.
(Default is ``None``)
plot_pref (bool):
If ``True``, plots the filters. (Default is ``False``)
progressBar (bool):
If ``True``, displays a tqdm progress bar during ``__call__``.
(Default is ``True``)
Attributes:
filters (torch.Tensor):
Complex sinusoid filters used for convolution.
freqs (np.ndarray):
Filter center frequencies (only when filters were generated
internally).
wins (torch.Tensor):
Gaussian windows for each filter (only when filters were
generated internally).
using_custom_filters (bool):
``True`` if filters were supplied by the caller.
"""
def __init__(
self,
Fs_sample=1000,
Q_lowF=3,
Q_highF=20,
F_min=10,
F_max=400,
n_freq_bins=55,
win_size=501,
symmetry='center',
taper_asymmetric=True,
downsample_factor=4,
padding='valid',
DEVICE_compute='cpu',
DEVICE_return='cpu',
batch_size=1000,
return_complex=False,
filters=None,
plot_pref=False,
progressBar=True,
):
"""Builds filters (or accepts pre-built ones) and stores all transform parameters."""
## Prepare filters
if filters is not None:
## Use provided filters
self.using_custom_filters = True
self.filters = filters
else:
## Make new filters
self.using_custom_filters = False
self.filters, self.freqs, self.wins = make_VQT_filters(
Fs_sample=Fs_sample,
Q_lowF=Q_lowF,
Q_highF=Q_highF,
F_min=F_min,
F_max=F_max,
n_freq_bins=n_freq_bins,
win_size=win_size,
symmetry=symmetry,
taper_asymmetric=taper_asymmetric,
plot_pref=plot_pref,
)
## Gather parameters from arguments
self.Fs_sample, self.Q_lowF, self.Q_highF, self.F_min, self.F_max, self.n_freq_bins, self.win_size, self.downsample_factor, self.padding, self.DEVICE_compute, \
self.DEVICE_return, self.batch_size, self.return_complex, self.plot_pref, self.progressBar = \
Fs_sample, Q_lowF, Q_highF, F_min, F_max, n_freq_bins, win_size, downsample_factor, padding, DEVICE_compute, DEVICE_return, batch_size, return_complex, plot_pref, progressBar
def _helper_ds(self, X: torch.Tensor, ds_factor: int=4, return_complex: bool=False):
"""
Downsamples ``X`` along the last dimension via average pooling.
Args:
X (torch.Tensor):
Input tensor (real or complex).
ds_factor (int):
Downsampling factor. ``1`` returns ``X`` unchanged.
(Default is ``4``)
return_complex (bool):
If ``True``, treats ``X`` as complex and pools the real and
imaginary parts separately. (Default is ``False``)
Returns:
(torch.Tensor):
X_ds (torch.Tensor):
Downsampled tensor.
"""
if ds_factor == 1:
return X
elif return_complex == False:
return torch.nn.functional.avg_pool1d(X, kernel_size=[int(ds_factor)], stride=ds_factor, ceil_mode=True)
elif return_complex == True:
## Unfortunately, torch.nn.functional.avg_pool1d does not support complex numbers. So we have to split it up.
### Split X, shape: (batch_size, n_freq_bins, n_samples) into real and imaginary parts, shape: (batch_size, n_freq_bins, n_samples, 2)
Y = torch.view_as_real(X)
### Downsample each part separately, then stack them and make them complex again.
Z = torch.view_as_complex(torch.stack([torch.nn.functional.avg_pool1d(y, kernel_size=[int(ds_factor)], stride=ds_factor, ceil_mode=True) for y in [Y[...,0], Y[...,1]]], dim=-1))
return Z
def _helper_conv(self, arr, filters, take_abs, DEVICE):
"""
Convolves a batch of signals with complex filters using two ``conv1d`` calls.
Args:
arr (torch.Tensor):
Real input signals. shape: *(batch, n_samples)*.
filters (torch.Tensor):
Complex filter bank. shape: *(n_freq_bins, win_size)*.
take_abs (bool):
If ``True``, returns the magnitude of the result.
DEVICE (str):
Device used for the convolution.
Returns:
(torch.Tensor):
out (torch.Tensor):
Filtered output, complex if ``take_abs`` is ``False``.
"""
out = torch.complex(
torch.nn.functional.conv1d(input=arr.to(DEVICE)[:,None,:], weight=torch.real(filters.T).to(DEVICE).T[:,None,:], padding=self.padding),
torch.nn.functional.conv1d(input=arr.to(DEVICE)[:,None,:], weight=-torch.imag(filters.T).to(DEVICE).T[:,None,:], padding=self.padding)
)
if take_abs:
return torch.abs(out)
else:
return out
def __call__(self, X):
"""
Computes the variable-Q spectrogram of ``X``.
Args:
X (torch.Tensor):
Input signal. shape: *(n_channels, n_samples)*.
Returns:
(tuple): tuple containing:
specs (torch.Tensor):
Spectrogram. shape: *(n_channels, n_freq_bins, n_samples_ds)*.
x_axis (torch.Tensor):
New x-axis for the spectrogram in samples. Divide by
``self.Fs_sample`` to get time.
freqs (np.ndarray):
Frequencies corresponding to the spectrogram bins.
"""
if type(X) is not torch.Tensor:
X = torch.as_tensor(X, dtype=torch.float32, device=self.DEVICE_compute)
if X.ndim==1:
X = X[None,:]
## Make iterator for batches
batches = make_batches(X, batch_size=self.batch_size, length=X.shape[0])
## Make spectrograms
specs = [self._helper_ds(
X=self._helper_conv(
arr=arr,
filters=self.filters,
take_abs=(self.return_complex==False),
DEVICE=self.DEVICE_compute
),
ds_factor=self.downsample_factor,
return_complex=self.return_complex,
).to(self.DEVICE_return) for arr in tqdm(batches, disable=(self.progressBar==False), leave=True, total=int(np.ceil(X.shape[0]/self.batch_size)))]
specs = torch.cat(specs, dim=0)
## Make x_axis
x_axis = torch.nn.functional.avg_pool1d(
torch.nn.functional.conv1d(
input=torch.arange(0, X.shape[-1], dtype=torch.float32)[None,None,:],
weight=torch.ones(1,1,self.filters.shape[-1], dtype=torch.float32) / self.filters.shape[-1],
padding=self.padding
),
kernel_size=[int(self.downsample_factor)],
stride=self.downsample_factor, ceil_mode=True,
).squeeze()
return specs, x_axis, self.freqs
def __repr__(self):
"""Returns a debug string summarizing the VQT configuration."""
if self.using_custom_filters:
return f"VQT with custom filters"
else:
return f"VQT object with parameters: Fs_sample={self.Fs_sample}, Q_lowF={self.Q_lowF}, Q_highF={self.Q_highF}, F_min={self.F_min}, F_max={self.F_max}, n_freq_bins={self.n_freq_bins}, win_size={self.win_size}, downsample_factor={self.downsample_factor}, DEVICE_compute={self.DEVICE_compute}, DEVICE_return={self.DEVICE_return}, batch_size={self.batch_size}, return_complex={self.return_complex}, plot_pref={self.plot_pref}"
[docs]
def generate_multiphasic_sinewave(
n_samples: int = 10000,
n_periods: float = 1.0,
n_waves: int = 3,
return_x: bool = False,
return_phases: bool = False,
):
"""
Generates ``n_waves`` cosine waves with evenly spaced phase offsets. RH 2024
Args:
n_samples (int):
Number of samples per wave. (Default is ``10000``)
n_periods (float):
Number of full periods spanned by ``n_samples``. (Default is
``1.0``)
n_waves (int):
Number of phase-shifted sine waves to return. (Default is ``3``)
return_x (bool):
If ``True``, also returns the x positions. (Default is ``False``)
return_phases (bool):
If ``True``, also returns the per-wave phase arrays. (Default is
``False``)
Returns:
(Union[np.ndarray, tuple]):
output (Union[np.ndarray, tuple]):
Combination depending on ``return_x`` / ``return_phases``: \n
* ``waves`` (np.ndarray): generated cosine waves.
* ``x`` (np.ndarray): x positions, if ``return_x``.
* ``phases`` (np.ndarray): per-wave phases, if
``return_phases``.
"""
x = np.linspace(0, n_periods * np.pi*2, n_samples)
phases = np.stack([
x - ii * np.pi * (2 / n_waves) for ii in range(n_waves)
])
waves = np.cos(phases)
if return_x and return_phases:
return waves, x, phases
elif return_x:
return waves, x
elif return_phases:
return waves, phases
else:
return waves
############################################################################################################################################################################
#################################################################### TORCH HELPERS #########################################################################################
############################################################################################################################################################################
[docs]
def set_device(
use_GPU: bool = True,
device_num: int = 0,
verbose: bool = True
) -> str:
"""
Sets the device for PyTorch. If a GPU is available and **use_GPU** is
``True``, it will be set as the device. Otherwise, the CPU will be set as
the device.
RH 2022
Args:
use_GPU (bool):
Determines if the GPU should be utilized: \n
* ``True``: the function will attempt to use the GPU if a GPU is
not available.
* ``False``: the function will use the CPU. \n
(Default is ``True``)
device_num (int):
Specifies the index of the GPU to use. (Default is ``0``)
verbose (bool):
Determines whether to print the device information. \n
* ``True``: the function will print out the device information.
\n
(Default is ``True``)
Returns:
(str):
device (str):
A string specifying the device, either *"cpu"* or
*"cuda:<device_num>"*.
"""
if use_GPU:
print(f'devices available: {[torch.cuda.get_device_properties(ii) for ii in range(torch.cuda.device_count())]}') if verbose else None
device = f"cuda:{device_num}" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("no GPU available. Using CPU.") if verbose else None
else:
print(f"Using device: '{device}': {torch.cuda.get_device_properties(device_num)}") if verbose else None
else:
device = "cpu"
print(f"device: '{device}'") if verbose else None
return device
[docs]
def tensorly_cp_to_device(cp, device='cpu'):
"""
Moves the factors and weights of a tensorly CP object to ``device``. RH 2024
Args:
cp (object):
Tensorly ``CP`` tensor (``tensorly.cp_tensor.CP``).
device (str):
Target device for ``cp.factors`` and ``cp.weights``. (Default
is ``'cpu'``)
Returns:
(object):
cp (object):
Same CP object with all factors and weights on ``device``.
"""
for ii in range(len(cp.factors)):
cp.factors[ii] = cp.factors[ii].to(device)
cp.weights = cp.weights.to(device)
return cp
############################################################################################################################################################################
################################################################## PLOTTING HELPERS ########################################################################################
############################################################################################################################################################################
[docs]
def simple_cmap(
colors=[
[1,0,0],
[1,0.6,0],
[0.9,0.9,0],
[0.6,1,0],
[0,1,0],
[0,1,0.6],
[0,0.8,0.8],
[0,0.6,1],
[0,0,1],
[0.6,0,1],
[0.8,0,0.8],
[1,0,0.6],
],
under=[0,0,0],
over=[0.5,0.5,0.5],
bad=[0.9,0.9,0.9],
name='none'):
"""
Builds a ``LinearSegmentedColormap`` from a sequence of RGB values.
Adapted from
https://gist.github.com/ahwillia/3e022cdd1fe82627cbf1f2e9e2ad80a7e.
Args:
colors (list):
Sequence of RGB triples (or matplotlib color strings) defining
the colormap stops.
under (list):
RGB color used for values below the colormap range. (Default is
``[0, 0, 0]``)
over (list):
RGB color used for values above the colormap range. (Default is
``[0.5, 0.5, 0.5]``)
bad (list):
RGB color used for masked / NaN values. (Default is
``[0.9, 0.9, 0.9]``)
name (str):
Colormap name. (Default is ``'none'``)
Returns:
(matplotlib.colors.LinearSegmentedColormap):
cmap (matplotlib.colors.LinearSegmentedColormap):
Resulting linear-segmented colormap.
Example:
.. highlight:: python
.. code-block:: python
cmap = simple_cmap([(1, 1, 1), (1, 0, 0)]) # white to red
cmap = simple_cmap(['w', 'r']) # white to red
cmap = simple_cmap(['r', 'b', 'r']) # red to blue to red
"""
from matplotlib.colors import LinearSegmentedColormap, colorConverter
# check inputs
n_colors = len(colors)
if n_colors <= 1:
raise ValueError('Must specify at least two colors')
# convert colors to rgb
colors = [colorConverter.to_rgb(c) for c in colors]
# set up colormap
r, g, b = colors[0]
cdict = {'red': [(0.0, r, r)], 'green': [(0.0, g, g)], 'blue': [(0.0, b, b)]}
for i, (r, g, b) in enumerate(colors[1:]):
idx = (i+1) / (n_colors-1)
cdict['red'].append((idx, r, r))
cdict['green'].append((idx, g, g))
cdict['blue'].append((idx, b, b))
cmap = LinearSegmentedColormap(name, {k: tuple(v) for k, v in cdict.items()})
cmap.set_bad(bad)
cmap.set_over(over)
cmap.set_under(under)
return cmap
[docs]
class Cmap_conjunctive:
"""
Combines multiple colormaps by multiplying their per-channel outputs. RH 2022
Args:
cmaps (list):
List of ``matplotlib.colors.LinearSegmentedColormap`` objects to
combine.
dtype_out (np.dtype):
Data type of the returned color array. (Default is ``int``)
normalize (bool):
If ``True``, normalizes each input column to ``[0, 1]`` before
applying the colormaps. (Default is ``False``)
normalization_range (list):
``[lo, hi]`` to which the output is rescaled. (Default is
``[0, 255]``)
name (str):
Name of the resulting colormap. (Default is ``'cmap_conjunctive'``)
Attributes:
cmaps (list):
Stored input colormaps.
n_cmaps (int):
Number of input colormaps.
fn_conj_cmap (Callable):
Function that maps an input array of shape *(n_samples, n_cmaps)*
to the elementwise product of each colormap's output.
"""
def __init__(
self,
cmaps,
dtype_out=int,
normalize=False,
normalization_range=[0,255],
name='cmap_conjunctive',
):
"""Validates the colormaps and stores normalization parameters."""
import matplotlib
## Check inputs
assert isinstance(cmaps, list), 'cmaps must be a list.'
assert all([isinstance(cmap, matplotlib.colors.LinearSegmentedColormap) for cmap in cmaps]), 'All elements of cmaps must be matplotlib.colors.LinearSegmentedColormap objects.'
self.cmaps = cmaps
self.dtype_out = dtype_out
self.name = name
self.normalize = normalize
self.normalization_range = normalization_range
self.n_cmaps = len(self.cmaps)
self.fn_conj_cmap = lambda x: np.prod(np.stack([cmap(x_i) for cmap,x_i in zip(self.cmaps, x.T)], axis=0), axis=0)
def __call__(self, x):
"""
Applies the conjunctive colormap to the input array.
Args:
x (np.ndarray):
Input values. shape: *(n_samples, n_cmaps)*. When
``self.normalize`` is ``True``, each column is independently
rescaled to ``[0, 1]``.
Returns:
(np.ndarray):
colors (np.ndarray):
RGBA values rescaled to ``self.normalization_range``.
shape: *(n_samples, 4)*, dtype: ``self.dtype_out``.
"""
assert isinstance(x, np.ndarray), 'x must be a numpy array of shape (n_samples, n_cmaps).'
## Make array 2D
if x.ndim == 1:
x = x[None,:]
assert x.shape[1] == self.n_cmaps, 'x.shape[1] must match the number of cmaps.'
## Normalize x
if self.normalize:
assert x.shape[1] > 1, 'x must have more than one row to normalize.'
x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True))
## Get colors
colors = self.fn_conj_cmap(x)
colors = (colors * (self.normalization_range[1] - self.normalization_range[0]) + self.normalization_range[0]).astype(self.dtype_out)
return colors
[docs]
class Colorwheel:
"""
2D colorwheel colormap (angle + magnitude) for cyclic data. RH 2024
Useful for visualizing complex/polar values, optical flow, and other
cyclic data.
Args:
rotation (float):
Rotation of the colorwheel in radians. (Default is ``0.0``)
saturation (float):
Color saturation in *[0, 1]*. (Default is ``1.0``)
center (int):
Color value at the center of the wheel. (Default is ``0``)
radius (int):
Maximum color value at the rim of the wheel. (Default is ``255``)
dtype (np.dtype):
Output dtype of the color array. (Default is ``np.uint8``)
bit_depth (int):
Number of samples used to discretize the wheel:
``2 ** bit_depth``. (Default is ``16``)
exponent (float):
Exponent applied to each base color wave to sharpen transitions.
(Default is ``10``)
normalize (bool):
If ``True``, normalizes the per-angle color sum to ``1`` so that
color intensity is uniform around the wheel. (Default is
``True``)
colors (List[Union[List, Tuple]]):
Sequence of base RGB triples used to build the rainbow.
(Default is a 12-color rainbow)
Attributes:
fn_interp (Callable):
Interpolator that maps an angle (radians) to per-base-color
weights along the wheel.
colors (np.ndarray):
Array of base colors. shape: *(n_colors, 3)*.
"""
def __init__(
self,
rotation: float = 0.0,
saturation: float = 1.0,
center: int = 0,
radius: int = 255,
dtype: np.dtype = np.uint8,
bit_depth: int = 16,
exponent: float = 10,
normalize: bool = True,
colors: List[Union[List, Tuple]] = [
[1 , 0 , 0 ],
[1 , 0.5, 0 ],
[1 , 1 , 0 ],
[0.5, 1 , 0 ],
[0 , 1 , 0 ],
[0 , 1 , 0.5],
[0 , 1 , 1 ],
[0 , 0.5, 1 ],
[0 , 0 , 1 ],
[0.5, 0 , 1 ],
[1 , 0 , 1 ],
[1 , 0 , 0.5],
],
):
"""Builds the angle-to-color interpolator from the rainbow base waves."""
import scipy.interpolate
import scipy.special
self.rotation = rotation
self.saturation = saturation
self.center = center
self.radius = radius
self.dtype = dtype
self.bit_depth = bit_depth
self.exponent = exponent
self.normalize = normalize
self.colors = np.array(colors)
# Make a rainbow colorwheel
# Create 3 single cosine waves centered at 0, 120, and 240 degrees spanning 120 degrees each
import scipy.signal
waves, x = generate_multiphasic_sinewave(
n_samples=int(2**bit_depth),
n_periods=1 + 2 / int(2**bit_depth),
n_waves=len(colors),
return_x=True,
)
waves = ((waves + 1).astype(np.float64) / 2) ** exponent
waves = (waves - waves.min()) / (waves.max() - waves.min())
if normalize:
if waves.shape[0] == 1:
waves = np.ones_like(waves)
else:
waves = waves / np.sum(waves, axis=0, keepdims=True)
waves = (waves * (radius - (1-saturation) * radius) + (1-saturation) * radius)
waves = np.roll(waves, int(rotation * 2**bit_depth / (2*np.pi)), axis=1)
# Create interpolation function
self.fn_interp = scipy.interpolate.interp1d(
x=x,
y=waves,
kind='linear',
axis=1,
bounds_error=False,
fill_value='extrapolate',
)
def __call__(
self,
angles: Union[np.ndarray, List[Union[float, int]], Tuple[Union[float, int]], float, int],
magnitudes: Optional[Union[np.ndarray, List[Union[float, int]], Tuple[Union[float, int]], float, int]] = None,
normalize_magnitudes: bool = True,
) -> np.ndarray:
"""
Returns RGB colors for the given angles and (optional) magnitudes. RH 2024
Args:
angles (Union[np.ndarray, List[Union[float, int]], Tuple[Union[float, int]], float, int]):
Angles in radians. shape: *(n_samples,)*.
magnitudes (Optional[Union[np.ndarray, List[Union[float, int]], Tuple[Union[float, int]], float, int]]):
Magnitudes for each angle. shape: *(n_samples,)*. If
``None``, treated as ones. (Default is ``None``)
normalize_magnitudes (bool):
If ``True``, applies min-max normalization to ``magnitudes``
before use. (Default is ``True``)
Returns:
(np.ndarray):
rgb (np.ndarray):
RGB values. shape: *(n_samples, 3)*, dtype:
``self.dtype``.
"""
# Check inputs
def check_input(arg):
if isinstance(arg, (float, int)):
arg = np.array([angles])
elif isinstance(arg, (list, tuple)):
arg = np.array(arg)
elif not isinstance(angles, np.ndarray):
raise ValueError("angles and magnitudes must be a numpy array, list, or tuple of ints or floats.")
return arg
angles = check_input(angles)
magnitudes = check_input(magnitudes) if magnitudes is not None else None
# Normalize the magnitudes
if magnitudes is not None:
if normalize_magnitudes:
magnitudes = (magnitudes - np.min(magnitudes)) / (np.max(magnitudes) - np.min(magnitudes))
magnitudes = np.clip(magnitudes, 0, 1)
else:
magnitudes = np.ones_like(angles)
# Get the saturated color by interpolating the colorwheel
sample_colors = self.fn_interp(angles % (2*np.pi))
# Clip the colors
sample_colors = np.clip(sample_colors, 0, self.radius)
# Project to RGB
rgb = self.colors.T @ sample_colors
# Apply the saturation
rgb = rgb * magnitudes[None, :] + (1 - magnitudes)[None, :] * self.center
# Convert to dtype
## Clip the values to the dtype range
if rgb.dtype != self.dtype:
vmax, vmin = np.iinfo(self.dtype).max, np.iinfo(self.dtype).min
rgb = np.clip(rgb, vmin, vmax)
rgb = rgb.astype(self.dtype)
return rgb.T
[docs]
def plot_colorwheel(self, n_samples: int = 100000):
"""
Renders the colorwheel as a 2D image and displays it. RH 2024
Args:
n_samples (int):
Approximate number of samples used to build the wheel; the
actual grid is ``ceil(sqrt(n_samples))`` per side. (Default
is ``100000``)
"""
import matplotlib.pyplot as plt
l = int(np.ceil(n_samples**0.5))
grid = np.meshgrid(np.linspace(-1, 1, l), np.linspace(-1, 1, l), indexing='xy')
grid = grid[0] + 1j*grid[1]
grid = grid.reshape(-1)
angles = np.angle(grid)
magnitudes = np.abs(grid)
mask = magnitudes > 1
magnitudes = np.clip(magnitudes, 0, 1)
colors = self(angles, magnitudes)
colors = np.clip(colors, 0, 255)
im = np.zeros((l, l, 3), dtype=self.dtype)
im[tuple(np.meshgrid(range(l), range(l), indexing='ij'))] = colors.reshape(im.shape[:2] + (3,))
fig, axs = plt.subplots(2, 1, figsize=(5, 10))
axs[0].imshow(im)
x = np.linspace(0, 2*np.pi, l)
[axs[1].plot(x, v, color=c) for v, c in zip(self(x).T, self.colors)]
axs[1].set_ylabel('Channel magnitude')
axs[1].set_xlabel('Phase (rads)')
def __repr__(self) -> str:
"""Returns a debug string with the colorwheel construction parameters."""
return (f"ColorwheelColormap(rotation={self.rotation}, "
f"saturation={self.saturation}, center={self.center}, "
f"radius={self.radius}, dtype={self.dtype}, "
f"bit_depth={self.bit_depth}, exponent={self.exponent}, "
f"normalize={self.normalize})")
##########################################################################################################################################
############################################################ IMAGE PROCESSING ############################################################
##########################################################################################################################################
[docs]
def clahe(im, grid_size=50, clipLimit=0, normalize=True):
"""
Applies Contrast Limited Adaptive Histogram Equalization to an image. RH 2022
Args:
im (np.ndarray):
Input image.
grid_size (int):
Tile grid size passed to ``cv2.createCLAHE``. (Default is ``50``)
clipLimit (int):
Contrast clip limit passed to ``cv2.createCLAHE``. (Default is
``0``)
normalize (bool):
If ``True``, normalizes the input to span the full 16-bit range
before applying CLAHE. (Default is ``True``)
Returns:
(np.ndarray):
im_c (np.ndarray):
CLAHE-enhanced image. dtype: *uint16*.
"""
import cv2
im_tu = (im / im.max())*(2**16) if normalize else im
im_tu = im_tu/10
clahe = cv2.createCLAHE(clipLimit=clipLimit, tileGridSize=(grid_size, grid_size))
im_c = clahe.apply(im_tu.astype(np.uint16))
return im_c
[docs]
def add_text_to_images(images, text, position=(10,10), font_size=1, color=(255,255,255), line_width=1, font=None, show=False, frameRate=30):
"""
Overlays multi-line text onto each frame using ``cv2.putText``. RH 2022
Args:
images (np.ndarray):
Frames of video or images. shape: *(n_frames, H, W, C)*.
text (List[List[str]]):
Text per frame. Outer list has one element per frame; each
inner list holds the lines of text drawn on that frame.
position (Tuple[int, int]):
``(x, y)`` position of the top-left corner of the text. (Default
is ``(10, 10)``)
font_size (int):
Font scale passed to ``cv2.putText``. (Default is ``1``)
color (Tuple[int, int, int]):
``(R, G, B)`` text color. (Default is ``(255, 255, 255)``)
line_width (int):
Line thickness passed to ``cv2.putText``. (Default is ``1``)
font (Optional[int]):
OpenCV font constant. If ``None``, uses
``cv2.FONT_HERSHEY_SIMPLEX``. (Default is ``None``)
show (bool):
If ``True``, displays each annotated frame using
``cv2.imshow``. (Default is ``False``)
frameRate (float):
Display frame rate when ``show`` is ``True``. (Default is ``30``)
Returns:
(np.ndarray):
images_with_text (np.ndarray):
Frames of video or images with text overlays applied.
"""
import cv2
import copy
if font is None:
font = cv2.FONT_HERSHEY_SIMPLEX
images_cp = copy.deepcopy(images)
for ii, im in enumerate(images_cp):
im = im[:,:,None] if im.ndim==2 else im
images_cp[ii] = im
for i_f, frame in enumerate(images_cp):
for i_t, t in enumerate(text[i_f]):
cv2.putText(frame, t, [position[0] , position[1] + i_t*font_size*30], font, font_size, color, line_width)
if show:
cv2.imshow('add_text_to_images', frame)
cv2.waitKey(int(1000/frameRate))
for ii, im in enumerate(images_cp):
im = im[:,:,0] if images[ii].ndim==2 else im
images_cp[ii] = im
if show:
cv2.destroyWindow('add_text_to_images')
return images_cp
[docs]
def mask_image_border(
im: np.ndarray,
border_outer: Optional[Union[int, Tuple[int, int, int, int]]] = None,
border_inner: Optional[int] = None,
mask_value: float = 0,
) -> np.ndarray:
"""
Masks an image within specified outer and inner borders. RH 2022
Args:
im (np.ndarray):
Input image of shape: *(height, width)* or *(height, width,
channels)*.
border_outer (Union[int, tuple[int, int, int, int], None]):
Number of pixels along the border to mask. If ``None``, the border
is not masked. If an int is provided, all borders are equally
masked. If a tuple of ints is provided, borders are masked in the
order: *(top, bottom, left, right)*. (Default is ``None``)
border_inner (int, Optional):
Number of pixels in the center to mask. Will be a square with side
length equal to this value. (Default is ``None``)
mask_value (float):
Value to replace the masked pixels with. (Default is *0*)
Returns:
(np.ndarray):
im_out (np.ndarray):
Masked output image.
"""
## Find the center of the image
height, width = im.shape[:2]
center_y = cy = int(np.floor(height/2))
center_x = cx = int(np.floor(width/2))
## Mask the center
if border_inner is not None:
## make edge_lengths
center_edge_length = cel = int(np.ceil(border_inner/2)) if border_inner is not None else 0
im[cy-cel:cy+cel, cx-cel:cx+cel] = mask_value
## Mask the border
if border_outer is not None:
## make edge_lengths
if isinstance(border_outer, int):
border_outer = (border_outer, border_outer, border_outer, border_outer)
im[:border_outer[0], :] = mask_value
im[-border_outer[1]:, :] = mask_value
im[:, :border_outer[2]] = mask_value
im[:, -border_outer[3]:] = mask_value
return im
[docs]
def warp_matrix_to_remappingIdx(
warp_matrix: Union[np.ndarray, torch.Tensor],
x: int,
y: int
) -> Union[np.ndarray, torch.Tensor]:
"""
Converts a warp matrix (2x3 or 3x3) into a 2D remapping index field. RH 2023
Args:
warp_matrix (Union[np.ndarray, torch.Tensor]):
Warp matrix of shape *(2, 3)* for affine transformations, or
*(3, 3)* for homography.
x (int):
Width of the output remapping field.
y (int):
Height of the output remapping field.
Returns:
(Union[np.ndarray, torch.Tensor]):
remapIdx (Union[np.ndarray, torch.Tensor]):
Remapping indices. shape: *(y, x, 2)*. The last axis stores
the pixel coordinate ``(x, y)`` to sample from.
"""
assert warp_matrix.shape in [(2, 3), (3, 3)], f"warp_matrix.shape {warp_matrix.shape} not recognized. Must be (2, 3) or (3, 3)"
assert isinstance(x, int) and isinstance(y, int), f"x and y must be integers"
assert x > 0 and y > 0, f"x and y must be positive"
if isinstance(warp_matrix, torch.Tensor):
stack, meshgrid, arange, hstack, ones, float32, array = torch.stack, torch.meshgrid, torch.arange, torch.hstack, torch.ones, torch.float32, torch.as_tensor
stack_partial = lambda x: stack(x, dim=0)
elif isinstance(warp_matrix, np.ndarray):
stack, meshgrid, arange, hstack, ones, float32, array = np.stack, np.meshgrid, np.arange, np.hstack, np.ones, np.float32, np.array
stack_partial = lambda x: stack(x, axis=0)
else:
raise ValueError(f"warp_matrix must be a torch.Tensor or np.ndarray")
# create the grid
mesh = stack_partial(meshgrid(arange(x, dtype=float32), arange(y, dtype=float32)))
mesh_coords = hstack((mesh.reshape(2,-1).T, ones((x*y, 1), dtype=float32)))
# warp the grid
mesh_coords_warped = (mesh_coords @ warp_matrix.T)
mesh_coords_warped = mesh_coords_warped[:, :2] / mesh_coords_warped[:, 2:3] if warp_matrix.shape == (3, 3) else mesh_coords_warped ## if homography, divide by z
# reshape the warped grid
remapIdx = mesh_coords_warped.T.reshape(2, y, x)
# permute the axes to (x, y, 2)
remapIdx = remapIdx.permute(1, 2, 0) if isinstance(warp_matrix, torch.Tensor) else remapIdx.transpose(1, 2, 0)
return remapIdx
[docs]
def remap_images(
images: Union[np.ndarray, torch.Tensor],
remappingIdx: Union[np.ndarray, torch.Tensor],
backend: str = "torch",
interpolation_method: str = 'linear',
border_mode: str = 'constant',
border_value: float = 0,
device: str = 'cpu',
) -> Union[np.ndarray, torch.Tensor]:
"""
Applies remapping indices to a set of images. Remapping indices, similar to
flow fields, describe the index of the pixel to sample from rather than the
displacement of each pixel. RH 2023
Args:
images (Union[np.ndarray, torch.Tensor]):
The images to be warped. Shapes can be *(N, C, H, W)*, *(C, H, W)*,
or *(H, W)*.
remappingIdx (Union[np.ndarray, torch.Tensor]):
The remapping indices, describing the index of the pixel to sample
from. Shape is *(H, W, 2)*.
backend (str):
The backend to use. Can be either ``'torch'`` or ``'cv2'``. (Default
is ``'torch'``)
interpolation_method (str):
The interpolation method to use. Options are ``'linear'``,
``'nearest'``, ``'cubic'``, and ``'lanczos'``. Refer to `cv2.remap`
or `torch.nn.functional.grid_sample` for more details. (Default is
``'linear'``)
border_mode (str):
The border mode to use. Options include ``'constant'``,
``'reflect'``, ``'replicate'``, and ``'wrap'``. Refer to `cv2.remap`
for more details. (Default is ``'constant'``)
border_value (float):
The border value to use. Refer to `cv2.remap` for more details.
(Default is ``0``)
device (str):
The device to use for computations. Commonly either ``'cpu'`` or
``'gpu'``. (Default is ``'cpu'``)
Returns:
(Union[np.ndarray, torch.Tensor]):
warped_images (Union[np.ndarray, torch.Tensor]):
The warped images. The shape will be the same as the input
images, which can be *(N, C, H, W)*, *(C, H, W)*, or *(H, W)*.
"""
# Check inputs
assert isinstance(images, (np.ndarray, torch.Tensor)), f"images must be a np.ndarray or torch.Tensor"
assert isinstance(remappingIdx, (np.ndarray, torch.Tensor)), f"remappingIdx must be a np.ndarray or torch.Tensor"
if images.ndim == 2:
images = images[None, None, :, :]
elif images.ndim == 3:
images = images[None, :, :, :]
elif images.ndim != 4:
raise ValueError(f"images must be a 2D, 3D, or 4D array. Got shape {images.shape}")
assert remappingIdx.ndim == 3, f"remappingIdx must be a 3D array of shape (H, W, 2). Got shape {remappingIdx.shape}"
assert images.shape[-2] == remappingIdx.shape[0], f"images H ({images.shape[-2]}) must match remappingIdx H ({remappingIdx.shape[0]})"
assert images.shape[-1] == remappingIdx.shape[1], f"images W ({images.shape[-1]}) must match remappingIdx W ({remappingIdx.shape[1]})"
# Check backend
if backend not in ["torch", "cv2"]:
raise ValueError("Invalid backend. Supported backends are 'torch' and 'cv2'.")
if backend == 'torch':
if isinstance(images, np.ndarray):
images = torch.as_tensor(images, device=device, dtype=torch.float32)
elif isinstance(images, torch.Tensor):
images = images.to(device=device).type(torch.float32)
if isinstance(remappingIdx, np.ndarray):
remappingIdx = torch.as_tensor(remappingIdx, device=device, dtype=torch.float32)
elif isinstance(remappingIdx, torch.Tensor):
remappingIdx = remappingIdx.to(device=device).type(torch.float32)
interpolation = {
'linear': 'bilinear',
'nearest': 'nearest',
'cubic': 'bicubic',
'lanczos': 'lanczos',
}[interpolation_method]
border = {
'constant': 'zeros',
'reflect': 'reflection',
'replicate': 'replication',
'wrap': 'circular',
}[border_mode]
## Convert remappingIdx to normalized grid
normgrid = cv2RemappingIdx_to_pytorchFlowField(remappingIdx)
# Apply remappingIdx
warped_images = torch.nn.functional.grid_sample(
images,
normgrid[None,...],
mode=interpolation,
padding_mode=border,
align_corners=True, ## align_corners=True is the default in cv2.remap. See documentation for details.
)
elif backend == 'cv2':
assert isinstance(images, np.ndarray), f"images must be a np.ndarray when using backend='cv2'"
assert isinstance(remappingIdx, np.ndarray), f"remappingIdx must be a np.ndarray when using backend='cv2'"
## convert to float32 if not uint8
images = images.astype(np.float32) if images.dtype != np.uint8 else images
remappingIdx = remappingIdx.astype(np.float32) if remappingIdx.dtype != np.uint8 else remappingIdx
interpolation = {
'linear': cv2.INTER_LINEAR,
'nearest': cv2.INTER_NEAREST,
'cubic': cv2.INTER_CUBIC,
'lanczos': cv2.INTER_LANCZOS4,
}[interpolation_method]
borderMode = {
'constant': cv2.BORDER_CONSTANT,
'reflect': cv2.BORDER_REFLECT,
'replicate': cv2.BORDER_REPLICATE,
'wrap': cv2.BORDER_WRAP,
}[border_mode]
# Apply remappingIdx
def remap(ims):
out = np.stack([cv2.remap(
im,
remappingIdx[..., 0],
remappingIdx[..., 1],
interpolation=interpolation,
borderMode=borderMode,
borderValue=border_value,
) for im in ims], axis=0)
return out
warped_images = np.stack([remap(im) for im in images], axis=0)
return warped_images.squeeze()
[docs]
def invert_remappingIdx(
remappingIdx: np.ndarray,
method: str = 'linear',
fill_value: Optional[float] = np.nan
) -> np.ndarray:
"""
Inverts a remapping index field.
Requires the assumption that the remapping index field is invertible or bijective/one-to-one and non-occluding.
Defined 'remap_AB' as a remapping index field that warps image A onto image B, then 'remap_BA' is the remapping index field that warps image B onto image A. This function computes 'remap_BA' given 'remap_AB'.
RH 2023
Args:
remappingIdx (np.ndarray):
An array of shape *(H, W, 2)* representing the remap field.
method (str):
Interpolation method to use. See ``scipy.interpolate.griddata``. Options are:
\n
* ``'linear'``
* ``'nearest'``
* ``'cubic'`` \n
(Default is ``'linear'``)
fill_value (Optional[float]):
Value used to fill points outside the convex hull.
(Default is ``np.nan``)
Returns:
(np.ndarray):
An array of shape *(H, W, 2)* representing the inverse remap field.
"""
H, W, _ = remappingIdx.shape
# Create the meshgrid of the original image
grid = np.mgrid[:H, :W][::-1].transpose(1,2,0).reshape(-1, 2)
# Flatten the original meshgrid and remappingIdx
remapIdx_flat = remappingIdx.reshape(-1, 2)
# Interpolate the inverse mapping using griddata
map_BA = scipy.interpolate.griddata(
points=remapIdx_flat,
values=grid,
xi=grid,
method=method,
fill_value=fill_value,
).reshape(H,W,2)
return map_BA
[docs]
def invert_warp_matrix(
warp_matrix: np.ndarray
) -> np.ndarray:
"""
Inverts a provided warp matrix for the transformation A->B to compute the
warp matrix for B->A.
RH 2023
Args:
warp_matrix (np.ndarray):
A 2x3 or 3x3 array representing the warp matrix. Shape: *(2, 3)* or
*(3, 3)*.
Returns:
(np.ndarray):
inverted_warp_matrix (np.ndarray):
The inverted warp matrix. Shape: same as input.
"""
if warp_matrix.shape == (2, 3):
# Convert 2x3 affine warp matrix to 3x3 by appending [0, 0, 1] as the last row
warp_matrix_3x3 = np.vstack((warp_matrix, np.array([0, 0, 1])))
elif warp_matrix.shape == (3, 3):
warp_matrix_3x3 = warp_matrix
else:
raise ValueError("Input warp_matrix must be of shape (2, 3) or (3, 3)")
# Compute the inverse of the 3x3 warp matrix
inverted_warp_matrix_3x3 = np.linalg.inv(warp_matrix_3x3)
if warp_matrix.shape == (2, 3):
# Convert the inverted 3x3 warp matrix back to 2x3 by removing the last row
inverted_warp_matrix = inverted_warp_matrix_3x3[:2, :]
else:
inverted_warp_matrix = inverted_warp_matrix_3x3
return inverted_warp_matrix
[docs]
def compose_remappingIdx(
remap_AB: np.ndarray,
remap_BC: np.ndarray,
method: str = 'linear',
fill_value: Optional[float] = np.nan,
bounds_error: bool = False,
) -> np.ndarray:
"""
Composes two remapping index fields using scipy.interpolate.interpn.
This function computes 'remap_AC' from 'remap_AB' and 'remap_BC', where
'remap_AB' is a remapping index field that warps image A onto image B, and
'remap_BC' is a remapping index field that warps image B onto image C.
RH 2023
Args:
remap_AB (np.ndarray):
An array of shape *(H, W, 2)* representing the remap field from
image A to image B.
remap_BC (np.ndarray):
An array of shape *(H, W, 2)* representing the remap field from
image B to image C.
method (str):
Interpolation method to use. Either \n
* ``'linear'``: Use linear interpolation (default).
* ``'nearest'``: Use nearest interpolation.
* ``'cubic'``: Use cubic interpolation.
fill_value (Optional[float]):
The value used for points outside the interpolation domain. (Default
is ``np.nan``)
bounds_error (bool):
If ``True``, a ValueError is raised when interpolated values are
requested outside of the domain of the input data. (Default is
``False``)
Returns:
(np.ndarray):
remap_AC (np.ndarray):
An array of shape *(H, W, 2)* representing the remap field from
image A to image C.
"""
# Get the shape of the remap fields
H, W, _ = remap_AB.shape
# Combine the x and y components of remap_AB into a complex number
# This is done to simplify the interpolation process
AB_complex = remap_AB[:,:,0] + remap_AB[:,:,1]*1j
# Perform the interpolation using interpn
AC = scipy.interpolate.interpn(
(np.arange(H), np.arange(W)),
AB_complex,
remap_BC.reshape(-1, 2)[:, ::-1],
method=method,
bounds_error=bounds_error,
fill_value=fill_value
).reshape(H, W)
# Split the real and imaginary parts of the interpolated result to get the x and y components
remap_AC = np.stack((AC.real, AC.imag), axis=-1)
return remap_AC
def _make_idx_grid(
im: Union[np.ndarray, object],
) -> Union[np.ndarray, object]:
"""
Helper function to make a grid of indices for an image. Used in
``flowField_to_remappingIdx`` and ``remappingIdx_to_flowField``.
Args:
im (Union[np.ndarray, object]):
An image represented as a numpy ndarray or torch Tensor.
Returns:
(Union[np.ndarray, object]):
idx_grid (Union[np.ndarray, object]):
Index grid for the given image.
"""
if isinstance(im, torch.Tensor):
stack, meshgrid, arange = partial(torch.stack, dim=-1), partial(torch.meshgrid, indexing='xy'), partial(torch.arange, device=im.device, dtype=im.dtype)
elif isinstance(im, np.ndarray):
stack, meshgrid, arange = partial(np.stack, axis=-1), partial(np.meshgrid, indexing='xy'), partial(np.arange, dtype=im.dtype)
return stack(meshgrid(arange(im.shape[1]), arange(im.shape[0]))) # (H, W, 2). Last dimension is (x, y).
[docs]
def flowField_to_remappingIdx(
ff: Union[np.ndarray, object],
) -> Union[np.ndarray, object]:
"""
Converts a flow field into a remapping index by adding the pixel grid. RH 2023
**WARNING**: Strictly speaking, a flow field (displacement) and a
remapping index (interpolation mapping) are different concepts; this
helper performs the obvious sum and is correct under the standard
convention.
Args:
ff (Union[np.ndarray, torch.Tensor]):
Flow field describing the displacement of each pixel. shape:
*(H, W, 2)*. Last dimension is *(x, y)*.
Returns:
(Union[np.ndarray, torch.Tensor]):
ri (Union[np.ndarray, torch.Tensor]):
Remapping index of source pixel coordinates. shape: *(H, W, 2)*.
"""
ri = ff + _make_idx_grid(ff)
return ri
[docs]
def remappingIdx_to_flowField(
ri: Union[np.ndarray, object],
) -> Union[np.ndarray, object]:
"""
Converts a remapping index into a flow field by subtracting the pixel grid. RH 2023
**WARNING**: Strictly speaking, a remapping index (interpolation
mapping) and a flow field (displacement) are different concepts; this
helper performs the obvious subtraction.
Args:
ri (Union[np.ndarray, torch.Tensor]):
Remapping index. shape: *(H, W, 2)*. Last dimension is
*(x, y)*.
Returns:
(Union[np.ndarray, torch.Tensor]):
ff (Union[np.ndarray, torch.Tensor]):
Flow field. shape: *(H, W, 2)*.
"""
ff = ri - _make_idx_grid(ri)
return ff
[docs]
def cv2RemappingIdx_to_pytorchFlowField(
ri: Union[np.ndarray, torch.Tensor]
) -> Union[np.ndarray, torch.Tensor]:
"""
Converts remapping indices from the OpenCV format to the PyTorch format. In
the OpenCV format, the displacement is in pixels relative to the top left
pixel of the image. In the PyTorch format, the displacement is in pixels
relative to the center of the image. RH 2023
Args:
ri (Union[np.ndarray, torch.Tensor]):
Remapping indices. Each pixel describes the index of the pixel in
the original image that should be mapped to the new pixel. Shape:
*(H, W, 2)*. The last dimension is (x, y).
Returns:
(Union[np.ndarray, torch.Tensor]):
normgrid (Union[np.ndarray, torch.Tensor]):
"Flow field", in the PyTorch format. Technically not a flow
field, since it doesn't describe displacement. Rather, it is a
remapping index relative to the center of the image. Shape: *(H,
W, 2)*. The last dimension is (x, y).
"""
assert isinstance(ri, torch.Tensor), f"ri must be a torch.Tensor. Got {type(ri)}"
im_shape = torch.flipud(torch.as_tensor(ri.shape[:2], dtype=torch.float32, device=ri.device)) ## (W, H)
normgrid = ((ri / (im_shape[None, None, :] - 1)) - 0.5) * 2 ## PyTorch's grid_sample expects grid values in [-1, 1] because it's a relative offset from the center pixel. CV2's remap expects grid values in [0, 1] because it's an absolute offset from the top-left pixel.
## note also that pytorch's grid_sample expects align_corners=True to correspond to cv2's default behavior.
return normgrid
[docs]
def remap_points(
points: np.ndarray,
remappingIdx: np.ndarray,
interpolation: str = 'linear',
fill_value: float = None,
) -> np.ndarray:
"""
Remaps a set of 2D points through an index map produced for image warping.
Args:
points (np.ndarray):
Array of points to be remapped. shape: *(n_points, 2)*, dtype:
*floating*. Each row is an ``(x, y)`` coordinate within the
image.
remappingIdx (np.ndarray):
Index map describing the warp. shape: *(height, width, 2)*,
dtype: *floating*.
interpolation (str):
Interpolation method passed to
``scipy.interpolate.RegularGridInterpolator``. One of \n
* ``'linear'``
* ``'nearest'``
* ``'slinear'``
* ``'cubic'``
* ``'quintic'``
* ``'pchip'`` \n
(Default is ``'linear'``)
fill_value (Optional[float]):
Value used to fill points outside the convex hull. If ``None``,
values outside the convex hull are extrapolated. (Default is
``None``)
Returns:
(np.ndarray):
points_remap (np.ndarray):
Remapped points. shape: *(n_points, 2)*.
"""
### Assert points is a 2D numpy.ndarray of shape (n_points, 2) and that all points are within the image and that points are float
assert isinstance(points, np.ndarray), 'points must be a numpy.ndarray'
assert points.ndim == 2, 'points must be a 2D numpy.ndarray'
assert points.shape[1] == 2, 'points must be of shape (n_points, 2)'
assert np.issubdtype(points.dtype, np.floating), 'points must be a float subtype'
assert isinstance(remappingIdx, np.ndarray), 'remappingIdx must be a numpy.ndarray'
assert remappingIdx.ndim == 3, 'remappingIdx must be a 3D numpy.ndarray'
assert remappingIdx.shape[2] == 2, 'remappingIdx must be of shape (height, width, 2)'
assert np.issubdtype(remappingIdx.dtype, np.floating), 'remappingIdx must be a float subtype'
## Make grid of indices for image remapping
dims = remappingIdx.shape
x_arange, y_arange = np.arange(0., dims[1]).astype(np.float32), np.arange(0., dims[0]).astype(np.float32)
## Use RegularGridInterpolator to remap points
warper = scipy.interpolate.RegularGridInterpolator(
points=(y_arange, x_arange),
values=remappingIdx,
method=interpolation,
bounds_error=False,
fill_value=fill_value,
)
points_remap = warper(xi=(points[:, 1], points[:, 0]))
return points_remap
##########################################################################################################################################
########################################################### RESOURCE TRACKING ############################################################
##########################################################################################################################################
import datetime
from threading import Timer
from pathlib import Path
import psutil
class _Device_Checker_Base():
"""
Base class for periodic resource utilization tracking.
Subclasses must implement ``check_utilization`` returning a dict of
measurement values keyed by metric name.
Args:
verbose (int):
Verbosity level. ``0`` silences output, ``1`` prints basic
statements and warnings. (Default is ``1``)
Attributes:
log (Dict[str, Dict[int, Any]]):
Per-metric history populated by ``log_utilization``.
fn_timer (_RepeatTimer):
Background repeating timer created by ``track_utilization``.
"""
def __init__(self, verbose=1):
"""Stores the verbosity level for the checker."""
self._verbose = int(verbose)
def log_utilization(self, path_save=None):
"""
Records one utilization snapshot, optionally appending to a CSV file.
Args:
path_save (Optional[str]):
If provided, appends the snapshot as a row to a ``.csv``
file at this path (creating it with a header if needed).
(Default is ``None``)
Returns:
(Dict[str, Dict[int, Any]]):
log (Dict[str, Dict[int, Any]]):
Updated history of all collected metrics.
"""
info_changing = self.check_utilization()
if not hasattr(self, 'log'):
self.log = {}
self._iter_log = 0
## Populate with keys
for key in info_changing.keys():
self.log[key] = {}
print(f'Created self.log with keys: {self.log.keys()}') if self._verbose > 0 else None
else:
assert hasattr(self, '_iter_log'), 'self.log exists but self._iter_log does not'
self._iter_log += 1
## Populate with values
for key in info_changing.keys():
self.log[key][self._iter_log] = info_changing[key]
## Save
if path_save is not None:
assert path_save.endswith('.csv'), 'path_save must be a .csv file'
## Check if file exists
if not Path(path_save).exists():
## Make a .csv file with header
with open(path_save, 'w') as f:
f.write(','.join(self.log.keys()) + '\n')
## Append to file
with open(path_save, 'a') as f:
f.write(','.join([str(info_changing[key]) for key in self.log.keys()]) + '\n')
## Append to file
else:
with open(path_save, 'a') as f:
f.write(','.join([str(info_changing[key]) for key in self.log.keys()]) + '\n')
return self.log
def track_utilization(
self,
interval=0.2,
path_save=None,
):
"""
Begins periodic logging on a background thread.
Replaces any existing tracker and stores the new one as
``self.fn_timer``. Each call to ``log_utilization`` happens every
``interval`` seconds.
Args:
interval (float):
Polling interval in seconds. ``0.2`` is the minimum useful
value. (Default is ``0.2``)
path_save (Optional[str]):
``.csv`` path to which each snapshot is appended. ``None``
disables saving. (Default is ``None``)
"""
self.stop_tracking()
## Make a background thread that runs self.log_utilization() every interval seconds
def log_utilization_thread():
self.log_utilization(path_save=path_save)
self.fn_timer = _RepeatTimer(interval, log_utilization_thread)
self.fn_timer.start()
def stop_tracking(self):
"""Cancels the background ``fn_timer`` thread (if one is running)."""
if hasattr(self, 'fn_timer'):
self.fn_timer.cancel()
def __del__(self):
"""Stops the background tracker before the checker is garbage-collected."""
self.stop_tracking()
[docs]
class NVIDIA_Device_Checker(_Device_Checker_Base):
"""
Resource utilization checker for an NVIDIA GPU.
Requires the ``nvidia-ml-py3`` package.
Args:
device_index (Optional[int]):
Index of the GPU to monitor. If ``None`` and only one device is
present, that device is used; otherwise an error is raised.
(Default is ``None``)
verbose (int):
Verbosity level passed to the base class. (Default is ``1``)
Attributes:
info_static (Dict[str, Any]):
Static device info (name, index, total memory, power limit).
handle (object):
``nvidia_smi`` device handle for the monitored GPU.
"""
def __init__(self, device_index=None, verbose=1):
"""Initializes NVML, looks up device handles, and captures static device info."""
try:
import nvidia_smi
except ImportError:
raise ImportError('nvidia_smi package not found. Install with "pip install nvidia-ml-py3"')
self.nvidia_smi = nvidia_smi
super().__init__(verbose=verbose)
## Initialize
nvidia_smi.nvmlInit() ## This is needed to get device info
## Get device handles
self._handles_allDevices = self.get_device_handles()
n_device = len(self._handles_allDevices)
if n_device == 1:
self.handle = self._handles_allDevices[0]
self.device_index = 0
print(f'Found one device. Setting self.device_index to 0.') if self._verbose > 0 else None
else:
assert isinstance(device_index, int), 'Device index must be specified since multiple devices were found'
assert device_index < n_device, f'Device index specified is greater tban the number of devices found: {n_device}'
## Get static info
self.info_static = {}
self.info_static['device_name'] = nvidia_smi.nvmlDeviceGetName(self.handle)
self.info_static['device_index'] = nvidia_smi.nvmlDeviceGetIndex(self.handle)
self.info_static['memory_total'] = nvidia_smi.nvmlDeviceGetMemoryInfo(self.handle).total
self.info_static['power_limit'] = nvidia_smi.nvmlDeviceGetPowerManagementLimit(self.handle)
[docs]
def get_device_handles(self):
"""Returns one ``nvmlDeviceGetHandleByIndex`` handle per GPU detected by NVML."""
nvidia_smi = self.nvidia_smi
return [nvidia_smi.nvmlDeviceGetHandleByIndex(i_device) for i_device in range(nvidia_smi.nvmlDeviceGetCount())]
[docs]
def check_utilization(self):
"""
Returns a snapshot of the current GPU utilization metrics.
Returns:
(Dict[str, Any]):
info_changing (Dict[str, Any]):
Includes ``time``, ``memory_free``, ``memory_used``,
``memory_used_percentage``, ``power_used``,
``power_used_percentage``, ``processor_used_percentage``,
``temperature``, and ``fan_speed``.
"""
nvidia_smi = self.nvidia_smi
h = self.handle
info_mem = nvidia_smi.nvmlDeviceGetMemoryInfo(h)
info_changing = {}
info_changing['time'] = datetime.datetime.now()
info_changing['memory_free'] = info_mem.free
info_changing['memory_used'] = info_mem.used
info_changing['memory_used_percentage'] = 100 * info_mem.used / info_mem.total
info_changing['power_used'] = nvidia_smi.nvmlDeviceGetPowerUsage(h)
info_changing['power_used_percentage'] = 100* info_changing['power_used'] / nvidia_smi.nvmlDeviceGetPowerManagementLimit(h)
info_changing['processor_used_percentage'] = nvidia_smi.nvmlDeviceGetUtilizationRates(h).gpu
info_changing['temperature'] = nvidia_smi.nvmlDeviceGetTemperature(h, nvidia_smi.NVML_TEMPERATURE_GPU)
info_changing['fan_speed'] = nvidia_smi.nvmlDeviceGetFanSpeed(h)
return info_changing
def __del__(self):
"""Shuts down NVML and cancels any background tracker before destruction."""
nvidia_smi = self.nvidia_smi
nvidia_smi.nvmlShutdown() ## This stops the ability to get device info
super().__del__()
[docs]
class CPU_Device_Checker(_Device_Checker_Base):
"""
Resource utilization checker for the host CPU and disk.
Args:
verbose (int):
Verbosity level passed to the base class. (Default is ``1``)
Attributes:
info_static (Dict[str, Any]):
Static info (CPU count, frequency, total RAM, total disk).
"""
def __init__(self, verbose=1):
"""Captures static CPU, memory, and disk info."""
super().__init__(verbose=verbose)
self.info_static = {}
self.info_static['cpu_count'] = psutil.cpu_count()
self.info_static['cpu_freq'] = psutil.cpu_freq()
self.info_static['memory_total'] = psutil.virtual_memory().total
self.info_static['disk_total'] = psutil.disk_usage('/').total
[docs]
def check_utilization(self):
"""
Returns a snapshot of CPU, memory, network, and disk utilization.
Returns:
(Dict[str, Any]):
info_changing (Dict[str, Any]):
Per-snapshot metrics including memory, network I/O, disk
free/used, disk read/write throughput, and overall +
per-core CPU usage percentages.
"""
info_changing = {}
info_changing['time'] = datetime.datetime.now()
## log cpu utilization (per cpu), memory utilization, network utilization, disk utilization, etc
info_changing['memory_used_percentage'] = psutil.virtual_memory().percent
info_changing['memory_used'] = psutil.virtual_memory().used
info_changing['memory_free'] = psutil.virtual_memory().free
info_changing['memory_available'] = psutil.virtual_memory().available
info_changing['memory_active'] = psutil.virtual_memory().active
info_changing['memory_inactive'] = psutil.virtual_memory().inactive
info_changing['memory_buffers'] = psutil.virtual_memory().buffers
info_changing['memory_cached'] = psutil.virtual_memory().cached
info_changing['memory_shared'] = psutil.virtual_memory().shared
## Get network info: current bytes sent and received
info_changing['network_sent'] = psutil.net_io_counters().bytes_sent
info_changing['network_received'] = psutil.net_io_counters().bytes_recv
## Get disk info: free space and used space and percentage
info_changing['disk_free'] = psutil.disk_usage('/').free
info_changing['disk_used'] = psutil.disk_usage('/').used
info_changing['disk_used_percentage'] = psutil.disk_usage('/').percent
## Get disk read/write info
info_changing['disk_read'] = psutil.disk_io_counters().read_bytes
info_changing['disk_write'] = psutil.disk_io_counters().write_bytes
## Get processor info: current processor utilization (overall and per core)
info_changing['processor_used_percentage'] = psutil.cpu_percent()
for i_core, val in enumerate(psutil.cpu_percent(percpu=True)):
info_changing[f'cpu_{i_core}'] = val
return info_changing
class _RepeatTimer(Timer):
"""Subclass of ``threading.Timer`` that re-fires its callback every ``interval`` seconds."""
def run(self):
"""Calls ``self.function`` repeatedly until ``self.finished`` is set."""
while not self.finished.wait(self.interval):
self.function(*self.args, **self.kwargs)
##########################################################################################################################################
########################################################### TESTING ######################################################################
##########################################################################################################################################
[docs]
class Equivalence_checker():
"""
Class for checking if all items are equivalent or allclose (almost equal) in
two complex data structures. Can check nested lists, dicts, and other data
structures. Can also optionally assert (raise errors) if all items are not
equivalent.
RH 2023
Attributes:
_kwargs_allclose (Optional[dict]):
Keyword arguments for the `numpy.allclose` function.
_assert_mode (bool):
Whether to raise an assertion error if items are not close.
Args:
kwargs_allclose (Optional[dict]):
Keyword arguments for the `numpy.allclose` function. (Default is
``{'rtol': 1e-7, 'equal_nan': True}``)
assert_mode (bool):
Whether to raise an assertion error if items are not close.
verbose (bool):
How much information to print out:
* ``False`` / ``0``: No information printed out.
* ``True`` / ``1``: Mismatched items only.
* ``2``: All items printed out.
"""
def __init__(
self,
kwargs_allclose: Optional[dict] = {'rtol': 1e-7, 'equal_nan': True},
assert_mode=False,
verbose=False,
) -> None:
"""Stores comparison settings for later use by ``__call__``."""
self._kwargs_allclose = kwargs_allclose
self._assert_mode = assert_mode
self._verbose = verbose
def _checker(
self,
test: Any,
true: Any,
path: Optional[List[str]] = None,
) -> bool:
"""
Compares the test and true values using numpy's allclose function.
Args:
test (Union[dict, list, tuple, set, np.ndarray, int, float, complex,
str, bool, None]):
Test value to compare.
true (Union[dict, list, tuple, set, np.ndarray, int, float, complex,
str, bool, None]):
True value to compare.
path (Optional[List[str]]):
The path of the data structure that is currently being compared.
(Default is ``None``)
Returns:
(bool):
result (bool):
Returns True if all elements in test and true are close.
Otherwise, returns False.
"""
try:
## If the dtype is a kind of string (or byte string) or object, then allclose will raise an error. In this case, just check if the values are equal.
if np.issubdtype(test.dtype, np.str_) or np.issubdtype(test.dtype, np.bytes_) or test.dtype == np.object_:
out = bool(np.all(test == true))
print(f"Equivalence check {'passed' if out else 'failed'}. Path: {path}.") if self._verbose > 1 else None
else:
out = np.allclose(test, true, **self._kwargs_allclose)
print(f"Equivalence check passed. Path: {path}") if self._verbose > 1 else None
except Exception as e:
out = None ## This is not False because sometimes allclose will raise an error if the arrays have a weird dtype among other reasons.
warnings.warn(f"WARNING. Equivalence check failed. Path: {path}. Error: {e}") if self._verbose else None
if out == False:
if self._assert_mode:
raise AssertionError(f"Equivalence check failed. Path: {path}.")
if self._verbose:
## Come up with a way to describe the difference between the two values. Something like the following:
### IF the arrays are numeric, then calculate the relative difference
dtypes_numeric = (np.number, np.bool_, np.integer, np.floating, np.complexfloating)
if any([np.issubdtype(test.dtype, dtype) and np.issubdtype(true.dtype, dtype) for dtype in dtypes_numeric]):
## numpy 1.25+ removed the `-` operator on bool arrays; cast first.
test_n = test.astype(np.int8) if test.dtype == bool else test
true_n = true.astype(np.int8) if true.dtype == bool else true
diff = np.abs(test_n - true_n)
r_diff = diff / np.abs(true_n)
r_diff_mean, r_diff_max, any_nan = np.nanmean(r_diff), np.nanmax(r_diff), np.any(np.isnan(r_diff))
print(f"Equivalence check failed. Path: {path}. Relative difference: mean={r_diff_mean}, max={r_diff_max}, any_nan={any_nan}") if self._verbose > 0 else None
else:
print(f"Equivalence check failed. Path: {path}. Value is non-numerical.") if self._verbose > 0 else None
return out
def __call__(
self,
test: Union[dict, list, tuple, set, np.ndarray, int, float, complex, str, bool, None],
true: Union[dict, list, tuple, set, np.ndarray, int, float, complex, str, bool, None],
path: Optional[List[str]] = None,
) -> Dict[str, Tuple[bool, str]]:
"""
Recursively compares ``test`` to ``true`` and returns the result.
Handles dictionaries, iterables, ``np.ndarray``, scalars, strings,
numbers, booleans, and ``None``.
Args:
test (Union[dict, list, tuple, set, np.ndarray, int, float, complex, str, bool, None]):
Test value to compare.
true (Union[dict, list, tuple, set, np.ndarray, int, float, complex, str, bool, None]):
Reference value.
path (Optional[List[str]]):
Hierarchical path within the data structure currently being
compared (used in recursion). (Default is ``None``)
Returns:
(Union[Tuple[Optional[bool], str], Dict[str, Any]]):
result (Union[Tuple[Optional[bool], str], Dict[str, Any]]):
For leaf values, a ``(passed, reason)`` tuple. For
container values, a nested dictionary keyed by the
container's keys / indices.
"""
if path is None:
path = ['']
if len(path) > 0:
if path[-1].startswith('_'):
return (None, 'excluded from testing')
## NP.NDARRAY
if isinstance(true, np.ndarray):
r = self._checker(test, true, path)
result = (r, 'equivalence')
## NP.SCALAR
elif np.isscalar(true):
if isinstance(test, (int, float, complex, np.number)):
r = self._checker(np.array(test), np.array(true), path)
result = (r, 'equivalence')
else:
result = (test == true, 'equivalence')
## NUMBER
elif isinstance(true, (int, float, complex)):
r = self._checker(test, true, path)
result = (result, 'equivalence')
## DICT
elif isinstance(true, dict):
result = {}
for key in true:
if key not in test:
result[str(key)] = (False, 'key not found')
print(f"Equivalence check failed. Path: {path}. Key {key} not found.") if self._verbose > 0 else None
else:
result[str(key)] = self.__call__(test[key], true[key], path=path + [str(key)])
## ITERATABLE
elif isinstance(true, (list, tuple, set)):
if len(true) != len(test):
result = (False, 'length_mismatch')
print(f"Equivalence check failed. Path: {path}. Length mismatch.") if self._verbose > 0 else None
else:
result = {}
for idx, (i, j) in enumerate(zip(test, true)):
result[str(idx)] = self.__call__(i, j, path=path + [str(idx)])
## STRING
elif isinstance(true, str):
result = (test == true, 'equivalence')
print(f"Equivalence check {'passed' if result[0] else 'failed'}. Path: {path}.") if self._verbose > 0 else None
## BOOL
elif isinstance(true, bool):
result = (test == true, 'equivalence')
print(f"Equivalence check {'passed' if result[0] else 'failed'}. Path: {path}.") if self._verbose > 0 else None
## NONE
elif true is None:
result = (test is None, 'equivalence')
print(f"Equivalence check {'passed' if result[0] else 'failed'}. Path: {path}.") if self._verbose > 0 else None
## N/A
else:
result = (None, 'not tested')
print(f"Equivalence check not performed. Path: {path}.") if self._verbose > 0 else None
return result
##########################################################################################################################################
######################################################### SIMILARITY #####################################################################
##########################################################################################################################################
[docs]
def order_cp_factors_by_EVR(
tensor_dense: Union[np.ndarray, torch.Tensor],
cp_factors: Union[list, object],
cp_weights: Optional[Union[np.ndarray, torch.Tensor]] = None,
orthogonalizable_EVR: bool = True,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Sorts CP factors by descending explained variance ratio. RH 2024
Args:
tensor_dense (Union[np.ndarray, torch.Tensor]):
Dense tensor to be reconstructed.
cp_factors (Union[list, object]):
CP factors. Either a list of 2D factor matrices of shape
*(n_samples, rank)* or a ``tensorly.CPTensor`` object.
cp_weights (Optional[Union[np.ndarray, torch.Tensor]]):
Per-rank weights of length *(rank,)*. (Default is ``None``)
orthogonalizable_EVR (bool):
If ``True``, optimizes each factor's scaling to maximize EVR by
OLS-orthogonalizing the dense tensor against each factor.
(Default is ``True``)
Returns:
(tuple): tuple containing:
order (np.ndarray):
Indices that sort the factors by descending EVR.
evrs (np.ndarray):
Sorted explained variance ratios.
"""
if isinstance(cp_factors, list):
# If cp_factors is a list, use it directly
factors = cp_factors
if cp_weights is not None:
# Apply weights to each factor if cp_weights is provided
factors = [factors[ii] * cp_weights[ii] for ii in range(len(factors))]
elif isinstance(cp_factors, tl.cp_tensor.CPTensor):
# If cp_factors is a CPTensor, extract factors and apply weights
factors = [f * cp_factors.weights[None, :] for f in cp_factors.factors]
else:
raise ValueError('tensor_CP must be a list of factors or a tensorly CPTensor object')
if orthogonalizable_EVR:
# Flatten tensor_dense and remove its mean
tensor_dense = tensor_dense.reshape(-1)
tensor_dense -= tensor_dense.mean()
tensor_dense_var = tensor_dense.var() # Compute variance of tensor_dense
# Determine the rank from the shape of the first factor
rank = factors[0].shape[1]
evrs = [] # Initialize list to store explained variance ratios (EVRs)
for ii in range(rank):
# Extract the ii-th component from each factor and reshape to be a column vector
f = [f[:, ii][:, None] for f in factors]
if orthogonalizable_EVR:
# Convert CP components back to a dense tensor and flatten
v2 = cp_to_dense(f).reshape(-1)
v2 = v2 - v2.mean() # Remove mean of v2
# Calculate orthogonal component of tensor_dense with respect to v2
v1_orth = tensor_dense - ((tensor_dense * v2).sum() / (v2 * v2).sum() )*v2
# Compute EVR as the fraction of variance explained by v2
evr = 1 - (v1_orth.var() / tensor_dense_var)
else:
# Compute EVR using a predefined function for the non-orthogonalizable case
evr = cp_reconstruction_EVR(tensor_dense, f)
evrs.append(evr) # Append the computed EVR to the list
# Sort the EVRs in descending order and return the sorted order and EVRs
order = np.argsort(evrs)[::-1]
return order, np.array(evrs)[order]
[docs]
def cp_reconstruction_EVR(tensor_dense, tensor_CP):
"""
Explained variance ratio of a CP-reconstructed tensor. RH 2023
Args:
tensor_dense (Union[np.ndarray, torch.Tensor]):
Dense reference tensor. shape: *(n_samples, n_features)*.
tensor_CP (Union[list, object]):
CP tensor. Either a list of 2D factor matrices of shape
*(n_samples, rank)* or a ``tensorly.CPTensor`` object.
Returns:
(Union[float, torch.Tensor]):
ev (Union[float, torch.Tensor]):
Explained variance ratio
``1 - var(tensor_dense - tensor_rec) / var(tensor_dense)``.
"""
tensor_rec = None
try:
import tensorly as tl
if isinstance(tensor_CP, tl.cp_tensor.CPTensor):
tensor_rec = tl.cp_to_tensor(tensor_CP)
except ImportError as e:
raise ImportError('tensorly not installed. Please install tensorly or provide a list of factors as the tensor_CP argument.')
if tensor_rec is None:
assert isinstance(tensor_CP, list), 'tensor_CP must be a list of factors'
assert all([isinstance(f, (np.ndarray, torch.Tensor)) for f in tensor_CP]), 'tensor_CP must be a list of factors'
tensor_rec = cp_to_dense(tensor_CP)
if isinstance(tensor_dense, torch.Tensor):
var = torch.var
elif isinstance(tensor_dense, np.ndarray):
var = np.var
ev = 1 - (var(tensor_dense - tensor_rec) / var(tensor_dense))
return ev
#####################################################################################################################################
######################################################### OTHER #####################################################################
#####################################################################################################################################
import torch
from typing import Union
[docs]
def rolling_mean(tensor: torch.Tensor, dim: int) -> torch.Tensor:
"""
Computes the running mean along ``dim`` using Welford's update. RH 2025
Args:
tensor (torch.Tensor):
Input tensor.
dim (int):
Dimension along which the running mean is accumulated.
Returns:
(torch.Tensor):
mean (torch.Tensor):
Final mean across ``dim`` (last accumulated value).
"""
# Ensure the dimension is non-negative and valid.
if dim < 0:
dim += tensor.dim()
if dim < 0 or dim >= tensor.dim():
raise ValueError(f"Invalid dimension {dim} for tensor with {tensor.dim()} dimensions.")
# Unbind the tensor along the given dimension to get a list of slices.
dims_permute = list(range(tensor.dim()))
## remove dim from the list
dims_permute.remove(dim)
dims_permute = [dim] + dims_permute
# Initialize an empty list to store the running means.
current_mean = None
# Iterate through each slice along the given dimension.
# Use a counter starting at 1 since we divide by the count.
for i, slice in enumerate(tensor.permute(dims_permute)):
if current_mean is None:
# For the first element, the running mean is the element itself.
current_mean = slice
else:
# Update the running mean using:
current_mean = current_mean + (slice - current_mean) / (i + 1)
# Stack the list of running means back into a tensor along the specified dimension.
return current_mean
## Video helpers — ported from bnpm.video (bnpm 0.7.1, RH 2021/2024).
## Used by the demo_event_alignment notebook for tiled trial playback.
## bnpm is a personal-utility repo; these are the only pieces face-rhythm needs.
[docs]
def play_video_cv2(
array=None,
path_video=None,
frameRate=30,
path_save=None,
show=True,
fourcc_code='MJPG',
text=None,
kwargs_text={},
):
"""
Plays or saves a video using OpenCV. RH 2021/2024
Args:
array (Optional[np.ndarray]):
3D ``(frames, H, W)`` or 4D ``(frames, H, W, channels)`` ``uint8``
array. Values are clipped to ``[0, 255]``. If ``None``,
``path_video`` must be supplied and ``decord`` is used to read
it. (Default is ``None``)
path_video (Optional[Union[str, pathlib.Path]]):
Path to a video file. Used only when ``array`` is ``None``.
(Default is ``None``)
frameRate (float):
Playback / output frame rate in Hz. (Default is ``30``)
path_save (Optional[Union[str, pathlib.Path]]):
Destination path for the saved video. ``None`` disables saving.
(Default is ``None``)
show (bool):
If ``True``, displays the video in a ``cv2`` window. (Default
is ``True``)
fourcc_code (str):
FourCC codec string passed to ``cv2.VideoWriter_fourcc``.
(Default is ``'MJPG'``)
text (Optional[Union[str, List[str]]]):
Text overlay. If a list, element ``i`` is drawn on frame ``i``.
(Default is ``None``)
kwargs_text (dict):
Keyword arguments forwarded to ``cv2.putText``. (Default is
``{}``)
"""
wait_frames = max(int((1 / frameRate) * 1000), 1)
if path_save is not None:
size = tuple((np.flip(array.shape[1:3])))
fourcc = cv2.VideoWriter_fourcc(*fourcc_code)
print(f'saving to file {path_save}')
writer = cv2.VideoWriter(path_save, fourcc, frameRate, size)
if kwargs_text is None:
kwargs_text = {
'org': (5, 15),
'fontFace': 1,
'fontScale': 1,
'color': (255, 255, 255),
'thickness': 1,
}
if array is not None:
array[array < 0] = 0
array[array > 255] = 255
if array.dtype != 'uint8':
array = array.astype('uint8')
movie = array
if array.ndim == 4:
flag_convert_to_gray = True
elif array.ndim == 3:
flag_convert_to_gray = False
else:
raise Exception('Unsupported number of channels, check array shape')
else:
try:
import decord
except ImportError as e:
_pkg = "eva_decord" if sys.platform.startswith("win") else "decord2"
raise ImportError(
"decord is required when array=None. "
f"Install with: pip install {_pkg}"
) from e
movie = decord.VideoReader(path_video)
flag_convert_to_gray = False
for i_frame, frame in enumerate(tqdm(movie)):
if array is None:
frame = frame.asnumpy()
if array is not None:
if flag_convert_to_gray:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
if text is not None:
text_frame = text[i_frame] if isinstance(text, list) else text
frame = cv2.putText(frame, text_frame, **kwargs_text)
if show:
cv2.imshow('handle', np.uint8(frame))
cv2.waitKey(wait_frames)
if path_save is not None:
writer.write(np.uint8(frame))
if path_save is not None:
writer.release()
print('Video saved')
if show:
cv2.destroyWindow('handle')
[docs]
def make_tiled_video_array(
videos: List[np.ndarray],
shape: Optional[Tuple[int, int]] = None,
verbose: bool = True,
):
"""
Tiles a list of videos into a single grid video array. RH 2021/2024
Videos are placed top-to-bottom and then left-to-right.
Args:
videos (List[np.ndarray]):
List of video arrays with shape *(frames, H, W, channels)* or
*(frames, H, W)*. All videos must share the same dtype.
shape (Optional[Tuple[int, int]]):
Grid layout ``(n_rows, n_cols)``. If ``None``, uses the smallest
square grid that fits all videos. (Default is ``None``)
verbose (bool):
If ``True``, prints progress messages. (Default is ``True``)
Returns:
(np.ndarray):
video_array (np.ndarray):
Tiled video. shape: *(max_frames, total_H, total_W,
channels)*.
"""
assert isinstance(videos, list), f"videos must be a list. Got {type(videos)}"
assert all(isinstance(v, np.ndarray) for v in videos), \
f"All elements of videos must be numpy arrays. Got {[type(v) for v in videos]}"
assert all(v.dtype == videos[0].dtype for v in videos), \
f"All videos must have the same dtype. Got {[v.dtype for v in videos]}"
if shape is not None:
assert isinstance(shape, tuple), f"shape must be a tuple. Got {type(shape)}"
assert len(shape) == 2, f"shape must be a 2-tuple. Got {len(shape)} elements"
assert all(isinstance(val, int) for val in shape), \
f"shape must contain only integers. Got {shape}"
n_videos = len(videos)
if shape is None:
n_height = int(np.floor(np.sqrt(n_videos)))
n_width = int(np.ceil(np.sqrt(n_videos)))
shape = (n_height, n_width)
assert shape[0] * shape[1] >= n_videos, \
f"shape[0] * shape[1] must be >= number of videos. Got {shape[0] * shape[1]} < {n_videos}"
if verbose:
print(f"Making video array with shape: {shape} videos")
for ii, video in enumerate(videos):
assert video.ndim in [3, 4], \
f"videos[{ii}] must be 3D or 4D. Got {video.ndim} dimensions"
if video.ndim == 3:
videos[ii] = video[..., None]
n_frames, heights, widths, channels = (
np.array([video.shape[ii] for video in videos]) for ii in range(4)
)
max_frames = max(n_frames)
if verbose:
print(f"Max video length and total frames in output video array: {max_frames}")
dtype = videos[0].dtype
if verbose:
print(f"Video dtype: {dtype}")
idx_videoArray_videos = np.array(
[(ii % shape[0], ii // shape[0]) for ii in range(n_videos)]
)
if verbose:
print(f"Video array indices: {idx_videoArray_videos}")
widths_col = np.array([
np.max(widths[idx_videoArray_videos[:, 1] == ii])
for ii in np.unique(idx_videoArray_videos[:, 1])
])
heights_row = np.array([
np.max(heights[idx_videoArray_videos[:, 0] == ii])
for ii in np.unique(idx_videoArray_videos[:, 0])
])
final_shape = (max_frames, np.sum(heights_row), np.sum(widths_col), channels[0])
video_array = np.zeros(final_shape, dtype)
if verbose:
print(f"Final output array shape: {video_array.shape}")
idx_tops_rows = np.cumsum(np.concatenate(([0], heights_row)))[:-1]
idx_tops_cols = np.cumsum(np.concatenate(([0], widths_col)))[:-1]
if verbose:
print(f"idx_tops_rows: {idx_tops_rows}, idx_tops_cols: {idx_tops_cols}")
for ii, video in enumerate(videos):
idx_top = idx_tops_rows[idx_videoArray_videos[ii, 0]]
idx_left = idx_tops_cols[idx_videoArray_videos[ii, 1]]
video_array[
:video.shape[0],
idx_top:idx_top + video.shape[1],
idx_left:idx_left + video.shape[2],
:,
] = video
return video_array