Ranking Losses (rax.*_loss)
Implementations of common ranking losses in JAX.
A ranking loss is a differentiable function that expresses the cost of a ranking
induced by item scores compared to a ranking induced from relevance labels. Rax
provides a number of ranking losses as JAX functions that are implemented
according to the LossFn interface.
Loss functions are designed to operate on the last dimension of its inputs. The
leading dimensions are considered batch dimensions. To compute per-list losses,
for example to apply per-list weighting or for distributed computing of losses
across devices, please use standard JAX transformations such as jax.vmap()
or jax.pmap().
Standalone usage:
>>> scores = jnp.array([2., 1., 3.])
>>> labels = jnp.array([1., 0., 0.])
>>> print(rax.softmax_loss(scores, labels))
1.4076059
Usage with a batch of data and a mask to indicate valid items.
>>> scores = jnp.array([[2., 1., 0.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[1., 0., 0.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> print(rax.pairwise_hinge_loss(
... scores, labels, where=where, reduce_fn=jnp.mean))
0.16666667
To compute gradients of each loss function, please use standard JAX
transformations such as jax.grad() or jax.value_and_grad():
>>> scores = jnp.asarray([[0., 1., 3.], [1., 2., 0.]])
>>> labels = jnp.asarray([[0., 0., 1.], [1., 0., 0.]])
>>> print(jax.grad(rax.softmax_loss)(scores, labels, reduce_fn=jnp.mean))
[[ 0.02100503 0.0570976 -0.07810265]
[-0.37763578 0.33262047 0.04501529]]
|
Mean squared error loss. |
|
Sigmoid cross entropy loss. |
|
Pairwise hinge loss. |
|
Pairwise logistic loss. |
|
Pairwise mean squared error loss. |
|
Softmax loss. |
|
ListMLE Loss. |
|
Poly1 softmax loss. |
|
Unique softmax loss. |
- rax.pointwise_mse_loss(scores, labels, *, where=None, weights=None, reduce_fn=<function mean>)
Mean squared error loss.
Definition:
\[\ell(s, y) = \sum_i (y_i - s_i)^2 \]- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the weight for each item.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The mean squared error loss.
- rax.pointwise_sigmoid_loss(scores, labels, *, where=None, weights=None, reduce_fn=<function mean>)
Sigmoid cross entropy loss.
Definition:
\[\ell(s, y) = \sum_i y_i * -log(sigmoid(s_i)) + (1 - y_i) * -log(1 - sigmoid(s_i)) \]This loss converts graded relevance to binary relevance by considering items with
label >= 1as relevant and items withlabel < 1as non-relevant.- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional […, list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the weight for each item.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The sigmoid cross entropy loss.
- rax.pairwise_hinge_loss(scores, labels, *, where=None, weights=None, lambdaweight_fn=None, reduce_fn=<function mean>)
Pairwise hinge loss.
Definition:
\[\ell(s, y) = \sum_i \sum_j I[y_i > y_j] \max(0, 1 - (s_i - s_j)) \]- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the weight for each item.lambdaweight_fn (
Optional[LambdaweightFn]) – An optional function that outputs lambdaweights.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The pairwise hinge loss.
- rax.pairwise_logistic_loss(scores, labels, *, where=None, weights=None, lambdaweight_fn=None, reduce_fn=<function mean>)
Pairwise logistic loss.
Definition [Burges et al., 2005]:
\[\ell(s, y) = \sum_i \sum_j I[y_i > y_j] \log(1 + \exp(-(s_i - s_j))) \]- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the weight for each item.lambdaweight_fn (
Optional[LambdaweightFn]) – An optional function that outputs lambdaweights.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The pairwise logistic loss.
- rax.pairwise_mse_loss(scores, labels, *, where=None, weights=None, lambdaweight_fn=None, reduce_fn=<function mean>)
Pairwise mean squared error loss.
Definition:
\[\ell(s, y) = \sum_i \sum_j ((y_i - y_j) - (s_i - s_j))^2 \]- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the weight for each item.lambdaweight_fn (
Optional[LambdaweightFn]) – An optional function that outputs lambdaweights.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The pairwise mean squared error loss.
- rax.softmax_loss(scores, labels, *, where=None, weights=None, label_fn=<function <lambda>>, reduce_fn=<function mean>)
Softmax loss.
Definition:
\[\ell(s, y) = - \sum_i y_i \log \frac{\exp(s_i)}{\sum_j \exp(s_j)} \]- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the weight for each item.label_fn (
Callable[...,Array]) – A label function that maps labels to probabilities. Default keeps labels as-is.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The softmax loss.
- rax.listmle_loss(scores, labels, *, key=None, where=None, reduce_fn=<function mean>)
ListMLE Loss.
Note
This loss performs sorting using the given labels. If the labels contain multiple identical values, you should provide a
PRNGKey()to thekeyargument to make sure ties are broken randomly during the sorting operation.Definition [Xia et al., 2008]:
\[\ell(s, y) = - \sum_i \log \frac{\exp(s_i)}{\sum_j I[rank(y_j) \ge rank(y_i)] \exp(s_j)} \]where \(\operatorname{rank}(y_i)\) indicates the rank of item \(i\) after sorting all labels \(y\).
- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.key (
Optional[Array]) – An optionalPRNGKey()to perform random tie-breaking.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The listmle loss.
- rax.poly1_softmax_loss(scores, labels, *, epsilon=1.0, where=None, weights=None, reduce_fn=<function mean>)
Poly1 softmax loss.
Definition [Leng et al., 2022]:
\[\ell(s, y) = softmax(s, y) + \epsilon * (1 - pt) \]where \(softmax\) is the standard softmax loss as implemented in
softmax_loss()and \(pt\) is the target softmax probability defined as:\[pt = \sum_i \frac{y_i}{\sum_j y_j} \frac{\exp(s_i)}{\sum_j \exp(s_j)} \]- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.epsilon (
float) – A float hyperparameter indicating the weight of the leading polynomial coefficient in the poly loss.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the weight for each item.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The poly1 softmax loss.
- rax.unique_softmax_loss(scores, labels, *, where=None, weights=None, gain_fn=<function default_gain_fn>, reduce_fn=<function mean>)
Unique softmax loss.
Definition [Zhu and Klabjan, 2020]:
\[\ell(s, y) = - \sum_i \operatorname{gain}(y_i) \log \frac{\exp(s_i)}{\exp(s_i) + \sum_{j : y_j < y_i} \exp(s_j)} \]where \(\operatorname{gain}(y_i)\) is a user-specified gain function applied to label \(y_i\) to boost items with higher relevance.
- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the weight for each item.gain_fn (
Optional[Callable[[Array],Array]]) – An optional function that maps relevance labels to gain values. If provided, the per-item losses are multiplied bygain_fn(label)to boost the importance of relevant items.reduce_fn (
ReduceFn) – An optional function that reduces the loss values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The unique softmax loss.
Ranking Metrics (rax.*_metric)
Implementations of common ranking metrics in JAX.
A ranking metric expresses how well a ranking induced by item scores matches a
ranking induced from relevance labels. Rax provides a number of ranking metrics
as JAX functions that are implemented according to the
MetricFn interface.
Metric functions are designed to operate on the last dimension of its inputs.
The leading dimensions are considered batch dimensions. To compute per-list
metrics, for example to apply per-list weighting or for distributed computing of
metrics across devices, please use standard JAX transformations such as
jax.vmap() or jax.pmap().
Standalone usage of a metric:
>>> import jax
>>> import rax
>>> scores = jnp.array([2., 1., 3.])
>>> labels = jnp.array([2., 0., 1.])
>>> print(rax.ndcg_metric(scores, labels))
0.79670763
Usage with a batch of data and a mask to indicate valid items:
>>> scores = jnp.array([[2., 1., 3.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[2., 0., 1.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> print(rax.ndcg_metric(scores, labels))
0.8983538
Usage with jax.vmap() batching and a mask to indicate valid items:
>>> scores = jnp.array([[2., 1., 0.], [1., 0.5, 1.5]])
>>> labels = jnp.array([[1., 0., 0.], [0., 0., 1.]])
>>> where = jnp.array([[True, True, False], [True, True, True]])
>>> print(jax.vmap(rax.ndcg_metric)(scores, labels, where=where))
[1. 1.]
|
Mean Reciprocal Rank (MRR). |
|
Precision. |
|
Recall. |
|
Average Precision. |
|
Discounted cumulative gain (DCG). |
|
Normalized discounted cumulative gain (NDCG). |
- rax.mrr_metric(scores, labels, *, where=None, topn=None, key=None, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)
Mean Reciprocal Rank (MRR).
Note
This metric converts graded relevance to binary relevance by considering items with
label >= 1as relevant and items withlabel < 1as non-relevant.Definition:
\[\operatorname{mrr}(s, y) = \max_i \frac{y_i}{\operatorname{rank}(s_i)} \]where \(\operatorname{rank}(s_i)\) indicates the rank of item \(i\) after sorting all scores \(s\) using
rank_fn.- Parameters
scores (
Array) –A
[..., list_size]-ndarray, indicating the score of each item. Items for which the score is \(-\inf\) are treatedas unranked items.
labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the metric.topn (
Optional[int]) – An optional integer value indicating at which rank the metric cuts off. IfNone, no cutoff is performed.key (
Optional[Array]) – An optionalPRNGKey(). If provided, any random operations in this metric will be based on this key.rank_fn (
RankFn) – A function that maps scores to 1-based ranks.cutoff_fn (
CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The MRR metric.
- rax.precision_metric(scores, labels, *, where=None, topn=None, key=None, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)
Precision.
Note
This metric converts graded relevance to binary relevance by considering items with
label >= 1as relevant and items withlabel < 1as non-relevant.Definition:
\[\operatorname{precision@n}(s, y) = \frac{1}{n} \sum_i y_i \cdot \mathbb{I}\left[\operatorname{rank}(s_i) \leq n\right] \]where \(\operatorname{rank}(s_i)\) indicates the rank of item \(i\) after sorting all scores \(s\) using
rank_fn.- Parameters
scores (
Array) –A
[..., list_size]-ndarray, indicating the score of each item. Items for which the score is \(-\inf\) are treatedas unranked items.
labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the metric.topn (
Optional[int]) – An optional integer value indicating at which rank the metric cuts off. IfNone, no cutoff is performed.key (
Optional[Array]) – An optionalPRNGKey(). If provided, any random operations in this metric will be based on this key.rank_fn (
RankFn) – A function that maps scores to 1-based ranks.cutoff_fn (
CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The precision metric.
- rax.recall_metric(scores, labels, *, where=None, topn=None, key=None, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)
Recall.
Note
This metric converts graded relevance to binary relevance by considering items with
label >= 1as relevant and items withlabel < 1as non-relevant.Definition:
\[\operatorname{recall@n}(s, y) = \frac{1}{\sum_i y_i} \sum_i y_i \cdot \mathbb{I}\left[\operatorname{rank}(s_i) \leq n\right] \]where \(\operatorname{rank}(s_i)\) indicates the rank of item \(i\) after sorting all scores \(s\) using rank_fn.
- Parameters
scores (
Array) –A
[..., list_size]-ndarray, indicating the score of each item. Items for which the score is \(-\inf\) are treatedas unranked items.
labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the metric.topn (
Optional[int]) – An optional integer value indicating at which rank the metric cuts off. IfNone, no cutoff is performed.key (
Optional[Array]) – An optionalPRNGKey(). If provided, any random operations in this metric will be based on this key.rank_fn (
RankFn) – A function that maps scores to 1-based ranks.cutoff_fn (
CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The recall metric.
- rax.ap_metric(scores, labels, *, where=None, topn=None, key=None, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)
Average Precision.
Note
This metric converts graded relevance to binary relevance by considering items with
label >= 1as relevant and items withlabel < 1as non-relevant.Definition:
\[\operatorname{ap}(s, y) = \frac{1}{\sum_i y_i} \sum_i y_i \operatorname{precision@rank}_{s_i}(s, y) \]where \(\operatorname{precision@rank}_{s_i}(s, y)\) indicates the precision at the rank of item \(i\).
- Parameters
scores (
Array) –A
[..., list_size]-ndarray, indicating the score of each item. Items for which the score is \(-\inf\) are treatedas unranked items.
labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the metric.topn (
Optional[int]) – An optional integer value indicating at which rank the metric cuts off. IfNone, no cutoff is performed.key (
Optional[Array]) – An optionalPRNGKey(). If provided, any random operations in this metric will be based on this key.rank_fn (
RankFn) – A function that maps scores to 1-based ranks.cutoff_fn (
CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The average precision metric.
- rax.dcg_metric(scores, labels, *, where=None, topn=None, weights=None, key=None, gain_fn=<function default_gain_fn>, discount_fn=<function default_discount_fn>, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)
Discounted cumulative gain (DCG).
Definition [Järvelin and Kekäläinen, 2002]:
\[\operatorname{dcg}(s, y) = \sum_i \operatorname{gain}(y_i) \cdot \operatorname{discount}(\operatorname{rank}(s_i)) \]where \(\operatorname{rank}(s_i)\) indicates the 1-based rank of item \(i\) as computed by
rank_fn, \(\operatorname{gain}(y)\) indicates the per-item gains as computed bygain_fn, and, \(\operatorname{discount}(r)\) indicates the per-item rank discounts as computed bydiscount_fn.- Parameters
scores (
Array) –A
[..., list_size]-ndarray, indicating the score of each item. Items for which the score is \(-\inf\) are treatedas unranked items.
labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the metric.topn (
Optional[int]) – An optional integer value indicating at which rank the metric cuts off. IfNone, no cutoff is performed.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the per-item weights.key (
Optional[Array]) – An optionalPRNGKey(). If provided, any random operations in this metric will be based on this key.gain_fn (
Callable[[Array],Array]) – A function that maps relevance label to gain values.discount_fn (
Callable[[Array],Array]) – A function that maps 1-based ranks to discount values.rank_fn (
RankFn) – A function that maps scores to 1-based ranks.cutoff_fn (
CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The DCG metric.
- rax.ndcg_metric(scores, labels, *, where=None, topn=None, weights=None, key=None, gain_fn=<function default_gain_fn>, discount_fn=<function default_discount_fn>, rank_fn=<function ranks>, cutoff_fn=<function cutoff>, reduce_fn=<function mean>)
Normalized discounted cumulative gain (NDCG).
Definition [Järvelin and Kekäläinen, 2002]:
\[\operatorname{ndcg}(s, y) = \operatorname{dcg}(s, y) / \operatorname{dcg}(y, y) \]where \(\operatorname{dcg}\) is the discounted cumulative gain metric.
- Parameters
scores (
Array) –A
[..., list_size]-ndarray, indicating the score of each item. Items for which the score is \(-\inf\) are treatedas unranked items.
labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the metric.topn (
Optional[int]) – An optional integer value indicating at which rank the metric cuts off. IfNone, no cutoff is performed.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the per-item weights.key (
Optional[Array]) – An optionalPRNGKey(). If provided, any random operations in this metric will be based on this key.gain_fn (
Callable[[Array],Array]) – A function that maps relevance label to gain values.discount_fn (
Callable[[Array],Array]) – A function that maps 1-based ranks to discount values.rank_fn (
RankFn) – A function that maps scores to 1-based ranks.cutoff_fn (
CutoffFn) – A function that maps ranks and a cutoff integer to a binary array indicating which items are cutoff.reduce_fn (
Optional[ReduceFn]) – An optional function that reduces the metric values. Can bejax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Return type
Array- Returns
The NDCG metric.
Function Transformations (rax.*_t12n)
Function transformations for ranking losses and metrics.
These function transformations can be used to transform the ranking metrics and
losses. An example is approx_t12n which transforms a given ranking metric
into a ranking loss by plugging in differentiable approximations to the rank and
cutoff functions.
Example usage:
>>> scores = jnp.asarray([0., 1., 3., 2.])
>>> labels = jnp.asarray([0., 0., 1., 2.])
>>> approx_ndcg_loss_fn = rax.approx_t12n(rax.ndcg_metric)
>>> print(approx_ndcg_loss_fn(scores, labels))
-0.71789175
|
Transforms |
|
Transforms |
|
Transforms |
- rax.approx_t12n(metric_fn, temperature=1.0)
Transforms
metric_fninto an approximate differentiable loss.This transformation and uses a sigmoid approximation to compute ranks and indicators in metrics [Qin et al., 2010]. The returned approximate metric is mapped to negative values to be used as a loss.
Example usage:
>>> approx_mrr = rax.approx_t12n(rax.mrr_metric) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 0., 1., 2.]) >>> print(approx_mrr(scores, labels)) -0.6965873
Example usage together with
rax.gumbel_t12n():>>> gumbel_approx_mrr = rax.gumbel_t12n(rax.approx_t12n(rax.mrr_metric)) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 0., 1., 2.]) >>> print(gumbel_approx_mrr(scores, labels, key=jax.random.PRNGKey(42))) -0.71880937
- rax.bound_t12n(metric_fn)
Transforms
metric_fninto a lower-bound differentiable loss.This transformation uses a hinge bound to compute ranks and indicators in metrics. The returned lower-bound of the metric is mapped to negative values to be used as a loss.
Example usage:
>>> bound_mrr = rax.bound_t12n(rax.mrr_metric) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 1., 0., 1.]) >>> print(bound_mrr(scores, labels)) -0.33333334
Example usage together with
rax.gumbel_t12n():>>> gumbel_bound_mrr = rax.gumbel_t12n(rax.bound_t12n(rax.mrr_metric)) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 1., 0., 1.]) >>> print(gumbel_bound_mrr(scores, labels, key=jax.random.PRNGKey(42))) -0.31619418
- Parameters
metric_fn (
MetricFn) – The metric function to convert to a lower-bound loss.- Returns
A loss function that computes the lower-bound version of
metric_fn.
- rax.gumbel_t12n(loss_or_metric_fn, *, samples=8, beta=1.0, smoothing_factor=None)
Transforms
loss_or_metric_fnto operate on Gumbel-sampled scores.This transformation changes given
loss_or_metric_fnso that it samples scores from a Gumbel distribution prior to computing the loss or metric [Bruch et al., 2020]. The returned function requires a newkeykeyword argument.Example usage:
>>> loss_fn = rax.gumbel_t12n(rax.softmax_loss) >>> scores = jnp.asarray([0., 1., 3., 2.]) >>> labels = jnp.asarray([0., 0., 1., 2.]) >>> print(loss_fn(scores, labels, key=jax.random.PRNGKey(42))) 6.2066536 >>> print(loss_fn(scores, labels, key=jax.random.PRNGKey(79))) 5.0127797
- Parameters
loss_or_metric_fn (
TypeVar(LossOrMetricFn,LossFn,MetricFn)) – A Rax loss or metric function.samples (
int) – Number of Gumbel samples to create.beta (
float) – Shape of the Gumbel distribution (default 1.0).smoothing_factor (
Optional[float]) – If supplied, this will apply an extralog(softmax(scores) + smoothing_factor)transformation to the scores. If set to 1e-20, this effectively makes the loss compatible with the TF-Ranking versions of Gumbel losses. Ifsmoothing_factor <= 0, this may produceNaNvalues.
- Return type
- Returns
A new function that behaves the same as
loss_or_metric_fnbut which requires an additionalkeyargument that will be used to randomly sample the scores from a Gumbel distribution.
Lambdaweights (rax.*_lambdaweight)
Implementations of lambdaweight functions for Rax pairwise losses.
Lambdaweight functions dynamically adjust the weights of a pairwise loss based
on the scores and labels. Rax provides a number of lambdaweight functions as JAX
functions that are implemented according to the
LambdaweightFn interface.
Example usage:
>>> scores = jnp.array([1.2, 0.4, 1.9])
>>> labels = jnp.array([1.0, 2.0, 0.0])
>>> loss = rax.pairwise_logistic_loss(
... scores, labels, lambdaweight_fn=rax.labeldiff_lambdaweight)
>>> print(loss)
1.8923712
|
Absolute label difference lambdaweights. |
|
DCG lambdaweights. |
|
DCG v2 ("lambdaloss") lambdaweights. |
- rax.labeldiff_lambdaweight(scores, labels, *, where=None, weights=None)
Absolute label difference lambdaweights.
Definition:
\[\lambda_{ij}(s, y) = |y_i - y_j| \]- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the lambdaweights. Items for which this is False will be ignored when computing the lambdaweights.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the weight for each item.
- Return type
Array- Returns
Absolute label difference lambdaweights.
- rax.dcg_lambdaweight(scores, labels, *, where=None, weights=None, topn=None, gain_fn=<function default_gain_fn>, discount_fn=<function default_discount_fn>)
DCG lambdaweights.
Definition [Burges et al., 2006]:
\[\lambda_{ij}(s, y) = |\operatorname{gain}(y_i) - \operatorname{gain}(y_j)| \cdot |\operatorname{discount}(\operatorname{rank}(s_i)) - \operatorname{discount}(\operatorname{rank}(s_j))| \]- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the lambdaweights. Items for which this is False will be ignored when computing the lambdaweights.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the weight for each item.topn (
Optional[int]) – The topn cutoff. IfNone, no cutoff is performed.gain_fn (
Callable[[Array],Array]) – A function mapping labels to gain values.discount_fn (
Callable[[Array],Array]) – A function mapping ranks to discount values.
- Return type
Array- Returns
DCG lambdaweights.
- rax.dcg2_lambdaweight(scores, labels, *, where=None, weights=None, topn=None, gain_fn=<function default_gain_fn>, discount_fn=<function default_discount_fn>)
DCG v2 (“lambdaloss”) lambdaweights.
Definition [Wang et al., 2018]:
\[\lambda_{ij}(s, y) = |\operatorname{gain}(y_i) - \operatorname{gain}(y_j)| \cdot |\operatorname{discount}( |\operatorname{rank}(s_i) - \operatorname{rank}(s_j)|) - \operatorname{discount}( |\operatorname{rank}(s_i) - \operatorname{rank}(s_j)| + 1)| \]- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the lambdaweights. Items for which this is False will be ignored when computing the lambdaweights.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the weight for each item.topn (
Optional[int]) – The topn cutoff. IfNone, no cutoff is performed. Topn cutoff is uses the method described in [Jagerman et al., 2022].gain_fn (
Callable[[Array],Array]) – A function mapping labels to gain values.discount_fn (
Callable[[Array],Array]) – A function mapping ranks to discount values.
- Return type
Array- Returns
DCG v2 (“lambdaloss”) lambdaweights.
Utilities
|
Computes the ranks for given scores. |
|
Computes a binary array to select the largest |
|
Computes approximate ranks. |
|
Computes a binary array to select the largest |
- rax.utils.ranks(scores, *, where=None, axis=-1, key=None)
Computes the ranks for given scores.
Note that the ranks returned by this function are not differentiable due to the sort operation having no gradients.
- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid.axis (
int) – The axis to sort on, by default this is the last axis.key (
Optional[Array]) – An optionaljax.random.PRNGKey(). If provided, ties will be broken randomly using this key. If not provided, ties will retain the order of their appearance in the scores array.
- Return type
Array- Returns
A tensor with the same shape as scores that indicates the 1-based rank of each item.
- rax.utils.cutoff(a, n=None, *, where=None, step_fn=<function <lambda>>)
Computes a binary array to select the largest
nvalues ofa.This function computes a binary
jax.numpy.ndarraythat selects thenlargest values ofaacross its last dimension.Note that the returned indicator may select more than
nitems ifahas ties.- Parameters
a (
Array) – Thejax.numpy.ndarrayto select the topn from.n (
Optional[int]) – The cutoff value. If None, no cutoff is performed.where (
Optional[Array]) – A mask to indicate which values to include in the topn calculation.step_fn (
Callable[[Array],Array]) – A function that computesx >= 0or an approximation.
- Return type
Array- Returns
A
jax.numpy.ndarrayof the same shape asa, where thenlargest values are set to 1, and the smaller values are set to 0.
- rax.utils.approx_ranks(scores, *, where=None, key=None, step_fn=<CompiledFunction of <function sigmoid>>)
Computes approximate ranks.
This can be used to construct differentiable approximations of metrics. For example:
>>> import functools >>> approx_ndcg = functools.partial( ... rax.ndcg_metric, rank_fn=rax.utils.approx_ranks) >>> scores = jnp.asarray([-1., 1., 0.]) >>> labels = jnp.asarray([0., 0., 1.]) >>> print(approx_ndcg(scores, labels)) 0.63092977 >>> print(jax.grad(approx_ndcg)(scores, labels)) [-0.03763788 -0.03763788 0.07527576]
- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid.key (
Optional[Array]) – An optionaljax.random.PRNGKey(). Unused byapprox_ranks.step_fn (
Callable[[Array],Array]) – A callable that approximates the step functionx >= 0.
- Return type
Array- Returns
A
ndarrayof the same shape asscores, indicating the 1-based approximate rank of each item.
- rax.utils.approx_cutoff(a, n=None, *, where=None, step_fn=<function <lambda>>)
Computes a binary array to select the largest
nvalues ofa.This function computes a binary
jax.numpy.ndarraythat selects thenlargest values ofaacross its last dimension.Note that the returned indicator may select more than
nitems ifahas ties.- Parameters
a (
Array) – Thejax.numpy.ndarrayto select the topn from.n (
Optional[int]) – The cutoff value. If None, no cutoff is performed.where (
Optional[Array]) – A mask to indicate which values to include in the topn calculation.step_fn (
Callable[[Array],Array]) – A function that computesx >= 0or an approximation.
- Return type
Array- Returns
A
jax.numpy.ndarrayof the same shape asa, where thenlargest values are set to 1, and the smaller values are set to 0.
Types
Rax-specific types and protocols.
Note
Types and protocols are provided for type-checking convenience only. You do not need to instantiate, subclass or extend them.
|
|
|
|
|
|
|
|
|
|
|
|
- class rax.types.CutoffFn(*args, **kwds)
typing.Protocolfor cutoff functions.- __call__(a, n)
Computes cutoffs based on the given array.
- class rax.types.LambdaweightFn(*args, **kwds)
typing.Protocolfor lambdaweight functions.- __call__(scores, labels, *, where, weights, **kwargs)
Computes lambdaweights.
- Parameters
scores (
Array) – A[..., list_size]-ndarray, indicating the score of each item.labels (
Array) – A[..., list_size]-ndarray, indicating the relevance label for each item.where (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating which items are valid for computing the lambdaweights. Items for which this is False will be ignored when computing the lambdaweights.weights (
Optional[Array]) – An optional[..., list_size]-ndarray, indicating the weight for each item.**kwargs – Optional lambdaweight-specific keyword arguments.
- Return type
Array- Returns
A
jax.numpy.ndarraythat represents the lambda weights.
- class rax.types.LossFn(*args, **kwds)
typing.Protocolfor loss functions.- __call__(scores, labels, *, where, **kwargs)
Computes a loss.
- Parameters
scores (
Array) – The score of each item.labels (
Array) – The label of each item.where (
Optional[Array]) – An optionaljax.numpy.ndarrayof the same shape asscoresthat indicates which elements to include in the loss.**kwargs – Optional loss-specific keyword arguments.
- Return type
Array- Returns
A
jax.numpy.ndarraythat represents the loss computed on the givenscoresandlabels.
- class rax.types.MetricFn(*args, **kwds)
typing.Protocolfor metric functions.- __call__(scores, labels, *, where, **kwargs)
Computes a metric.
- Parameters
scores (
Array) – The score of each item.labels (
Array) – The label of each item.where (
Optional[Array]) – An optionaljax.numpy.ndarrayof the same shape asscoresthat indicates which elements to include in the metric.**kwargs – Optional metric-specific keyword arguments.
- Return type
Array- Returns
A
jax.numpy.ndarraythat represents the metric computed on the givenscoresandlabels.
- class rax.types.RankFn(*args, **kwds)
typing.Protocolfor rank functions.- __call__(scores, where, key)
Computes 1-based ranks based on the given scores.
- Parameters
- Return type
Array- Returns
A
jax.numpy.ndarrayof the same shape asscoresthat represents the 1-based ranks.
- class rax.types.ReduceFn(*args, **kwds)
typing.Protocolfor reduce functions.- __call__(a, where, axis)
Reduces an array across one or more dimensions.
- Parameters
a (
Array) – The array to reduce.where (
Optional[Array]) – An optionaljax.numpy.ndarrayof the same shape asathat indicates which elements to include in the reduction.axis (
Union[int,Tuple[int,...],None]) – One or more axes to use for the reduction. IfNonethis reduces across all available axes.
- Return type
Array- Returns
A
jax.numpy.ndarraythat represents the reduced result ofaover givenaxis.
References
- BHBN20
Sebastian Bruch, Shuguang Han, Michael Bendersky, and Marc Najork. A stochastic treatment of learning to rank scoring functions. In Proceedings of the 13th International Conference on Web Search and Data Mining, 61–69. 2020.
- BSR+05
Chris Burges, Tal Shaked, Erin Renshaw, Ari Lazier, Matt Deeds, Nicole Hamilton, and Greg Hullender. Learning to rank using gradient descent. In Proceedings of the 22nd international conference on Machine learning, 89–96. 2005.
- BRL06
Christopher Burges, Robert Ragno, and Quoc Le. Learning to rank with nonsmooth cost functions. In Advances in Neural Information Processing Systems, volume 19, 193–200. 2006.
- JQW+22
Rolf Jagerman, Zhen Qin, Xuanhui Wang, Mike Bendersky, and Marc Najork. On optimizing top-k metrics for neural ranking models. In Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval, 2303–2307. 2022.
- JarvelinKekalainen02(1,2)
Kalervo Järvelin and Jaana Kekäläinen. Cumulated gain-based evaluation of ir techniques. ACM Transactions on Information Systems (TOIS), 20(4):422–446, 2002.
- LTL+22
Zhaoqi Leng, Mingxing Tan, Chenxi Liu, Ekin Dogus Cubuk, Jay Shi, Shuyang Cheng, and Dragomir Anguelov. Polyloss: a polynomial expansion perspective of classification loss functions. In International Conference on Learning Representations. 2022.
- QLL10
Tao Qin, Tie-Yan Liu, and Hang Li. A general approximation framework for direct optimization of information retrieval measures. Information retrieval, 13(4):375–397, 2010.
- WLG+18
Xuanhui Wang, Cheng Li, Nadav Golbandi, Mike Bendersky, and Marc Najork. The lambdaloss framework for ranking metric optimization. In Proceedings of The 27th ACM International Conference on Information and Knowledge Management, 1313–1322. 2018.
- XLW+08
Fen Xia, Tie-Yan Liu, Jue Wang, Wensheng Zhang, and Hang Li. Listwise approach to learning to rank: theory and algorithm. In Proceedings of the 25th international conference on Machine learning, 1192–1199. 2008.
- ZK20
Xiaofeng Zhu and Diego Klabjan. Listwise learning to rank by exploring unique ratings. In Proceedings of the 13th international conference on web search and data mining, 798–806. 2020.