Source code for hswfs.zernike

"""
This module provides methods related to Zernike polynomials and their
derivatives.
"""

# -----------------------------------------------------------------------------
# IMPORTS
# -----------------------------------------------------------------------------

from copy import deepcopy
from typing import Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np
import sympy as sy


# -----------------------------------------------------------------------------
# AUXILIARY FUNCTIONS
# -----------------------------------------------------------------------------

[docs]def mn_to_j(m: int, n: int) -> int: r""" Map the indices :math:`m, n` from the double-indexing scheme to the corresponding index :math:`j` of the single-indexing scheme. Basically, we are just counting the Zernike polynomials in a well-defined fashion. Mathematically, the mapping is given by: .. math:: j = \frac{n \cdot (n + 2) + m}{2} Args: m: Index :math:`m` of :math:`Z^m_n`. n: Index :math:`m` of :math:`Z^m_n`. Returns: The single index :math:`j` which corresponds to :math:`m, n`. """ return int((n * (n + 2) + m) / 2)
[docs]def j_to_mn(j: int) -> Tuple[int, int]: r""" Map the index :math:`j` of the single-indexing scheme to the corresponding indices :math:`m, n` from the double-indexing scheme. Mathematically, the mapping is given by: .. math:: n = \left\lceil (-3 + \sqrt{9 + 8 j})\, /\, 2 \right\rceil \quad \text{and} \quad m = 2 j - n \cdot (n + 2) Args: j: Index :math:`j` of :math:`Z_j`. Returns: The pair of indices :math:`m, n` which correspond to :math:`j`. """ n = int(np.ceil((-3 + np.sqrt(9 + 8 * j)) / 2)) m = int(2 * j - n * (n + 2)) return m, n
[docs]def polar_to_cartesian(expression: sy.Expr) -> sy.Expr: r""" Convert a sympy expression (`sy.Expr`) from polar coordinates to Cartesian coordinates by substituting :math:`\rho` and :math:`\phi` by the appropriate functions of :math:`x` and :math:`y`. Mathematically, the coordinate transformation is given by the following substitutions: .. math:: \rho = \sqrt{x^2 + y^2} \quad \text{and} \quad \phi = \arctan\left( \frac{y}{x} \right) Args: expression: A sympy expression in polar coordinates, i.e., it must contain two free symbols named "rho" and "phi". Returns: The original `expression`, converted to Cartesian coordinates. """ # Define symbols for polar and cartesian coordinates rho, phi = sy.symbols("rho"), sy.symbols("phi") x, y = sy.symbols("x"), sy.symbols("y") # Define coordinate transformation between polar and cartesian substitute_rho = sy.sqrt(x**2 + y**2) substitute_phi = sy.atan2(y, x) # Substitute rho and phi by the respective functions of x and y result = deepcopy(expression) result = result.subs(rho, substitute_rho) result = result.subs(phi, substitute_phi) return result
[docs]def derive(expression: sy.Expr, wrt: Union[str, sy.Symbol]) -> sy.Expr: """ Compute the derivative of `expression` with respect to `wrt`. Args: expression: A sympy expression. wrt: The variable with respect to which to take the derivative. Can either be a `sy.Symbol` or a string containing the name of the variable. Returns: The derivative of `expression` with respect to `wrt`. """ # Define type hint for the result derivative: sy.Expr # If wrt is a sympy symbol, and is part of the expression's free symbols, # we can directly take the derivative of the expression w.r.t. wrt if isinstance(wrt, sy.Symbol) and (wrt in expression.free_symbols): derivative = sy.diff(expression, wrt) # If wrt is a string, we check if there is a symbol in the free symbols of # the expression whose name matches wrt, and then we take the derivative # w.r.t. to this symbol elif isinstance(wrt, str): for symbol in expression.free_symbols: if wrt == symbol.name: # type: ignore derivative = sy.diff(expression, symbol) break else: raise ValueError( f"Can't differentiate {expression} w.r.t. {wrt} - not found!" ) # In every other case, the derivative is simply 0 else: derivative = sy.sympify(0) return derivative
[docs]def is_cartesian(expression: sy.Expr) -> bool: """ Check if a given `expression` is in Cartesian coordinates, that is, if the names of its free symbols are a subset of `{"x", "y"}`. Args: expression: A sympy expression. Returns: True if `expression` is in Cartesian coordinates; else False. """ # noinspection PyUnresolvedReferences symbols = {_.name for _ in expression.free_symbols} return symbols.issubset({"x", "y"})
[docs]def is_polar(expression: sy.Expr) -> bool: """ Check if a given `expression` is in polar coordinates, that is, if the names of its free symbols are a subset of `{"rho", "phi"}`. Args: expression: A sympy expression. Returns: True if `expression` is in polar coordinates; else False. """ # noinspection PyUnresolvedReferences symbols = {_.name for _ in expression.free_symbols} return symbols.issubset({"rho", "phi"})
[docs]def eval_cartesian( expression: sy.Expr, x_0: Union[float, np.ndarray], y_0: Union[float, np.ndarray], ) -> Union[float, np.ndarray]: """ Evaluate an expression that is in in Cartesian coordinates, either at a single position or on a grid of positions. Args: expression: A sympy expression. x_0: The value(s) of :math:`x` at which to evaluate the given `expression`. This can either be a single float, or an array of arbitrary size (its shape, however, must match `y_0`). y_0: The value(s) of :math:`y` at which to evaluate the given `expression`. This can either be a single float, or an array of arbitrary size (its shape, however, must match `x_0`). Returns: The value of `expression` at the given position(s). The type and shape of the output matches the one of the input: for `x_0`, `y_0` as floats, a float is returned; for numpy array inputs, a numpy array is returned. """ # Make sure that expression is a function of Cartesian coordinates assert is_cartesian( expression ), '"expression" is not in Cartesian coordinates!' # Make sure that x_0 and y_0 have compatible shapes assert (isinstance(x_0, float) and isinstance(y_0, float)) or ( isinstance(x_0, np.ndarray) and isinstance(y_0, np.ndarray) and x_0.shape == y_0.shape ), ( '"x_0" and "y_0" must be either both float, or both numpy array ' "with the same shape!" ) # If the expression is not constant, we can use sympy.lambdify() to # generate a numpy version of the expression, which can be used to # evaluate the function efficiently: if not expression.is_constant(): numpy_func: Callable[ ..., Union[float, np.ndarray] ] = sy.utilities.lambdify( args=sy.symbols("x, y"), expr=expression, modules="numpy" ) # Otherwise, that is, if the expression is constant, we need to define # the evaluation function manually because the result of sympy.lambdify() # does not behave as desired (it does not vectorize properly). else: # The multiplication with _ / _ makes sure that everything that is NaN # in the input also is NaN in the output; non-NaN values are unchanged def numpy_func(_: float, __: float) -> float: return float(expression) * _ / _ * __ / __ numpy_func = np.vectorize(numpy_func) return numpy_func(x_0, y_0)
# ----------------------------------------------------------------------------- # CLASSES # -----------------------------------------------------------------------------
[docs]class ZernikePolynomial: """ Implements the Zernike polynomial :math:`Z^m_n` (in double-index notation), or :math:`Z_j` (in single-index notation). """ def __init__( self, m: Optional[int] = None, n: Optional[int] = None, j: Optional[int] = None, ): # Make sure that we have received *either* (m, n) *or* j error_msg = ( "ZernikePolynomial must be instantiated either with " "double indices (m, n) *or* a single index j!" ) if j is not None: assert m is None, error_msg assert n is None, error_msg self.j = j self.m, self.n = j_to_mn(self.j) else: assert m is not None, error_msg assert n is not None, error_msg self.m, self.n = m, n self.j = mn_to_j(self.m, self.n) # Run basic sanity checks on inputs assert ( -self.n <= self.m <= self.n ), "Zernike polynomials are only defined for -n <= m <= n!" assert self.j >= 0, "Zernike polynomials are only defined for j >= 0!" def __repr__(self) -> str: return f"Z^{self.m}_{self.n}" @property def radial_part(self) -> sy.Expr: r""" The radial polynomial :math:`R^m_n`, which is given by: .. math:: R^m_n ( \rho ) = \sum_{k=0}^{\frac{n-m}{2}} (-1)^{k} \, {{n - k} \choose {k}} \, {{n - 2k} \choose {\frac{n - m}{2} - k}} \, \rho^{n - 2k} Returns: The radial polynomial :math:`R^m_n`. """ # Define a symbol for the radius (rho) rho = sy.Symbol("rho") # Define type hint for the result result: sy.Expr # If n - m is odd, the radial polynomial is simply 0 if (self.n - self.m) % 2 == 1: result = sy.sympify(0) # Otherwise, things are a little more complicated else: result = sum( sy.Pow(-1, k) * sy.binomial(int(self.n - k), int(k)) * sy.binomial( int(self.n - 2 * k), int((self.n - self.m) / 2 - k) ) * sy.Pow(rho, self.n - 2 * k) for k in range(0, int((self.n - self.m) / 2) + 1) ) return result @property def azimuthal_part(self) -> sy.Expr: r""" The azimuthal component of :math:`Z^m_n`, which is given by: .. math:: \Phi_m ( \phi ) = \begin{cases} \cos( m \, \phi ) & \text{for}\ m \geq 0 \\ \sin( m \, \phi ) & \text{for}\ m < 0 \end{cases} Returns: The azimuthal part of :math:`Z^m_n`. """ # Define a symbol for the azimuthal angle (phi) phi = sy.Symbol("phi") # Define type hint for the result result: sy.Expr # Return the azimuthal part, which depends only on the value of m if self.m > 0: result = sy.cos(self.m * phi) elif self.m < 0: result = sy.sin(-self.m * phi) else: result = sy.sympify(1) return result @property def normalization(self) -> sy.Expr: r""" The normalization factor of :math:`Z^m_n`. Zernike polynomials are normalized such that: .. math:: \int_0^1 d\rho \int_0^{2\pi}d\phi \ Z^2(\rho, \phi) \, \rho = \pi .. note:: Note that this choice of normalization is not universal, and that some authors choose different normalizations; for example, they normalize the above integral to 1 instead of :math:`\pi`. Returns: The normalization factor of :math:`Z^m_n`. """ # Define type hint for the result result: sy.Expr if self.m == 0: result = sy.sqrt(self.n + 1) else: result = sy.sqrt(self.n + 1) * sy.sqrt(2) return result @property def polar(self) -> sy.Expr: r""" :math:`Z^m_n(\rho, \phi)`, that is, the full Zernike polynomial in polar coordinates :math:`\rho, \phi`. Returns: The Zernike polynomial :math:`Z^m_n(\rho, \phi)`. """ result: sy.Expr = ( self.normalization * self.radial_part * self.azimuthal_part ) return result @property def cartesian(self) -> sy.Expr: r""" :math:`Z^m_n(x, y)`, that is, the full Zernike polynomial in Cartesian coordinates :math:`x, y`. Returns: The Zernike polynomial :math:`Z^m_n(x, y)`. """ return polar_to_cartesian(self.polar) @property def fourier_transform(self) -> sy.Expr: r""" The 2D Fourier transform of :math:`Z^m_n`. This function essentially implements eq. (7) of [Tatulli_2013]_. .. note:: Compared to [Tatulli_2013]_, we are using a slightly different notation. Instead of indexing the dimensions of the Fourier space (or :math:`k`-space) using :math:`\kappa` and :math:`\alpha`, we use :math:`k_1` and :math:`k_2`. Returns: The 2D Fourier transform of :math:`Z^m_n`, that is, :math:`\mathcal{F}\lbrace Z^m_n \rbrace (k_1, k_2)`. """ # Define symbols for k1 and k2 k1 = sy.Symbol("k1") k2 = sy.Symbol("k2") # Define the first factor, which only depends on n factor_1 = ( sy.Pow(-1, self.n) * sy.sqrt(self.n + 1) / (sy.pi * k1) * sy.besselj(2 * sy.pi * k1, self.n + 1) ) # Define the second factor that also depends on m if self.m == 0: factor_2 = sy.Pow(-1, self.n / 2) elif self.m > 0: factor_2 = ( sy.sqrt(2) * sy.Pow(-1, (self.n - self.m) / 2) * sy.Pow(sy.I, self.m) * sy.cos(self.m * k2) ) else: factor_2 = ( sy.sqrt(2) * sy.Pow(-1, (self.n + self.m) / 2) * sy.Pow(sy.I, -self.m) * sy.sin(-self.m * k2) ) result: sy.Expr = sy.nsimplify(sy.simplify(factor_1 * factor_2)) return result
[docs]class Wavefront: r""" A wavefront, expressed as a weighted sum of Zernike polynomials. The wavefront is returned as a `sy.Expr`. This is useful if we, for example, also want to compute the derivative of the wavefront. Both the polar and the Cartesian representation of the wavefront are available. Args: coefficients: The coefficients to be used as weights for the Zernike polynomials. There are two ways to specify this: 1. As a sequence of floats. In this case, the :math:`j`-th entry of the sequence will be used as the weight for the :math:`j`-th Zernike polynomial. That means: >>> coefficients = [0, 1, 2, 0, 4] will produce the following wavefront: .. math:: \text{WF} = Z_1 + 2 \cdot Z_2 + 4 \cdot Z_4 2. As a dictionary with entries of the form `(j, weight)`. To reproduce the previous example, we could therefore also write: >>> coefficients = {1: 1, 2: 2, 4: 4} Note that there is per se no limit of the number of Zernike polynomials that can be used for a wavefront; however, things of course will get slower for higher orders. """ def __init__(self, coefficients: Union[Sequence[float], Dict[int, float]]): # Store constructor arguments self.coefficients = coefficients @property def polar(self) -> sy.Expr: """ Get the polar representation of the wavefront. """ # Define type hint for the result result: sy.Expr if isinstance(self.coefficients, dict): result = sum( coefficient * ZernikePolynomial(j=j).polar for j, coefficient in self.coefficients.items() ) else: result = sum( coefficient * ZernikePolynomial(j=j).polar for j, coefficient in enumerate(self.coefficients) ) return result @property def cartesian(self) -> sy.Expr: """ Get the Cartesian representation of the wavefront. """ return polar_to_cartesian(self.polar)