diff --git a/src/ImageSharp/Formats/Heif/Av1/Transform/Av1SinusConstants.cs b/src/ImageSharp/Formats/Heif/Av1/Transform/Av1SinusConstants.cs index 432402677f..241730c6b2 100644 --- a/src/ImageSharp/Formats/Heif/Av1/Transform/Av1SinusConstants.cs +++ b/src/ImageSharp/Formats/Heif/Av1/Transform/Av1SinusConstants.cs @@ -1,8 +1,6 @@ // Copyright (c) Six Labors. // Licensed under the Six Labors Split License. -using System; - namespace SixLabors.ImageSharp.Formats.Heif.Av1.Transform; internal static class Av1SinusConstants diff --git a/src/ImageSharp/Formats/Heif/Av1/Transform/Av1Transform2dFlipConfiguration.cs b/src/ImageSharp/Formats/Heif/Av1/Transform/Av1Transform2dFlipConfiguration.cs index 534edd3f1d..4ebed44c6b 100644 --- a/src/ImageSharp/Formats/Heif/Av1/Transform/Av1Transform2dFlipConfiguration.cs +++ b/src/ImageSharp/Formats/Heif/Av1/Transform/Av1Transform2dFlipConfiguration.cs @@ -142,8 +142,8 @@ public Av1Transform2dFlipConfiguration(Av1TransformType transformType, Av1Transf this.TransformFunctionTypeRow = TransformFunctionTypeMap[txw_idx][(int)tx_type_1d_row]; this.StageNumberColumn = StageNumberList[(int)this.TransformFunctionTypeColumn]; this.StageNumberRow = StageNumberList[(int)this.TransformFunctionTypeRow]; - this.StageRangeColumn = new int[12]; - this.StageRangeRow = new int[12]; + this.StageRangeColumn = new byte[12]; + this.StageRangeRow = new byte[12]; this.NonScaleRange(); } @@ -169,9 +169,9 @@ public Av1Transform2dFlipConfiguration(Av1TransformType transformType, Av1Transf public Span Shift => this.shift; - public int[] StageRangeColumn { get; } + public byte[] StageRangeColumn { get; } - public int[] StageRangeRow { get; } + public byte[] StageRangeRow { get; } /// /// SVT: svt_av1_gen_fwd_stage_range @@ -184,13 +184,13 @@ public void GenerateStageRange(int bitDepth) // i < MAX_TXFM_STAGE_NUM will mute above array bounds warning for (int i = 0; i < this.StageNumberColumn && i < MaxStageNumber; ++i) { - this.StageRangeColumn[i] = this.StageRangeColumn[i] + shift[0] + bitDepth + 1; + this.StageRangeColumn[i] = (byte)(this.StageRangeColumn[i] + shift[0] + bitDepth + 1); } // i < MAX_TXFM_STAGE_NUM will mute above array bounds warning for (int i = 0; i < this.StageNumberRow && i < MaxStageNumber; ++i) { - this.StageRangeRow[i] = this.StageRangeRow[i] + shift[0] + shift[1] + bitDepth + 1; + this.StageRangeRow[i] = (byte)(this.StageRangeRow[i] + shift[0] + shift[1] + bitDepth + 1); } } @@ -296,7 +296,7 @@ private void NonScaleRange() int stage_num_col = this.StageNumberColumn; for (int i = 0; i < stage_num_col; ++i) { - this.StageRangeColumn[i] = (range_mult2_col[i] + 1) >> 1; + this.StageRangeColumn[i] = (byte)((range_mult2_col[i] + 1) >> 1); } } @@ -306,7 +306,7 @@ private void NonScaleRange() Span range_mult2_row = RangeMulti2List[(int)this.TransformFunctionTypeRow]; for (int i = 0; i < stage_num_row; ++i) { - this.StageRangeRow[i] = (range_mult2_col[this.StageNumberColumn - 1] + range_mult2_row[i] + 1) >> 1; + this.StageRangeRow[i] = (byte)((range_mult2_col[this.StageNumberColumn - 1] + range_mult2_row[i] + 1) >> 1); } } } diff --git a/src/ImageSharp/Formats/Heif/Av1/Transform/Forward/Av1Dct4ForwardTransformer.cs b/src/ImageSharp/Formats/Heif/Av1/Transform/Forward/Av1Dct4ForwardTransformer.cs index db0134c36f..0c46e8f480 100644 --- a/src/ImageSharp/Formats/Heif/Av1/Transform/Forward/Av1Dct4ForwardTransformer.cs +++ b/src/ImageSharp/Formats/Heif/Av1/Transform/Forward/Av1Dct4ForwardTransformer.cs @@ -10,7 +10,39 @@ namespace SixLabors.ImageSharp.Formats.Heif.Av1.Transform.Forward; internal class Av1Dct4ForwardTransformer : IAv1ForwardTransformer { public void Transform(ref int input, ref int output, int cosBit, Span stageRange) - => throw new NotImplementedException(); + { + Span cospi = Av1SinusConstants.CosinusPi(cosBit); + ref int bf0 = ref output; + ref int bf1 = ref output; + Span stepSpan = new int[4]; + ref int step0 = ref stepSpan[0]; + ref int step1 = ref Unsafe.Add(ref step0, 1); + ref int step2 = ref Unsafe.Add(ref step0, 2); + ref int step3 = ref Unsafe.Add(ref step0, 3); + ref int output1 = ref Unsafe.Add(ref output, 1); + ref int output2 = ref Unsafe.Add(ref output, 2); + ref int output3 = ref Unsafe.Add(ref output, 3); + + // stage 0; + + // stage 1; + output = input + Unsafe.Add(ref input, 3); + output1 = Unsafe.Add(ref input, 1) + Unsafe.Add(ref input, 2); + output2 = -Unsafe.Add(ref input, 2) + Unsafe.Add(ref input, 1); + output3 = -Unsafe.Add(ref input, 3) + Unsafe.Add(ref input, 0); + + // stage 2 + step0 = HalfBtf(cospi[32], output, cospi[32], output1, cosBit); + step1 = HalfBtf(-cospi[32], output1, cospi[32], output, cosBit); + step2 = HalfBtf(cospi[48], output2, cospi[16], output3, cosBit); + step3 = HalfBtf(cospi[48], output3, -cospi[16], output2, cosBit); + + // stage 3 + output = step0; + output1 = step2; + output2 = step1; + output3 = step3; + } public void TransformAvx2(ref Vector256 input, ref Vector256 output, int cosBit, int columnNumber) => throw new NotImplementedException("Too small block for Vector implementation, use TransformSse() method instead."); @@ -20,7 +52,8 @@ public void TransformAvx2(ref Vector256 input, ref Vector256 output, i /// public static void TransformSse(ref Vector128 input, ref Vector128 output, byte cosBit, int columnNumber) { - /* +#pragma warning disable CA1857 // A constant is expected for the parameter + // We only use stage-2 bit; // shift[0] is used in load_buffer_4x4() // shift[1] is used in txfm_func_col() @@ -35,51 +68,71 @@ public static void TransformSse(ref Vector128 input, ref Vector128 out Vector128 v0, v1, v2, v3; int endidx = 3 * columnNumber; - s0 = Sse41.Add(input, Unsafe.Add(ref input, endidx)); - s3 = Sse41.Subtract(input, Unsafe.Add(ref input, endidx)); + s0 = Sse2.Add(input, Unsafe.Add(ref input, endidx)); + s3 = Sse2.Subtract(input, Unsafe.Add(ref input, endidx)); endidx -= columnNumber; - s1 = Sse41.Add(Unsafe.Add(ref input, columnNumber), Unsafe.Add(ref input, endidx)); - s2 = Sse41.Subtract(Unsafe.Add(ref input, columnNumber), Unsafe.Add(ref input, endidx)); + s1 = Sse2.Add(Unsafe.Add(ref input, columnNumber), Unsafe.Add(ref input, endidx)); + s2 = Sse2.Subtract(Unsafe.Add(ref input, columnNumber), Unsafe.Add(ref input, endidx)); // btf_32_sse4_1_type0(cospi32, cospi32, s[01], u[02], bit); u0 = Sse41.MultiplyLow(s0, cospi32); u1 = Sse41.MultiplyLow(s1, cospi32); - u2 = Sse41.Add(u0, u1); - v0 = Sse41.Subtract(u0, u1); + u2 = Sse2.Add(u0, u1); + v0 = Sse2.Subtract(u0, u1); - u3 = Sse41.Add(u2, rnding); - v1 = Sse41.Add(v0, rnding); + u3 = Sse2.Add(u2, rnding); + v1 = Sse2.Add(v0, rnding); - u0 = Sse41.ShiftRightArithmetic(u3, cosBit); - u2 = Sse41.ShiftRightArithmetic(v1, cosBit); + u0 = Sse2.ShiftRightArithmetic(u3, cosBit); + u2 = Sse2.ShiftRightArithmetic(v1, cosBit); // btf_32_sse4_1_type1(cospi48, cospi16, s[23], u[13], bit); v0 = Sse41.MultiplyLow(s2, cospi48); v1 = Sse41.MultiplyLow(s3, cospi16); - v2 = Sse41.Add(v0, v1); + v2 = Sse2.Add(v0, v1); - v3 = Sse41.Add(v2, rnding); - u1 = Sse41.ShiftRightArithmetic(v3, cosBit); + v3 = Sse2.Add(v2, rnding); + u1 = Sse2.ShiftRightArithmetic(v3, cosBit); v0 = Sse41.MultiplyLow(s2, cospi16); v1 = Sse41.MultiplyLow(s3, cospi48); - v2 = Sse41.Subtract(v1, v0); + v2 = Sse2.Subtract(v1, v0); - v3 = Sse41.Add(v2, rnding); - u3 = Sse41.ShiftRightArithmetic(v3, cosBit); + v3 = Sse2.Add(v2, rnding); + u3 = Sse2.ShiftRightArithmetic(v3, cosBit); // Note: shift[1] and shift[2] are zeros // Transpose 4x4 32-bit - v0 = Sse41.UnpackLow(u0, u1); - v1 = Sse41.UnpackHigh(u0, u1); - v2 = Sse41.UnpackLow(u2, u3); - v3 = Sse41.UnpackHigh(u2, u3); - - output = Sse41.UnpackLow(v0.AsInt64(), v2.AsInt64()).AsInt32(); - Unsafe.Add(ref output, 1) = Sse41.UnpackHigh(v0.AsInt64(), v2.AsInt64()).AsInt32(); - Unsafe.Add(ref output, 2) = Sse41.UnpackLow(v1.AsInt64(), v3.AsInt64()).AsInt32(); - Unsafe.Add(ref output, 3) = Sse41.UnpackHigh(v1.AsInt64(), v3.AsInt64()).AsInt32(); - */ + v0 = Sse2.UnpackLow(u0, u1); + v1 = Sse2.UnpackHigh(u0, u1); + v2 = Sse2.UnpackLow(u2, u3); + v3 = Sse2.UnpackHigh(u2, u3); + + output = Sse2.UnpackLow(v0.AsInt64(), v2.AsInt64()).AsInt32(); + Unsafe.Add(ref output, 1) = Sse2.UnpackHigh(v0.AsInt64(), v2.AsInt64()).AsInt32(); + Unsafe.Add(ref output, 2) = Sse2.UnpackLow(v1.AsInt64(), v3.AsInt64()).AsInt32(); + Unsafe.Add(ref output, 3) = Sse2.UnpackHigh(v1.AsInt64(), v3.AsInt64()).AsInt32(); +#pragma warning restore CA1857 // A constant is expected for the parameter + } + + private static int HalfBtf(int w0, int in0, int w1, int in1, int bit) + { + long result64 = (long)(w0 * in0) + (w1 * in1); + long intermediate = result64 + (1L << (bit - 1)); + + // NOTE(david.barker): The value 'result_64' may not necessarily fit + // into 32 bits. However, the result of this function is nominally + // ROUND_POWER_OF_TWO_64(result_64, bit) + // and that is required to fit into stage_range[stage] many bits + // (checked by range_check_buf()). + // + // Here we've unpacked that rounding operation, and it can be shown + // that the value of 'intermediate' here *does* fit into 32 bits + // for any conformant bitstream. + // The upshot is that, if you do all this calculation using + // wrapping 32-bit arithmetic instead of (non-wrapping) 64-bit arithmetic, + // then you'll still get the correct result. + return (int)(intermediate >> bit); } } diff --git a/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ForwardTransformTests.cs b/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ForwardTransformTests.cs index fb286cf70e..b92599c86e 100644 --- a/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ForwardTransformTests.cs +++ b/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ForwardTransformTests.cs @@ -3,6 +3,7 @@ using System.Runtime.CompilerServices; using SixLabors.ImageSharp.Formats.Heif.Av1.Transform; +using SixLabors.ImageSharp.Formats.Heif.Av1.Transform.Forward; namespace SixLabors.ImageSharp.Tests.Formats.Heif.Av1; @@ -35,22 +36,49 @@ public class Av1ForwardTransformTests 36, // 64x16 transform ]; - private readonly short[] inputOfTest; - private readonly int[] outputOfTest; - private readonly double[] inputReference; - private readonly double[] outputReference; - - public Av1ForwardTransformTests() + [Theory] + [MemberData(nameof(GetSizes))] + public void AccuracyDct1dTest(int txSize) { - this.inputOfTest = new short[64 * 64]; - this.outputOfTest = new int[64 * 64]; - this.inputReference = new double[64 * 64]; - this.outputReference = new double[64 * 64]; + Random rnd = new(0); + const int testBlockCount = 1; // Originally set to: 1000 + Av1TransformSize transformSize = (Av1TransformSize)txSize; + Av1Transform2dFlipConfiguration config = new(Av1TransformType.DctDct, transformSize); + int width = config.TransformSize.GetWidth(); + + int[] inputOfTest = new int[width]; + double[] inputReference = new double[width]; + int[] outputOfTest = new int[width]; + double[] outputReference = new double[width]; + for (int ti = 0; ti < testBlockCount; ++ti) + { + // prepare random test data + for (int ni = 0; ni < width; ++ni) + { + inputOfTest[ni] = (short)rnd.Next((1 << 10) - 1); + inputReference[ni] = inputOfTest[ni]; + outputReference[ni] = 0; + outputOfTest[ni] = 255; + } + + // calculate in forward transform functions + new Av1Dct4ForwardTransformer().Transform( + ref inputOfTest[0], + ref outputOfTest[0], + config.CosBitColumn, + config.StageRangeColumn); + + // calculate in reference forward transform functions + Av1ReferenceTransform.ReferenceDct1d(inputReference, outputReference, width); + + // Assert + Assert.True(CompareWithError(outputReference, outputOfTest, 1)); + } } - // [Theory] - // [MemberData(nameof(GetCombinations))] - public void Accuracy2dTest(int txSize, int txType, int maxAllowedError) + [Theory] + [MemberData(nameof(GetCombinations))] + public void Accuracy2dTest(int txSize, int txType, int maxAllowedError = 0) { const int bitDepth = 8; Random rnd = new(0); @@ -63,53 +91,49 @@ public void Accuracy2dTest(int txSize, int txType, int maxAllowedError) int blockSize = width * height; double scaleFactor = Av1ReferenceTransform.GetScaleFactor(config, width, height); + short[] inputOfTest = new short[blockSize]; + double[] inputReference = new double[blockSize]; + int[] outputOfTest = new int[blockSize]; + double[] outputReference = new double[blockSize]; for (int ti = 0; ti < testBlockCount; ++ti) { // prepare random test data for (int ni = 0; ni < blockSize; ++ni) { - this.inputOfTest[ni] = (short)rnd.Next((1 << 10) - 1); - this.inputReference[ni] = this.inputOfTest[ni]; - this.outputReference[ni] = 0; - this.outputOfTest[ni] = 255; + inputOfTest[ni] = (short)rnd.Next((1 << 10) - 1); + inputReference[ni] = inputOfTest[ni]; + outputReference[ni] = 0; + outputOfTest[ni] = 255; } // calculate in forward transform functions Av1ForwardTransformer.Transform2d( - this.inputOfTest, - this.outputOfTest, + inputOfTest, + outputOfTest, (uint)transformSize.GetWidth(), transformType, transformSize, bitDepth); // calculate in reference forward transform functions - Av1ReferenceTransform.ReferenceTransformFunction2d(this.inputReference, this.outputReference, transformType, transformSize, scaleFactor); + Av1ReferenceTransform.ReferenceTransformFunction2d(inputReference, outputReference, transformType, transformSize, scaleFactor); // repack the coefficents for some tx_size - this.RepackCoefficients(width, height); + RepackCoefficients(outputOfTest, outputReference, width, height); - // compare for the result is in accuracy - double maximumErrorInTest = 0; - for (int ni = 0; ni < blockSize; ++ni) - { - maximumErrorInTest = Math.Max(maximumErrorInTest, Math.Abs(this.outputOfTest[ni] - Math.Round(this.outputReference[ni]))); - } - - maximumErrorInTest /= scaleFactor; - Assert.True(maxAllowedError >= maximumErrorInTest, $"Forward transform 2d test with transform type: {transformType}, transform size: {transformSize} and loop: {ti}"); + Assert.True(CompareWithError(outputReference, outputOfTest, maxAllowedError * scaleFactor), $"Forward transform 2d test with transform type: {transformType}, transform size: {transformSize} and loop: {ti}"); } } // The max txb_width or txb_height is 32, as specified in spec 7.12.3. // Clear the high frequency coefficents and repack it in linear layout. - private void RepackCoefficients(int tx_width, int tx_height) + private static void RepackCoefficients(Span outputOfTest, Span outputReference, int tx_width, int tx_height) { for (int i = 0; i < 2; ++i) { uint e_size = i == 0 ? (uint)sizeof(int) : sizeof(double); - ref byte output = ref (i == 0) ? ref Unsafe.As(ref this.outputOfTest[0]) - : ref Unsafe.As(ref this.outputReference[0]); + ref byte output = ref (i == 0) ? ref Unsafe.As(ref outputOfTest[0]) + : ref Unsafe.As(ref outputReference[0]); if (tx_width == 64 && tx_height == 64) { @@ -188,13 +212,34 @@ ref Unsafe.Add(ref output, row * 64 * e_size), } } + private static bool CompareWithError(Span expected, Span actual, double allowedError) + { + // compare for the result is witghin accuracy + double maximumErrorInTest = 0; + for (int ni = 0; ni < expected.Length; ++ni) + { + maximumErrorInTest = Math.Max(maximumErrorInTest, Math.Abs(actual[ni] - Math.Round(expected[ni]))); + } + + return maximumErrorInTest <= allowedError; + } + + public static TheoryData GetSizes() + { + TheoryData sizes = []; + + // For now test only 4x4. + sizes.Add(0); + return sizes; + } + public static TheoryData GetCombinations() { TheoryData combinations = []; for (int s = 0; s < (int)Av1TransformSize.AllSizes; s++) { double maxError = MaximumAllowedError[s]; - for (int t = 0; t < (int)Av1TransformType.AllTransformTypes; ++t) + for (int t = 0; t < (int)Av1TransformType.AllTransformTypes; t++) { Av1TransformType transformType = (Av1TransformType)t; Av1TransformSize transformSize = (Av1TransformSize)s; @@ -203,7 +248,13 @@ public static TheoryData GetCombinations() { combinations.Add(s, t, (int)maxError); } + + // For now only DCT. + break; } + + // For now only 4x4. + break; } return combinations; diff --git a/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ReferenceTransform.cs b/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ReferenceTransform.cs index f490ead2e6..5cb91ca44d 100644 --- a/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ReferenceTransform.cs +++ b/tests/ImageSharp.Tests/Formats/Heif/Av1/Av1ReferenceTransform.cs @@ -174,7 +174,7 @@ private static void ReferenceIdentity1d(Span input, Span output, } } - private static void ReferenceDct1d(Span input, Span output, int size) + internal static void ReferenceDct1d(Span input, Span output, int size) { const double kInvSqrt2 = 0.707106781186547524400844362104f; for (int k = 0; k < size; ++k)