Headers.java
package com.renomad.minum.web;
import com.renomad.minum.security.ForbiddenUseException;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static com.renomad.minum.utils.Invariants.mustBeTrue;
/**
* Details extracted from the headers. For example,
* is this a keep-alive connection? what is the content-length,
* and so on.
* Here is some detail from <a href="https://en.wikipedia.org/wiki/List_of_HTTP_header_fields">Wikipedia</a> on the subject:
* <p>
* HTTP header fields are a list of strings sent and received by both
* the client program and server on every HTTP request and response. These
* headers are usually invisible to the end-user and are only processed or
* logged by the server and client applications. They define how information
* sent/received through the connection are encoded (as in Content-Encoding),
* the session verification and identification of the client (as in browser
* cookies, IP address, user-agent) or their anonymity thereof (VPN or
* proxy masking, user-agent spoofing), how the server should handle data
* (as in Do-Not-Track), the age (the time it has resided in a shared cache)
* of the document being downloaded, amongst others.
* </p>
*/
public final class Headers{
public static final Headers EMPTY = new Headers(List.of());
private static final int MAX_HEADERS_COUNT = 70;
/**
* Each line of the headers is read into this data structure
*/
private final List<String> headerStrings;
private final Map<String, List<String>> headersMap;
public Headers(
List<String> headerStrings
) {
this.headerStrings = new ArrayList<>(headerStrings);
this.headersMap = Collections.unmodifiableMap(extractHeadersToMap(headerStrings));
}
public List<String> getHeaderStrings() {
return new ArrayList<>(headerStrings);
}
/**
* Used for extracting the length of the body, in POSTs and
* responses from servers
*/
private static final Pattern contentLengthRegex = Pattern.compile("^[cC]ontent-[lL]ength: (.*)$");
/**
* Obtain any desired header by looking it up in this map. All keys
* are made lowercase.
*/
static Map<String, List<String>> extractHeadersToMap(List<String> headerStrings) {
var result = new HashMap<String, List<String>>();
for (var h : headerStrings) {
var indexOfFirstColon = h.indexOf(":");
// if the header is malformed, just move on
if (indexOfFirstColon <= 0) continue;
String key = h.substring(0, indexOfFirstColon).toLowerCase(Locale.ROOT);
String value = h.substring(indexOfFirstColon+1).trim();
if (result.containsKey(key)) {
var currentValue = result.get(key);
List<String> newList = new ArrayList<>();
newList.add(value);
newList.addAll(currentValue);
result.put(key, newList);
} else {
result.put(key, List.of(value));
}
}
return result;
}
/**
* Gets the one content-type header, or returns an empty string
*/
public String contentType() {
// find the header that starts with content-type
List<String> cts = headerStrings.stream().filter(x -> x.toLowerCase(Locale.ROOT).startsWith("content-type")).toList();
if (cts.size() > 1) {
throw new WebServerException("The number of content-type headers must be exactly zero or one. Received: " + cts);
}
if (!cts.isEmpty()) {
return cts.getFirst();
}
// if we don't find a content-type header, or if we don't find one we can handle, return an empty string.
return "";
}
/**
* Given the list of headers, find the one with the length of the
* body of the POST and return that value as an integer. If
* we do not find a content length, return -1.
*/
public int contentLength() {
List<String> cl = headerStrings.stream().filter(x -> x.toLowerCase(Locale.ROOT).startsWith("content-length")).toList();
if (cl.size() > 1) {
throw new WebServerException("The number of content-length headers must be exactly zero or one. Received: " + cl);
}
int contentLength = -1;
if (!cl.isEmpty()) {
Matcher clMatcher = contentLengthRegex.matcher(cl.getFirst());
mustBeTrue(clMatcher.matches(), "The content length header value must match the contentLengthRegex");
contentLength = Integer.parseInt(clMatcher.group(1));
mustBeTrue(contentLength >= 0, "Content-length cannot be negative");
}
return contentLength;
}
/**
* Indicates whether the headers in this request
* have a Connection: Keep-Alive
*/
public boolean hasKeepAlive() {
List<String> connectionHeader = headersMap.get("connection");
if (connectionHeader == null) return false;
return connectionHeader.stream().anyMatch(x -> x.toLowerCase(Locale.ROOT).contains("keep-alive"));
}
/**
* Indicates whether the headers in this request
* have a Connection: close
*/
public boolean hasConnectionClose() {
List<String> connectionHeader = headersMap.get("connection");
if (connectionHeader == null) return false;
return connectionHeader.stream().anyMatch(x -> x.toLowerCase(Locale.ROOT).contains("close"));
}
/**
* Loop through the lines of header in the HTTP message
*/
static List<String> getAllHeaders(InputStream is, IInputStreamUtils inputStreamUtils) {
List<String> headers = new ArrayList<>();
for (int i = 0;; i++) {
if (i >=MAX_HEADERS_COUNT) {
throw new ForbiddenUseException("User tried sending too many headers. max: " + MAX_HEADERS_COUNT);
}
String value;
try {
value = inputStreamUtils.readLine(is);
} catch (IOException e) {
throw new WebServerException(e);
}
if (value != null && value.isBlank()) {
break;
} else if (value == null) {
return headers;
} else {
headers.add(value);
}
}
return headers;
}
/**
* Allows a user to obtain any header value by its key, case-insensitively
* @return a {@link List} of string values, or null
* if no header was found.
*/
public List<String> valueByKey(String key) {
return headersMap.get(key.toLowerCase(Locale.ROOT));
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Headers headers = (Headers) o;
return Objects.equals(headerStrings, headers.headerStrings) && Objects.equals(headersMap, headers.headersMap);
}
@Override
public int hashCode() {
return Objects.hash(headerStrings, headersMap);
}
@Override
public String toString() {
return "Headers{" +
"headerStrings=" + headerStrings +
'}';
}
}