Skip to content

Commit

Permalink
naive_mean_absolute_error
Browse files Browse the repository at this point in the history
  • Loading branch information
carrotflakes committed Sep 20, 2023
1 parent 01ad9cb commit 5663d2a
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/losses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ pub fn naive_mean_squared_error(x0: ComputedNDA, x1: ComputedNDA) -> ComputedNDA
/ ComputedNDA::new(scalar(x.shape().iter().product::<usize>() as f32))
}

pub fn naive_mean_absolute_error(x0: ComputedNDA, x1: ComputedNDA) -> ComputedNDA {
let x = (x0 - x1).abs();
x.sum(Vec::from_iter(0..x.ndim()), false)
/ ComputedNDA::new(scalar(x.shape().iter().product::<usize>() as f32))
}

pub fn softmax_cross_entropy(t: Vec<usize>, x: &ComputedNDA) -> ComputedNDA {
let n = x.shape().iter().take(x.ndim() - 1).product();
let log_z = log_sum_exp(&*x);
Expand Down

0 comments on commit 5663d2a

Please sign in to comment.