differt.utils module#
General purpose utilities.
- minimize(fun, x0, args=(), steps=1000, optimizer=None)[source]#
Minimize a scalar function of one or more variables.
The minimization is achieved by computing the gradient of the objective function, and performing a fixed (i.e.,
step
) number of iterations.- Parameters:
fun (
Callable
[[Num[Array, '*batch n']
,Unpack
[TypeVarTuple
]],Num[Array, '*batch']
]) – The objective function to be minimized.x0 (
Num[Array, '*batch n']
) – The initial guess.args (
tuple
[Unpack
[TypeVarTuple
]]) –Positional arguments passed to
fun
.Note
Those argument are also expected have batch dimensions similar to
x0
.If your function has static arguments, please wrap the function with
functools.partial
:fun_p = partial(fun, static_arg=static_value)
If your function has keyword-only arguments, create a wrapper function that maps positional arguments to keyword only arguments:
fun_p = lambda x, kw_only_value: fun(x, kw_only_arg=kw_only_value)
steps (
int
) – The number of steps to perform.optimizer (
Optional
[GradientTransformation
]) – The optimizer to use. If not provided, usesoptax.adam
with a learning rate of0.1
.
- Return type:
tuple
[Num[Array, '*batch n']
,Num[Array, '*batch']
]- Returns:
The solution array and the corresponding loss.
Examples
The following example shows how to minimize a basic function.
>>> from differt.utils import minimize >>> import chex >>> >>> def f(x, offset=1.0): ... x = x - offset ... return jnp.dot(x, x) >>> >>> x, y = minimize(f, jnp.zeros(10)) >>> chex.assert_trees_all_close(x, jnp.ones(10), rtol=1e-2) >>> chex.assert_trees_all_close(y, 0.0, atol=1e-4) >>> >>> # It is also possible to pass positional arguments >>> x, y = minimize(f, jnp.zeros(10), args=(2.0,)) >>> chex.assert_trees_all_close(x, 2.0 * jnp.ones(10), rtol=1e-2) >>> chex.assert_trees_all_close(y, 0.0, atol=1e-3) >>> >>> # You can also change the optimizer and the number of steps >>> import optax >>> optimizer = optax.noisy_sgd(learning_rate=0.003) >>> x, y = minimize(f, jnp.zeros(5), args=(4.0,), steps=10000, optimizer=optimizer) >>> chex.assert_trees_all_close(x, 4.0 * jnp.ones(5), rtol=1e-2) >>> chex.assert_trees_all_close(y, 0.0, atol=1e-3)
This example shows how you can minimize on a batch of arrays. The signature of the objective function is
(*batch, n) -> (*batch)
, where each batch is minimized independently.>>> from differt.utils import minimize >>> import chex >>> >>> batch = (1, 2, 3) >>> n = 10 >>> key = jax.random.PRNGKey(1234) >>> offset = jax.random.uniform(key, (*batch, n)) >>> >>> def f(x, offset, scale=2.0): ... x = scale * x - offset ... return jnp.sum(x * x, axis=-1) >>> >>> x0 = jnp.zeros((*batch, n)) >>> x, y = minimize(f, x0, args=(offset,), steps=1000) >>> chex.assert_trees_all_close(x, offset / 2.0, rtol=1e-2) >>> chex.assert_trees_all_close(y, 0.0, atol=1e-4) >>> >>> # By default, arguments are expected to have batch >>> # dimensions like `x0`, so `offset` cannot be a static >>> # value (i.e., float): >>> offset = 10.0 >>> x, y = minimize( ... f, x0, args=(offset,), steps=1000 ... ) Traceback (most recent call last): ValueError: vmap was requested to map its arguments along axis 0, ... >>> >>> # For static arguments, use functools.partial >>> from functools import partial >>> >>> fp = partial(f, offset=offset) >>> x, y = minimize(fp, x0, steps=1000) >>> chex.assert_trees_all_close(x, offset * jnp.ones_like(x0) / 2.0, rtol=1e-2) >>> chex.assert_trees_all_close(y, 0.0, atol=1e-2)
- sorted_array2(array)[source]#
Sort a 2D array by row and (then) by column.
- Parameters:
array (
Shaped[Array, 'm n']
) – The input array.- Return type:
Shaped[Array, 'm n']
- Returns:
A sorted copy of the input array.
Examples
The following example shows how the sorting works.
>>> from differt.utils import ( ... sorted_array2, ... ) >>> >>> arr = jnp.arange(10).reshape(5, 2) >>> key = jax.random.PRNGKey(1234) >>> ( ... key1, ... key2, ... ) = jax.random.split(key, 2) >>> arr = jax.random.permutation(key1, arr) >>> arr Array([[4, 5], [8, 9], [0, 1], [2, 3], [6, 7]], dtype=int32) >>> >>> sorted_array2(arr) Array([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], dtype=int32) >>> >>> arr = jax.random.randint( ... key2, ... (5, 5), ... 0, ... 2, ... ) >>> arr Array([[1, 1, 1, 0, 1], [1, 0, 1, 1, 1], [1, 0, 0, 1, 1], [1, 0, 0, 0, 0], [1, 1, 0, 1, 0]], dtype=int32) >>> >>> sorted_array2(arr) Array([[1, 0, 0, 0, 0], [1, 0, 0, 1, 1], [1, 0, 1, 1, 1], [1, 1, 0, 1, 0], [1, 1, 1, 0, 1]], dtype=int32)