differt.em.special module#
Special functions.
This module extends the jax.scipy.special
module
by adding missing function from scipy.special
,
or by extending already implemented function to the
complex domain.
Those new implementation are needed to keep the ability
of differentating code, otherwise we could just
call the SciPy function and wrap their output
with jnp.asarray
.
- erf(z)[source]#
Evaluate the error function at the given points.
The current implementation is written using the real-valued error function
jax.scipy.special.erf
and the approximation as detailed in [Leu08].The output type (real or complex) is determined by the input type.
Warning
Currently, we observe that this function and
scipy.special.erf
starts to diverge for \(|z| > 6\). If you know how to avoid this problem, please contact us!- Parameters:
z (
Inexact[Array, '*batch']
) – The array of real or complex points to evaluate.- Return type:
Inexact[Array, '*batch']
- Returns:
The values of the error function at the given point.
Notes
Regarding performances, there are two possible outputs:
If
z
is real, then this function compiles tojax.scipy.special.erf
, and will therefore have the same performances (when JIT compilation is done). Compared to the SciPy equivalent, we measured that our implementation is ~ 10 times faster.If
z
is complex, then our implementation is ~ 3 times faster thanscipy.special.erf
.
Those results were measured on centered random uniform arrays with \(10^5\) elements.
Examples
The following plots the error function for real-valued inputs.
>>> from differt.em.special import erf >>> >>> x = jnp.linspace(-3.0, +3.0) >>> y = erf(x) >>> plt.plot(x, y.real) >>> plt.xlabel("$x$") >>> plt.ylabel(r"$\text{erf}(x)$")
The following plots the error function for complex-valued inputs.
>>> from differt.em.special import erf >>> from scipy.special import erf >>> >>> x = y = jnp.linspace(-2.0, +2.0, 200) >>> a, b = jnp.meshgrid(x, y) >>> z = erf(a + 1j * b) >>> fig = go.Figure( ... data=[ ... go.Surface( ... x=x, ... y=y, ... z=jnp.abs(z), ... colorscale="phase", ... surfacecolor=jnp.angle(z), ... colorbar=dict(title="Arg(erf(z))"), ... ) ... ] ... ) >>> fig.update_layout( ... scene=dict( ... xaxis=dict(title="Re(z)"), ... yaxis=dict(title="Im(z)"), ... zaxis=dict(title="Abs(erf(z))"), ... ) ... ) >>> fig
- erfc(z)[source]#
Evaluate the complementary error function at the given points.
The output type (real or complex) is determined by the input type.
See
erf
for more details.- Parameters:
z (
Inexact[Array, '*batch']
) – The array of real or complex points to evaluate.- Return type:
Inexact[Array, '*batch']
- Returns:
The values of the complementary error function at the given point.
Examples
The following plots the complementary error function for real-valued inputs.
>>> from differt.em.special import erfc >>> >>> x = jnp.linspace(-3.0, +3.0) >>> y = erfc(x) >>> plt.plot(x, y.real) >>> plt.xlabel("$x$") >>> plt.ylabel(r"$\text{erfc}(x)$")
- fresnel(z)[source]#
Evaluate the two Fresnel integrals at the given points.
This current implementation is written using the error function
erf
see [Wikipediacontributors24b].The output type (real or complex) is determined by the input type.
- Parameters:
z (
Inexact[Array, '*batch']
) – The array of real or complex points to evaluate.- Return type:
tuple
[Inexact[Array, '*batch']
,Inexact[Array, '*batch']
]- Returns:
A tuple of two arrays, one for each of the Fresnel integrals.
Examples
The following plots the Fresnel for real-valued inputs.
>>> from differt.em.special import fresnel >>> >>> t = jnp.linspace(0.0, 5.0, 200) >>> s, c = fresnel(t) >>> plt.plot(t, s.real, label=r"$y=S(x)$") >>> plt.plot(t, c.real, "--", label=r"$y=C(x)$") >>> plt.xlabel("$x$") >>> plt.ylabel("$y$") >>> plt.legend()