differt.utils module

differt.utils module#

General purpose utilities.

minimize(fun, x0, args=(), steps=1000, optimizer=None)[source]#

Minimize a scalar function of one or more variables.

The minimization is achieved by computing the gradient of the objective function, and performing a fixed (i.e., step) number of iterations.

Parameters:
  • fun (Callable[[Num[Array, '*batch n'], Unpack[TypeVarTuple]], Num[Array, '*batch']]) – The objective function to be minimized.

  • x0 (Num[Array, '*batch n']) – The initial guess.

  • args (tuple[Unpack[TypeVarTuple]]) –

    Positional arguments passed to fun.

    Note

    Those argument are also expected have batch dimensions similar to x0.

    If your function has static arguments, please wrap the function with functools.partial:

    fun_p = partial(fun, static_arg=static_value)
    

    If your function has keyword-only arguments, create a wrapper function that maps positional arguments to keyword only arguments:

    fun_p = lambda x, kw_only_value: fun(x, kw_only_arg=kw_only_value)
    

  • steps (int) – The number of steps to perform.

  • optimizer (Optional[GradientTransformation]) – The optimizer to use. If not provided, uses optax.adam with a learning rate of 0.1.

Return type:

tuple[Num[Array, '*batch n'], Num[Array, '*batch']]

Returns:

The solution array and the corresponding loss.

Examples

The following example shows how to minimize a basic function.

>>> from differt.utils import minimize
>>> import chex
>>>
>>> def f(x, offset=1.0):
...     x = x - offset
...     return jnp.dot(x, x)
>>>
>>> x, y = minimize(f, jnp.zeros(10))
>>> chex.assert_trees_all_close(x, jnp.ones(10), rtol=1e-2)
>>> chex.assert_trees_all_close(y, 0.0, atol=1e-4)
>>>
>>> # It is also possible to pass positional arguments
>>> x, y = minimize(f, jnp.zeros(10), args=(2.0,))
>>> chex.assert_trees_all_close(x, 2.0 * jnp.ones(10), rtol=1e-2)
>>> chex.assert_trees_all_close(y, 0.0, atol=1e-3)
>>>
>>> # You can also change the optimizer and the number of steps
>>> import optax
>>> optimizer = optax.noisy_sgd(learning_rate=0.003)
>>> x, y = minimize(f, jnp.zeros(5), args=(4.0,), steps=10000, optimizer=optimizer)
>>> chex.assert_trees_all_close(x, 4.0 * jnp.ones(5), rtol=1e-2)
>>> chex.assert_trees_all_close(y, 0.0, atol=1e-3)

This example shows how you can minimize on a batch of arrays. The signature of the objective function is (*batch, n) -> (*batch), where each batch is minimized independently.

>>> from differt.utils import minimize
>>> import chex
>>>
>>> batch = (1, 2, 3)
>>> n = 10
>>> key = jax.random.PRNGKey(1234)
>>> offset = jax.random.uniform(key, (*batch, n))
>>>
>>> def f(x, offset, scale=2.0):
...     x = scale * x - offset
...     return jnp.sum(x * x, axis=-1)
>>>
>>> x0 = jnp.zeros((*batch, n))
>>> x, y = minimize(f, x0, args=(offset,), steps=1000)
>>> chex.assert_trees_all_close(x, offset / 2.0, rtol=1e-2)
>>> chex.assert_trees_all_close(y, 0.0, atol=1e-4)
>>>
>>> # By default, arguments are expected to have batch
>>> # dimensions like `x0`, so `offset` cannot be a static
>>> # value (i.e., float):
>>> offset = 10.0
>>> x, y = minimize(
...     f, x0, args=(offset,), steps=1000
... )  
Traceback (most recent call last):
ValueError: vmap was requested to map its arguments along axis 0, ...
>>>
>>> # For static arguments, use functools.partial
>>> from functools import partial
>>>
>>> fp = partial(f, offset=offset)
>>> x, y = minimize(fp, x0, steps=1000)
>>> chex.assert_trees_all_close(x, offset * jnp.ones_like(x0) / 2.0, rtol=1e-2)
>>> chex.assert_trees_all_close(y, 0.0, atol=1e-2)
sorted_array2(array)[source]#

Sort a 2D array by row and (then) by column.

Parameters:

array (Shaped[Array, 'm n']) – The input array.

Return type:

Shaped[Array, 'm n']

Returns:

A sorted copy of the input array.

Examples

The following example shows how the sorting works.

>>> from differt.utils import (
...     sorted_array2,
... )
>>>
>>> arr = jnp.arange(10).reshape(5, 2)
>>> key = jax.random.PRNGKey(1234)
>>> (
...     key1,
...     key2,
... ) = jax.random.split(key, 2)
>>> arr = jax.random.permutation(key1, arr)
>>> arr
Array([[4, 5],
       [8, 9],
       [0, 1],
       [2, 3],
       [6, 7]], dtype=int32)
>>>
>>> sorted_array2(arr)
Array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7],
       [8, 9]], dtype=int32)
>>>
>>> arr = jax.random.randint(
...     key2,
...     (5, 5),
...     0,
...     2,
... )
>>> arr
Array([[1, 1, 1, 0, 1],
       [1, 0, 1, 1, 1],
       [1, 0, 0, 1, 1],
       [1, 0, 0, 0, 0],
       [1, 1, 0, 1, 0]], dtype=int32)
>>>
>>> sorted_array2(arr)
Array([[1, 0, 0, 0, 0],
       [1, 0, 0, 1, 1],
       [1, 0, 1, 1, 1],
       [1, 1, 0, 1, 0],
       [1, 1, 1, 0, 1]], dtype=int32)