StreamingMultipartPartition.java

package com.renomad.minum.web;


import com.renomad.minum.utils.RingBuffer;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.stream.IntStream;

/**
 * This class represents a single partition in a multipart/form
 * Request body, when read as an InputStream.  This enables the
 * developer to pull data incrementally, rather than reading it
 * all into memory at once.
 */
public class StreamingMultipartPartition extends InputStream {

    private final Headers headers;
    private final InputStream inputStream;
    private final ContentDisposition contentDisposition;
    private final int contentLength;
    /**
     * After we hit the boundary, we will set this flag to true, and all
     * subsequent reads will return -1.
     */
    private boolean isFinished = false;

    /**
     * This buffer follows along with what we are reading, so we can
     * easily compare against our boundary value.  There are four extra
     * bytes included, since multipart splits the content by two
     * dashes, followed by the boundary value, and then two dashes afterwards
     * on the last boundary.
     * <pre>
     * That is,
     * for a typical boundary:
     *
     *   --boundary_value
     *
     * and for the last boundary:
     *
     *   --boundary_value--
     *</pre>
     */
    private final RingBuffer<Byte> recentBytesBuffer;
    private final CountBytesRead countBytesRead;
    private final List<Byte> boundaryValueList;
    private boolean hasFilledBuffer;

    StreamingMultipartPartition(Headers headers,
                                       InputStream inputStream,
                                       ContentDisposition contentDisposition,
                                       String boundaryValue,
                                       CountBytesRead countBytesRead,
                                       int contentLength) {

        this.headers = headers;
        this.inputStream = inputStream;
        this.contentDisposition = contentDisposition;
        this.contentLength = contentLength;
        String boundaryValue1 = "\r\n--" + boundaryValue;
        byte[] bytes = boundaryValue1.getBytes(StandardCharsets.US_ASCII);
        boundaryValueList = IntStream.range(0, bytes.length).mapToObj(i -> bytes[i]).toList();

        /*
         * To explain the numbers here: we add one at the beginning to represent
         * the single character at the far left that is what we will actually return.
         * We have to fill the cache before we start sending anything.  The number
         * at the end represents the extra characters of the boundary - dashes,
         * carriage return, newline.
         */
        recentBytesBuffer = new RingBuffer<>(boundaryValue1.length(), Byte.class);
        this.countBytesRead = countBytesRead;
    }

    public Headers getHeaders() {
        return headers;
    }

    public ContentDisposition getContentDisposition() {
        return contentDisposition;
    }


    /**
     * Reads from the inputstream using a buffer for checking whether we've
     * hit the end of a multpart partition.
     * @return -1 if we're at the end of a partition
     * @throws IOException if the inputstream is closed unexpectedly while reading.
     */
    @Override
    public int read() throws IOException {
        if (isFinished) {
            return -1;
        }
        if (!hasFilledBuffer) {
            fillBuffer();
            boolean atTheEnd = recentBytesBuffer.containsAt(boundaryValueList, 0);
            if (atTheEnd) {
                // don't really do anything with this, it's just to collect the
                // last characters to have a clean finish.
                byte[] unused = inputStream.readNBytes(2);
                isFinished = true;
                return -1;
            }
        } else {
            int result = inputStream.read();
            countBytesRead.increment();
            if (countBytesRead.getCount() >= contentLength) {
                isFinished = true;
                return -1;
            }
            if (result == -1) {
                throw new IOException("Error: The inputstream has closed unexpectedly while reading");
            }
            byte byteValue = (byte) result;
            boolean isAtEndOfPartition = updateRecentBytesBufferAndCheck(byteValue);
            if (isAtEndOfPartition) {
                // don't really do anything with this, it's just to collect the
                // last characters to have a clean finish.
                byte[] unused = inputStream.readNBytes(2);
                isFinished = true;
                return -1;
            }

        }
        return ((int)recentBytesBuffer.atNextIndex()) & 0xff;
    }

    private void fillBuffer() throws IOException {
        for (int i = 0; i < recentBytesBuffer.getLimit(); i++) {
            int result = inputStream.read();
            countBytesRead.increment();
            if (result == -1) {
                throw new IOException("Error: The inputstream has closed unexpectedly while reading");
            }
            byte byteValue = (byte) result;
            updateRecentBytesBufferAndCheck(byteValue);
        }
        hasFilledBuffer = true;
    }

    @Override
    public byte[] readAllBytes()  {
        var baos = new ByteArrayOutputStream();
        while (true) {
            int result = 0;
            try {
                result = read();
            } catch (IOException e) {
                throw new WebServerException(e);
            }
            if (result == -1) {
                return baos.toByteArray();
            }
            baos.write((byte)result);
        }
    }

    /**
     * Updates the buffer with the last characters read, and returns
     * true if we have encountered the end of this partition.
     */
    private boolean updateRecentBytesBufferAndCheck(byte newByte) {
        recentBytesBuffer.add(newByte);
        return recentBytesBuffer.containsAt(boundaryValueList, 0);
    }

    /**
     * By "close", we will read from the {@link InputStream} until we have finished the body,
     * so that our InputStream has been read until the start of the next partition.
     */
    @Override
    public void close() throws IOException {
        while (true) {
            int result = read();
            if (result == -1) {
                return;
            }
        }
    }

}