稀疏-范畴-交叉准则和范畴-交叉准则的区别是什么?

sparse_categorical_crossentropycategorical_crossentropy有什么不同?什么时候应该使用一种损失而不是另一种损失?例如,这些损失是否适合线性回归?

59042 次浏览

Simply:

  • categorical_crossentropy (cce) produces a one-hot array containing the probable match for each category,
  • sparse_categorical_crossentropy (scce) produces a category index of the most likely matching category.

Consider a classification problem with 5 categories (or classes).

  • In the case of cce, the one-hot target may be [0, 1, 0, 0, 0] and the model may predict [.2, .5, .1, .1, .1] (probably right)

  • In the case of scce, the target index may be [1] and the model may predict: [.5].

Consider now a classification problem with 3 classes.

  • In the case of cce, the one-hot target might be [0, 0, 1] and the model may predict [.5, .1, .4] (probably inaccurate, given that it gives more probability to the first class)
  • In the case of scce, the target index might be [0], and the model may predict [.5]

Many categorical models produce scce output because you save space, but lose A LOT of information (for example, in the 2nd example, index 2 was also very close.) I generally prefer cce output for model reliability.

There are a number of situations to use scce, including:

  • when your classes are mutually exclusive, i.e. you don't care at all about other close-enough predictions,
  • the number of categories is large to the prediction output becomes overwhelming.

220405: response to "one-hot encoding" comments:

one-hot encoding is used for a category feature INPUT to select a specific category (e.g. male versus female). This encoding allows the model to train more efficiently: training weight is a product of category, which is 0 for all categories except for the given one.

cce and scce are a model OUTPUT. cce is a probability array of each category, totally 1.0. scce shows the MOST LIKELY category, totally 1.0.

scce is technically a one-hot array, just like a hammer used as a door stop is still a hammer, but its purpose is different. cce is NOT one-hot.

From the TensorFlow source code, the sparse_categorical_crossentropy is defined as categorical crossentropy with integer targets:

def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy with integer targets.
Arguments:
target: An integer tensor.
output: A tensor resulting from a softmax
(unless `from_logits` is True, in which
case `output` is expected to be the logits).
from_logits: Boolean, whether `output` is the
result of a softmax, or is a tensor of logits.
axis: Int specifying the channels axis. `axis=-1` corresponds to data
format `channels_last', and `axis=1` corresponds to data format
`channels_first`.
Returns:
Output tensor.
Raises:
ValueError: if `axis` is neither -1 nor one of the axes of `output`.
"""

From the TensorFlow source code, the categorical_crossentropy is defined as categorical cross-entropy between an output tensor and a target tensor.

def categorical_crossentropy(target, output, from_logits=False, axis=-1):
"""Categorical crossentropy between an output tensor and a target tensor.
Arguments:
target: A tensor of the same shape as `output`.
output: A tensor resulting from a softmax
(unless `from_logits` is True, in which
case `output` is expected to be the logits).
from_logits: Boolean, whether `output` is the
result of a softmax, or is a tensor of logits.
axis: Int specifying the channels axis. `axis=-1` corresponds to data
format `channels_last', and `axis=1` corresponds to data format
`channels_first`.
Returns:
Output tensor.
Raises:
ValueError: if `axis` is neither -1 nor one of the axes of `output`.
"""


The meaning of integer targets is that the target labels should be in the form of an integer list that shows the index of class, for example:

  • For sparse_categorical_crossentropy, For class 1 and class 2 targets, in a 5-class classification problem, the list should be [1,2]. Basically, the targets should be in integer form in order to call sparse_categorical_crossentropy. This is called sparse since the target representation requires much less space than one-hot encoding. For example, a batch with b targets and k classes needs b * k space to be represented in one-hot, whereas a batch with b targets and k classes needs b space to be represented in integer form.

  • For categorical_crossentropy, for class 1 and class 2 targets, in a 5-class classification problem, the list should be [[0,1,0,0,0], [0,0,1,0,0]]. Basically, the targets should be in one-hot form in order to call categorical_crossentropy.

The representation of the targets are the only difference, the results should be the same since they are both calculating categorical crossentropy.

I was also confused with this one. Fortunately, the excellent keras documentation came to the rescue. Both have the same loss function and are ultimately doing the same thing, only difference is in the representation of the true labels.

  • Categorical Cross Entropy [Doc]:

Use this crossentropy loss function when there are two or more label classes. We expect labels to be provided in a one_hot representation.

>>> y_true = [[0, 1, 0], [0, 0, 1]]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> cce = tf.keras.losses.CategoricalCrossentropy()
>>> cce(y_true, y_pred).numpy()
1.177
  • Sparse Categorical Cross Entropy [Doc]:

Use this crossentropy loss function when there are two or more label classes. We expect labels to be provided as integers.

>>> y_true = [1, 2]
>>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]
>>> # Using 'auto'/'sum_over_batch_size' reduction type.
>>> scce = tf.keras.losses.SparseCategoricalCrossentropy()
>>> scce(y_true, y_pred).numpy()
1.177

One good example of the sparse-categorical-cross-entropy is the fasion-mnist dataset.

import tensorflow as tf
from tensorflow import keras


fashion_mnist = keras.datasets.fashion_mnist
(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist.load_data()


print(y_train_full.shape) # (60000,)
print(y_train_full.dtype) # uint8


y_train_full[:10]
# array([9, 0, 0, 3, 0, 2, 7, 2, 5, 5], dtype=uint8)