Tensorflow has two separate functions to calculate MSE (Mean square error).

When calculating MSE, both functions are equal, but MSE with weights (Weighted MSE) are not similar.

Below is how weighted MSE differs between loss function and metrics function in Tensorflow.

LOSS WMSE

\[{ WMSE( \text{actual}_{n}, \text{predicted}_{n} ) }= \sum_{i=1}^{n} \frac{ \text { weights }_{i} \left(\text { predicted }_{\mathrm{i}}-\text { actual }_{i}\right)^{2}}{n * n}\]

Metrics WMSE

\[{ WMSE( \text{actual}_{n}, \text{predicted}_{n} ) }= \sum_{i=1}^{n} \frac{\text { weights }_{i} \left(\text { predicted }_{\mathrm{i}}-\text { actual }_{i}\right)^{2}}{n * \sum_{i=1}^{n} \text { weights }_{\mathrm{i}}}\]

Among the above, Metrics WMSE might be the right one to apply weights, but again depending on your use case, choose one.

Below is how you can verify it -

Metrics Weighted MSE

If you look at MeanSquaredError metrics function - tf.keras.metrics.MeanSquaredError() you can pass sample_weight to calculate weighted MSE.

y_true = tf.constant([[1., 9.], [2., 5.]])
y_pred = tf.constant([[4., 8.], [12., 3.]])
sample_weight = tf.constant([1.2, 0.5])

mse = tf.keras.metrics.MeanSquaredError()
mse.update_state(y_true, y_pred)
print("mse w/o weights", mse.result().numpy())

mse = tf.keras.metrics.MeanSquaredError()
mse.update_state(y_true, y_pred, sample_weight=sample_weight)
print("mse w weights", mse.result().numpy())

output -

mse w/o weights 28.5
mse w weights 18.823528

Let’s verify above output -

\[MSE = [\frac{((4 - 1)^2 + (8 - 9)^2)}{2}, \frac{((12 - 2)^2 + (3 - 5)^2)}{2}] = [5, 52] = \frac{(5+52)}{2} = 28.5\]

Now with weights

\[WMSE = \frac{(5 * 1.2 + 52 * 0.5)}{(1.2+0.5)} = 18.8235\]

Loss Weighted MSE

If you look at MeanSquaredError loss function - tf.keras.loss.MeanSquaredError() you can pass sample_weight to calculate weighted MSE.

y_true = tf.constant([[1., 9.], [2., 5.]])
y_pred = tf.constant([[4., 8.], [12., 3.]])
sample_weight = tf.constant([1.2, 0.5])

mse = tf.keras.losses.MeanSquaredError()

print("mse w/o weights", mse(y_true, y_pred).numpy())
print("mse w weights", mse(y_true, y_pred, sample_weight=sample_weight).numpy())

output -

mse w/o weights 28.5
mse w weights 16.0
\[MSE = [\frac{((4 - 1)^2 + (8 - 9)^2)}{2}, \frac{((12 - 2)^2 + (3 - 5)^2)}{2}] = [5, 52] = \frac{(5+52)}{2} = 28.5\]

Now with weights

\[WMSE = \frac{(5 * 1.2 + 52 * 0.5)}{2} = 16\]

As you see, Metrics WMSE score \(\neq\) Loss WMSE score

Ofcourse, you can use customized loss function like below to get Metrics WMSE results.

class WeightedMeanSquaredError(tf.keras.losses.Loss):
  def __init__(self,
                reduction=tf.keras.losses.Reduction.AUTO,
                name='WeightedMeanSquaredError'):
      super(WeightedMeanSquaredError, self).__init__(reduction=reduction, name=name)

  def num_of_elements(self, sample_weight):
    return tf.size(sample_weight).numpy()
    
  def __call__(self, y_true, y_pred, sample_weight):
    #return super().__call__(y_true, y_pred, 1)
    # 1/n *  sum(weights * mse(y_p - y_t))
    actual_loss = super().__call__(y_true, y_pred, sample_weight)
    num_of_el = tf.py_function(func=self.num_of_elements, inp = [sample_weight], Tout=tf.float32)

    # dividing by sum of all weights in this batch
    # (sum(weights * mse(y_p - y_t))) / sum(weights)
    wmse_loss = actual_loss * (num_of_el / tf.math.reduce_sum(sample_weight))
    return wmse_loss

  def call(self, y_true, y_pred):
    return tf.keras.losses.mean_squared_error(y_true, y_pred)

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.06), 
              loss=WeightedMeanSquaredError(), 
              weighted_metrics=[tf.keras.metrics.MeanSquaredError()])

Hope this helps