Skip to content

Commit

Permalink
Corrected EncryptingChannel bytes written for partial buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
exceptionfactory committed Jul 28, 2023
1 parent 364507b commit eafef38
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ public int write(final ByteBuffer sourceBuffer) throws IOException {
Objects.requireNonNull(sourceBuffer, "Source Buffer required");

final int sourceBufferLimit = sourceBuffer.limit();
final int sourceBufferStartPosition = sourceBuffer.position();

while (sourceBuffer.hasRemaining()) {
if (inputBuffer.remaining() == 0) {
Expand All @@ -95,7 +96,7 @@ public int write(final ByteBuffer sourceBuffer) throws IOException {
sourceBuffer.limit(sourceBufferLimit);
}

return sourceBufferLimit;
return sourceBuffer.position() - sourceBufferStartPosition;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,15 @@ class EncryptingChannelTest {
@Test
void testIsOpen() throws GeneralSecurityException, IOException {
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
final WritableByteChannel outputChannel = Channels.newChannel(outputStream);

final List<RecipientStanzaWriter> recipientStanzaWriters = Collections.singletonList(recipientStanzaWriter);
final EncryptingChannel encryptingChannel = new EncryptingChannel(outputChannel, recipientStanzaWriters, payloadKeyWriter);
final EncryptingChannel encryptingChannel = getEncryptingChannel(outputStream);

assertTrue(encryptingChannel.isOpen());
}

@Test
void testClose() throws GeneralSecurityException, IOException {
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
final WritableByteChannel outputChannel = Channels.newChannel(outputStream);

final List<RecipientStanzaWriter> recipientStanzaWriters = Collections.singletonList(recipientStanzaWriter);
when(payloadKeyWriter.writeFileHeader(any(), any())).thenReturn(PAYLOAD_KEY);
final EncryptingChannel encryptingChannel = new EncryptingChannel(outputChannel, recipientStanzaWriters, payloadKeyWriter);
final EncryptingChannel encryptingChannel = getEncryptingChannel(outputStream);

assertTrue(encryptingChannel.isOpen());
encryptingChannel.close();
Expand All @@ -113,22 +106,20 @@ void testWritePayloadException() throws GeneralSecurityException, IOException {
final EncryptingChannel encryptingChannel = new EncryptingChannel(outputChannel, recipientStanzaWriters, payloadKeyWriter);

final ByteBuffer sourceBuffer = ByteBuffer.wrap(SOURCE);
encryptingChannel.write(sourceBuffer);
final int written = encryptingChannel.write(sourceBuffer);
assertEquals(sourceBuffer.capacity(), written);

assertThrows(PayloadException.class, encryptingChannel::close);
}

@Test
void testWrite() throws GeneralSecurityException, IOException {
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
final WritableByteChannel outputChannel = Channels.newChannel(outputStream);

final List<RecipientStanzaWriter> recipientStanzaWriters = Collections.singletonList(recipientStanzaWriter);
when(payloadKeyWriter.writeFileHeader(any(), any())).thenReturn(PAYLOAD_KEY);
final EncryptingChannel encryptingChannel = new EncryptingChannel(outputChannel, recipientStanzaWriters, payloadKeyWriter);
final EncryptingChannel encryptingChannel = getEncryptingChannel(outputStream);

final ByteBuffer sourceBuffer = ByteBuffer.wrap(SOURCE);
encryptingChannel.write(sourceBuffer);
final int written = encryptingChannel.write(sourceBuffer);
assertEquals(sourceBuffer.capacity(), written);

encryptingChannel.close();

Expand All @@ -139,15 +130,12 @@ void testWrite() throws GeneralSecurityException, IOException {
@Test
void testWriteSingleChunk() throws GeneralSecurityException, IOException {
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
final WritableByteChannel outputChannel = Channels.newChannel(outputStream);

final List<RecipientStanzaWriter> recipientStanzaWriters = Collections.singletonList(recipientStanzaWriter);
when(payloadKeyWriter.writeFileHeader(any(), any())).thenReturn(PAYLOAD_KEY);
final EncryptingChannel encryptingChannel = new EncryptingChannel(outputChannel, recipientStanzaWriters, payloadKeyWriter);
final EncryptingChannel encryptingChannel = getEncryptingChannel(outputStream);

final byte[] chunk = new byte[ChunkSize.PLAIN.getSize()];
final ByteBuffer sourceBuffer = ByteBuffer.wrap(chunk);
encryptingChannel.write(sourceBuffer);
final int written = encryptingChannel.write(sourceBuffer);
assertEquals(sourceBuffer.capacity(), written);

encryptingChannel.close();

Expand All @@ -158,19 +146,42 @@ void testWriteSingleChunk() throws GeneralSecurityException, IOException {
@Test
void testWriteMultipleChunks() throws GeneralSecurityException, IOException {
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
final WritableByteChannel outputChannel = Channels.newChannel(outputStream);

final List<RecipientStanzaWriter> recipientStanzaWriters = Collections.singletonList(recipientStanzaWriter);
when(payloadKeyWriter.writeFileHeader(any(), any())).thenReturn(PAYLOAD_KEY);
final EncryptingChannel encryptingChannel = new EncryptingChannel(outputChannel, recipientStanzaWriters, payloadKeyWriter);
final EncryptingChannel encryptingChannel = getEncryptingChannel(outputStream);

final byte[] chunks = new byte[TWO_CHUNKS_LENGTH];
final ByteBuffer sourceBuffer = ByteBuffer.wrap(chunks);
encryptingChannel.write(sourceBuffer);
final int written = encryptingChannel.write(sourceBuffer);
assertEquals(sourceBuffer.capacity(), written);

encryptingChannel.close();

final byte[] bytes = outputStream.toByteArray();
assertEquals(TWO_CHUNKS_ENCRYPTED_LENGTH, bytes.length);
}

@Test
void testWriteMultipleChunksBufferSingleChunk() throws GeneralSecurityException, IOException {
final ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
final EncryptingChannel encryptingChannel = getEncryptingChannel(outputStream);

final byte[] chunks = new byte[TWO_CHUNKS_LENGTH];
final ByteBuffer sourceBuffer = ByteBuffer.wrap(chunks);
sourceBuffer.position(ChunkSize.PLAIN.getSize());
final int sourceBufferRemaining = sourceBuffer.remaining();

final int written = encryptingChannel.write(sourceBuffer);
assertEquals(sourceBufferRemaining, written);

encryptingChannel.close();

final byte[] bytes = outputStream.toByteArray();
assertEquals(ChunkSize.ENCRYPTED.getSize(), bytes.length);
}

private EncryptingChannel getEncryptingChannel(final ByteArrayOutputStream outputStream) throws GeneralSecurityException, IOException {
final WritableByteChannel outputChannel = Channels.newChannel(outputStream);
final List<RecipientStanzaWriter> recipientStanzaWriters = Collections.singletonList(recipientStanzaWriter);
when(payloadKeyWriter.writeFileHeader(any(), any())).thenReturn(PAYLOAD_KEY);
return new EncryptingChannel(outputChannel, recipientStanzaWriters, payloadKeyWriter);
}
}

0 comments on commit eafef38

Please sign in to comment.