diff --git a/README.md b/README.md index 84347d0..f54fa77 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ An asynchronous HTTP client library for ESP32 microcontrollers, built on top of - ✅ **Multiple simultaneous requests** - Handle multiple requests concurrently - ✅ **Chunked transfer decoding** - Validates framing and exposes parsed trailers - ✅ **Optional redirect following** - Follow 301/302/303 (converted to GET) and 307/308 (method preserved) -- ✅ **Header & body guards** - Limit buffered response headers/body to avoid runaway responses +- ✅ **Header & body guards** - Limits buffered headers (~2.8 KiB) and body (8 KiB) by default to avoid runaway responses - ✅ **Zero-copy streaming** - Combine `req->setNoStoreBody(true)` with `client.onBodyChunk(...)` to stream large payloads without heap spikes > ⚠ Limitations: provide trust material for HTTPS (CA, fingerprint or insecure flag) and remember the full body is buffered in memory unless you opt into zero-copy streaming via `setNoStoreBody(true)`. @@ -150,10 +150,10 @@ void setDefaultConnectTimeout(uint32_t ms); // Follow HTTP redirects (max hops clamps to >=1). Disabled by default. void setFollowRedirects(bool enable, uint8_t maxHops = 3); -// Abort if response headers exceed this many bytes (0 = unlimited) +// Abort if response headers exceed this many bytes (default ~2.8 KiB, 0 = unlimited) void setMaxHeaderBytes(size_t maxBytes); -// Soft limit for buffered response bodies (bytes, 0 = unlimited) +// Soft limit for buffered response bodies (default 8192 bytes, 0 = unlimited) void setMaxBodySize(size_t maxBytes); // Limit simultaneous active requests (0 = unlimited, others queued) @@ -162,12 +162,20 @@ void setMaxParallel(uint16_t maxParallel); // Set User-Agent string void setUserAgent(const char* userAgent); +// Keep-alive connection pooling (idle timeout in ms, clamped to >= 1000) +void setKeepAlive(bool enable, uint16_t idleMs = 5000); + // Cookie jar helpers void clearCookies(); void setCookie(const char* name, const char* value, const char* path = "/", const char* domain = nullptr, bool secure = false); ``` +Cookies are captured automatically from `Set-Cookie` responses and replayed on matching hosts/paths; call +`clearCookies()` to wipe the jar or `setCookie()` to pre-seed entries manually. Keep-alive pooling is off by default; +enable it with `setKeepAlive(true, idleMs)` to reuse TCP/TLS connections for the same host/port (respecting server +`Connection: close` requests). + #### Callback Types ```cpp @@ -212,7 +220,7 @@ client.get("http://example.com/chunked", [](AsyncHttpResponse* response) { ```cpp // Create custom request -AsyncHttpRequest request(HTTP_POST, "http://example.com/api"); +AsyncHttpRequest request(HTTP_METHOD_POST, "http://example.com/api"); // Set headers request.setHeader("Content-Type", "application/json"); @@ -229,6 +237,11 @@ request.setTimeout(10000); client.request(&request, onSuccess, onError); ``` +HTTP method enums are now prefixed (`HTTP_METHOD_GET`, `HTTP_METHOD_POST`, etc.) to avoid collisions with +`ESPAsyncWebServer`'s `HTTP_GET`/`HTTP_POST` values. Legacy aliases can be re-enabled by defining +`ASYNC_HTTP_ENABLE_LEGACY_METHOD_ALIASES` before including `ESPAsyncWebClient.h` (only do this if you are not also +including `ESPAsyncWebServer.h` in the same translation unit). + ## Examples ### Simple GET Request @@ -274,7 +287,7 @@ client.setHeader("X-API-Key", "your-api-key"); client.setUserAgent("MyDevice/1.0"); // Or set per-request headers -AsyncHttpRequest* request = new AsyncHttpRequest(HTTP_GET, "http://example.com"); +AsyncHttpRequest* request = new AsyncHttpRequest(HTTP_METHOD_GET, "http://example.com"); request->setHeader("Authorization", "Bearer token"); client.request(request, onSuccess); ``` @@ -336,7 +349,7 @@ client.setHeader("Accept", "application/json"); ### Per-Request Settings ```cpp -AsyncHttpRequest* request = new AsyncHttpRequest(HTTP_POST, url); +AsyncHttpRequest* request = new AsyncHttpRequest(HTTP_METHOD_POST, url); request->setTimeout(30000); // 30 second timeout for this request request->setHeader("Content-Type", "application/xml"); request->setBody(xmlData); @@ -369,7 +382,7 @@ Parameters: Notes: - Invoked for every segment (chunk or contiguous data block) -- The full body is still accumulated internally (future option may allow disabling accumulation) +- Unless `req->setNoStoreBody(true)` is enabled, the full body is still accumulated internally - `final` is invoked just before the success callback - Keep it lightweight (avoid blocking operations) @@ -378,7 +391,7 @@ Notes: If `Content-Length` is present, the response is considered complete once that many bytes have been received. Extra bytes (if a misbehaving server sends more) are ignored. Without `Content-Length`, completion is determined by connection close. -Configure `client.setMaxBodySize(maxBytes)` to abort early when the announced `Content-Length` or accumulated chunk data would exceed `maxBytes`, yielding `MAX_BODY_SIZE_EXCEEDED`. Pass `0` (default) to disable the guard. +Configure `client.setMaxBodySize(maxBytes)` to abort early when the announced `Content-Length` or accumulated chunk data would exceed `maxBytes`, yielding `MAX_BODY_SIZE_EXCEEDED`. Pass `0` to disable the guard (this applies only when buffering the response body in memory). Likewise, guard against oversized or malicious header blocks via `client.setMaxHeaderBytes(limit)`. When the cumulative response headers exceed `limit` bytes before completion of `\r\n\r\n`, the request aborts with `HEADERS_TOO_LARGE`. diff --git a/examples/arduino/NoStoreToSD/NoStoreToSD.ino b/examples/arduino/NoStoreToSD/NoStoreToSD.ino index b77440c..3644f3e 100644 --- a/examples/arduino/NoStoreToSD/NoStoreToSD.ino +++ b/examples/arduino/NoStoreToSD/NoStoreToSD.ino @@ -31,7 +31,7 @@ static bool beginDownload(const char* url, const char* destinationPath) { currentPath = destinationPath; - AsyncHttpRequest* request = new AsyncHttpRequest(HTTP_GET, url); + AsyncHttpRequest* request = new AsyncHttpRequest(HTTP_METHOD_GET, url); request->setNoStoreBody(true); // only stream via onBodyChunk uint32_t id = client.request( diff --git a/examples/arduino/StreamingUpload/StreamingUpload.ino b/examples/arduino/StreamingUpload/StreamingUpload.ino index ab1f86f..b32f50a 100644 --- a/examples/arduino/StreamingUpload/StreamingUpload.ino +++ b/examples/arduino/StreamingUpload/StreamingUpload.ino @@ -47,7 +47,7 @@ void setup() { pattern.total = 10 * 1024; pattern.sent = 0; - AsyncHttpRequest* req = new AsyncHttpRequest(HTTP_POST, "http://httpbin.org/post"); + AsyncHttpRequest* req = new AsyncHttpRequest(HTTP_METHOD_POST, "http://httpbin.org/post"); req->addQueryParam("mode", "stream"); req->addQueryParam("unit", "bytes"); req->finalizeQueryParams(); diff --git a/examples/platformio/CustomHeaders/src/main.cpp b/examples/platformio/CustomHeaders/src/main.cpp index 97fdea7..69f9693 100644 --- a/examples/platformio/CustomHeaders/src/main.cpp +++ b/examples/platformio/CustomHeaders/src/main.cpp @@ -45,7 +45,7 @@ void setup() { delay(5000); // Wait 5 seconds // Make another request with additional headers using the advanced API - AsyncHttpRequest* customRequest = new AsyncHttpRequest(HTTP_POST, "http://httpbin.org/post"); + AsyncHttpRequest* customRequest = new AsyncHttpRequest(HTTP_METHOD_POST, "http://httpbin.org/post"); customRequest->setHeader("Content-Type", "application/json"); customRequest->setHeader("X-Custom-Header", "CustomValue123"); customRequest->setHeader("Accept", "application/json"); diff --git a/examples/platformio/StreamingUpload/src/main.cpp b/examples/platformio/StreamingUpload/src/main.cpp index 95376e1..bf1ff55 100644 --- a/examples/platformio/StreamingUpload/src/main.cpp +++ b/examples/platformio/StreamingUpload/src/main.cpp @@ -60,7 +60,7 @@ void setup() { pattern.total = 10 * 1024; // 10 KB synthetic pattern.sent = 0; - AsyncHttpRequest* req = new AsyncHttpRequest(HTTP_POST, "http://httpbin.org/post"); + AsyncHttpRequest* req = new AsyncHttpRequest(HTTP_METHOD_POST, "http://httpbin.org/post"); req->addQueryParam("mode", "stream"); req->addQueryParam("unit", "bytes"); req->finalizeQueryParams(); diff --git a/src/AsyncHttpClient.cpp b/src/AsyncHttpClient.cpp index fb22b9a..4fecfaf 100644 --- a/src/AsyncHttpClient.cpp +++ b/src/AsyncHttpClient.cpp @@ -3,14 +3,43 @@ #include "AsyncHttpClient.h" #include #include +#include +#include #include #include #include +#include #include "UrlParser.h" static constexpr size_t kMaxChunkSizeLineLen = 64; static constexpr size_t kMaxChunkTrailerLineLen = 256; static constexpr size_t kMaxChunkTrailerLines = 32; +static constexpr size_t kDefaultMaxHeaderBytes = 2800; // ~2.8 KiB +static constexpr size_t kDefaultMaxBodyBytes = 8192; // 8 KiB +static constexpr size_t kMaxCookieCount = 16; +static constexpr size_t kMaxCookieBytes = 4096; +static const char* kPublicSuffixes[] = {"com", + "net", + "org", + "gov", + "edu", + "mil", + "int", + "co.uk", + "ac.uk", + "gov.uk", + "uk", + "io", + "co", + "app", + "dev", + "github.io", + "web.app", + "pages.dev", + "vercel.app", + "firebaseapp.com", + "cloudfront.net"}; + static bool equalsIgnoreCase(const String& a, const char* b) { size_t lenA = a.length(); size_t lenB = strlen(b); @@ -24,9 +53,91 @@ static bool equalsIgnoreCase(const String& a, const char* b) { return true; } +static int64_t currentTimeSeconds() { + time_t now = time(nullptr); + if (now > 0) + return static_cast(now); + // Fallback to millis-based monotonic clock when wall time is not set + return static_cast(millis() / 1000); +} + +static int monthFromAbbrev(const char* mon) { + static const char* kMonths[] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}; + if (!mon || strlen(mon) < 3) + return -1; + for (int i = 0; i < 12; ++i) { + if (tolower((unsigned char)mon[0]) == tolower((unsigned char)kMonths[i][0]) && + tolower((unsigned char)mon[1]) == tolower((unsigned char)kMonths[i][1]) && + tolower((unsigned char)mon[2]) == tolower((unsigned char)kMonths[i][2])) { + return i; + } + } + return -1; +} + +static int64_t daysFromCivil(int y, unsigned m, unsigned d) { + // Howard Hinnant's days_from_civil, offset so 1970-01-01 yields 0 + y -= m <= 2 ? 1 : 0; + const int era = (y >= 0 ? y : y - 399) / 400; + const unsigned yoe = static_cast(y - era * 400); // [0, 399] + const unsigned doy = (153 * (m + (m > 2 ? -3 : 9)) + 2) / 5 + d - 1; // [0, 365] + const unsigned doe = yoe * 365 + yoe / 4 - yoe / 100 + doy; // [0, 146096] + return era * 146097 + static_cast(doe) - 719468; +} + +static bool makeUtcTimestamp(int year, int month, int day, int hour, int minute, int second, int64_t* outEpoch) { + if (!outEpoch) + return false; + if (month < 1 || month > 12 || day < 1 || hour < 0 || hour > 23 || minute < 0 || minute > 59 || second < 0 || + second > 60) + return false; + static const uint8_t kMonthDays[] = {31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}; + bool leap = (year % 4 == 0 && (year % 100 != 0 || year % 400 == 0)); + uint8_t maxDay = kMonthDays[month - 1] + ((leap && month == 2) ? 1 : 0); + if (static_cast(day) > maxDay) + return false; + int64_t days = daysFromCivil(year, static_cast(month), static_cast(day)); + int64_t seconds = days * 86400 + hour * 3600 + minute * 60 + second; + *outEpoch = seconds; + return true; +} + +static bool parseHttpDate(const String& value, int64_t* outEpoch) { + if (!outEpoch) + return false; + String date = value; + date.trim(); + if (date.length() < 20) // Shorter than "01 Jan 1970 00:00:00 GMT" + return false; + int comma = date.indexOf(','); + if (comma != -1) + date = date.substring(comma + 1); + date.trim(); + + int day = 0, year = 0, hour = 0, minute = 0, second = 0; + char monthBuf[4] = {0}; + char tzBuf[4] = {0}; + int matched = sscanf(date.c_str(), "%d %3s %d %d:%d:%d %3s", &day, monthBuf, &year, &hour, &minute, &second, tzBuf); + if (matched < 6) + return false; + if (matched == 6) + strncpy(tzBuf, "GMT", sizeof(tzBuf)); + if (!(equalsIgnoreCase(String(tzBuf), "GMT") || equalsIgnoreCase(String(tzBuf), "UTC"))) + return false; + int month = monthFromAbbrev(monthBuf); + if (month < 0) + return false; + int64_t epoch = 0; + if (!makeUtcTimestamp(year, month + 1, day, hour, minute, second, &epoch)) + return false; + *outEpoch = epoch; + return true; +} + AsyncHttpClient::AsyncHttpClient() : _defaultTimeout(10000), _defaultUserAgent(String("ESPAsyncWebClient/") + ESP_ASYNC_WEB_CLIENT_VERSION), - _bodyChunkCallback(nullptr), _followRedirects(false), _maxRedirectHops(3), _maxHeaderBytes(0) { + _bodyChunkCallback(nullptr), _maxBodySize(kDefaultMaxBodyBytes), _followRedirects(false), _maxRedirectHops(3), + _maxHeaderBytes(kDefaultMaxHeaderBytes) { #if defined(ARDUINO_ARCH_ESP32) && defined(ASYNC_HTTP_ENABLE_AUTOLOOP) // Create recursive mutex for shared containers when auto-loop may run in background _reqMutex = xSemaphoreCreateRecursiveMutex(); @@ -95,22 +206,22 @@ void AsyncHttpClient::_autoLoopTaskThunk(void* param) { #endif uint32_t AsyncHttpClient::get(const char* url, SuccessCallback onSuccess, ErrorCallback onError) { - return makeRequest(HTTP_GET, url, nullptr, onSuccess, onError); + return makeRequest(HTTP_METHOD_GET, url, nullptr, onSuccess, onError); } uint32_t AsyncHttpClient::post(const char* url, const char* data, SuccessCallback onSuccess, ErrorCallback onError) { - return makeRequest(HTTP_POST, url, data, onSuccess, onError); + return makeRequest(HTTP_METHOD_POST, url, data, onSuccess, onError); } uint32_t AsyncHttpClient::put(const char* url, const char* data, SuccessCallback onSuccess, ErrorCallback onError) { - return makeRequest(HTTP_PUT, url, data, onSuccess, onError); + return makeRequest(HTTP_METHOD_PUT, url, data, onSuccess, onError); } uint32_t AsyncHttpClient::del(const char* url, SuccessCallback onSuccess, ErrorCallback onError) { - return makeRequest(HTTP_DELETE, url, nullptr, onSuccess, onError); + return makeRequest(HTTP_METHOD_DELETE, url, nullptr, onSuccess, onError); } uint32_t AsyncHttpClient::head(const char* url, SuccessCallback onSuccess, ErrorCallback onError) { - return makeRequest(HTTP_HEAD, url, nullptr, onSuccess, onError); + return makeRequest(HTTP_METHOD_HEAD, url, nullptr, onSuccess, onError); } uint32_t AsyncHttpClient::patch(const char* url, const char* data, SuccessCallback onSuccess, ErrorCallback onError) { - return makeRequest(HTTP_PATCH, url, data, onSuccess, onError); + return makeRequest(HTTP_METHOD_PATCH, url, data, onSuccess, onError); } void AsyncHttpClient::setHeader(const char* name, const char* value) { @@ -251,18 +362,22 @@ void AsyncHttpClient::setCookie(const char* name, const char* value, const char* bool secure) { if (!name || strlen(name) == 0) return; + int64_t now = currentTimeSeconds(); StoredCookie cookie; cookie.name = String(name); cookie.value = value ? String(value) : String(); cookie.path = (path && strlen(path) > 0) ? String(path) : String("/"); cookie.domain = domain ? String(domain) : String(); cookie.secure = secure; + cookie.createdAt = now; + cookie.lastAccessAt = now; if (!cookie.path.startsWith("/")) cookie.path = "/" + cookie.path; if (cookie.domain.startsWith(".")) cookie.domain.remove(0, 1); lock(); + purgeExpiredCookies(now); for (auto it = _cookies.begin(); it != _cookies.end();) { if (it->name.equalsIgnoreCase(cookie.name) && it->domain.equalsIgnoreCase(cookie.domain) && it->path.equals(cookie.path)) { @@ -271,8 +386,11 @@ void AsyncHttpClient::setCookie(const char* name, const char* value, const char* ++it; } } - if (!cookie.value.isEmpty()) + if (!cookie.value.isEmpty()) { + if (_cookies.size() >= kMaxCookieCount) + evictOneCookieLocked(); _cookies.push_back(cookie); + } unlock(); } @@ -455,7 +573,10 @@ void AsyncHttpClient::handleConnect(RequestContext* context) { } void AsyncHttpClient::handleData(RequestContext* context, char* data, size_t len) { - context->responseBuffer.concat(data, len); + bool storeBody = context && context->request && !context->request->getNoStoreBody(); + bool bufferThisChunk = context && (!context->headersComplete || context->chunked); + if (bufferThisChunk) + context->responseBuffer.concat(data, len); bool enforceLimit = shouldEnforceBodyLimit(context); auto wouldExceedLimit = [&](size_t incoming) -> bool { if (!enforceLimit) @@ -494,7 +615,7 @@ void AsyncHttpClient::handleData(RequestContext* context, char* data, size_t len triggerError(context, MAX_BODY_SIZE_EXCEEDED, "Body exceeds configured maximum"); return; } - if (!context->request->getNoStoreBody()) { + if (storeBody) { context->response->appendBody(context->responseBuffer.c_str(), incomingLen); } context->receivedContentLength += incomingLen; @@ -513,7 +634,7 @@ void AsyncHttpClient::handleData(RequestContext* context, char* data, size_t len triggerError(context, MAX_BODY_SIZE_EXCEEDED, "Body exceeds configured maximum"); return; } - if (!context->request->getNoStoreBody()) { + if (storeBody) { context->response->appendBody(data, len); } context->receivedContentLength += len; @@ -619,7 +740,7 @@ void AsyncHttpClient::handleData(RequestContext* context, char* data, size_t len return; } const char* chunkPtr = context->responseBuffer.c_str(); - if (!context->request->getNoStoreBody()) { + if (storeBody) { context->response->appendBody(chunkPtr, chunkLen); } context->receivedContentLength += chunkLen; @@ -733,7 +854,7 @@ bool AsyncHttpClient::parseResponseHeaders(RequestContext* context, const String parsed = 0; context->expectedContentLength = (size_t)parsed; context->response->setContentLength(context->expectedContentLength); - bool storeBody = !(context->request->getNoStoreBody() && _bodyChunkCallback); + bool storeBody = !context->request->getNoStoreBody(); if (storeBody) context->response->reserveBody(context->expectedContentLength); } else if (name.equalsIgnoreCase("Transfer-Encoding") && value.equalsIgnoreCase("chunked")) { @@ -902,7 +1023,7 @@ bool AsyncHttpClient::buildRedirectRequest(RequestContext* context, AsyncHttpReq HttpMethod newMethod = context->request->getMethod(); bool dropBody = false; if (status == 301 || status == 302 || status == 303) { - newMethod = HTTP_GET; + newMethod = HTTP_METHOD_GET; dropBody = true; } @@ -911,14 +1032,20 @@ bool AsyncHttpClient::buildRedirectRequest(RequestContext* context, AsyncHttpReq newRequest->setNoStoreBody(context->request->getNoStoreBody()); bool sameOrigin = isSameOrigin(context->request, newRequest); + auto isCrossOriginSensitiveHeader = [](const String& name) { + String lower = name; + lower.toLowerCase(); + return lower.equals("authorization") || lower.equals("proxy-authorization") || lower.equals("cookie") || + lower.equals("cookie2") || lower.startsWith("x-api-key") || lower.startsWith("x-auth-token") || + lower.startsWith("x-access-token"); + }; const auto& headers = context->request->getHeaders(); for (const auto& hdr : headers) { if (hdr.name.equalsIgnoreCase("Content-Length")) continue; if (dropBody && hdr.name.equalsIgnoreCase("Content-Type")) continue; - if (!sameOrigin && - (hdr.name.equalsIgnoreCase("Authorization") || hdr.name.equalsIgnoreCase("Proxy-Authorization"))) + if (!sameOrigin && isCrossOriginSensitiveHeader(hdr.name)) continue; newRequest->setHeader(hdr.name, hdr.value); } @@ -1082,7 +1209,7 @@ bool AsyncHttpClient::shouldEnforceBodyLimit(RequestContext* context) { return false; if (!context || !context->request) return true; - if (context->request->getNoStoreBody() && _bodyChunkCallback) + if (context->request->getNoStoreBody()) return false; return true; } @@ -1292,6 +1419,18 @@ bool AsyncHttpClient::isIpLiteral(const String& host) const { return hasColon || hasDot; } +static bool isPublicSuffix(const String& domain) { + if (domain.length() == 0) + return false; + String lower = domain; + lower.toLowerCase(); + for (auto suffix : kPublicSuffixes) { + if (lower.equals(suffix)) + return true; + } + return false; +} + bool AsyncHttpClient::normalizeCookieDomain(String& domain, const String& host, bool domainAttributeProvided) const { String cleaned = domain; cleaned.trim(); @@ -1316,6 +1455,8 @@ bool AsyncHttpClient::normalizeCookieDomain(String& domain, const String& host, return false; if (cleaned.indexOf('.') == -1) return false; + if (isPublicSuffix(cleaned)) + return false; domain = cleaned; return true; @@ -1353,9 +1494,12 @@ bool AsyncHttpClient::pathMatches(const String& cookiePath, const String& reques return req.length() > cpath.length() && req.charAt(cpath.length()) == '/'; } -bool AsyncHttpClient::cookieMatchesRequest(const StoredCookie& cookie, const AsyncHttpRequest* request) const { +bool AsyncHttpClient::cookieMatchesRequest(const StoredCookie& cookie, const AsyncHttpRequest* request, + int64_t nowSeconds) const { if (!request) return false; + if (isCookieExpired(cookie, nowSeconds)) + return false; if (cookie.secure && !request->isSecure()) return false; if (!domainMatches(cookie.domain, request->getHost())) @@ -1365,23 +1509,98 @@ bool AsyncHttpClient::cookieMatchesRequest(const StoredCookie& cookie, const Asy return !cookie.value.isEmpty(); } +bool AsyncHttpClient::isCookieExpired(const StoredCookie& cookie, int64_t nowSeconds) const { + return cookie.expiresAt != -1 && nowSeconds >= cookie.expiresAt; +} + +void AsyncHttpClient::purgeExpiredCookies(int64_t nowSeconds) { + for (auto it = _cookies.begin(); it != _cookies.end();) { + if (isCookieExpired(*it, nowSeconds)) { + it = _cookies.erase(it); + } else { + ++it; + } + } +} + +static uint8_t countDomainDots(const String& domain) { + uint8_t dots = 0; + for (size_t i = 0; i < domain.length(); ++i) { + if (domain.charAt(i) == '.') + ++dots; + } + return dots; +} + +void AsyncHttpClient::evictOneCookieLocked() { + if (_cookies.empty()) + return; + size_t bestIndex = 0; + for (size_t i = 1; i < _cookies.size(); ++i) { + const StoredCookie& best = _cookies[bestIndex]; + const StoredCookie& candidate = _cookies[i]; + + if (candidate.lastAccessAt != best.lastAccessAt) { + if (candidate.lastAccessAt < best.lastAccessAt) + bestIndex = i; + continue; + } + + bool candidateSession = candidate.expiresAt == -1; + bool bestSession = best.expiresAt == -1; + if (candidateSession != bestSession) { + if (candidateSession) + bestIndex = i; + continue; + } + + uint8_t candidateDots = countDomainDots(candidate.domain); + uint8_t bestDots = countDomainDots(best.domain); + if (candidateDots != bestDots) { + if (candidateDots < bestDots) + bestIndex = i; + continue; + } + + if (candidate.domain.length() != best.domain.length()) { + if (candidate.domain.length() < best.domain.length()) + bestIndex = i; + continue; + } + + if (candidate.path.length() != best.path.length()) { + if (candidate.path.length() < best.path.length()) + bestIndex = i; + continue; + } + + if (candidate.createdAt != best.createdAt) { + if (candidate.createdAt < best.createdAt) + bestIndex = i; + continue; + } + } + _cookies.erase(_cookies.begin() + static_cast::difference_type>(bestIndex)); +} + void AsyncHttpClient::applyCookies(AsyncHttpRequest* request) { if (!request) return; + int64_t now = currentTimeSeconds(); String cookieHeader; - std::vector cookiesCopy; lock(); - cookiesCopy = _cookies; - unlock(); + purgeExpiredCookies(now); size_t estimatedLen = 0; - if (!cookiesCopy.empty()) { - estimatedLen += (cookiesCopy.size() - 1) * 2; // separators - for (const auto& cookie : cookiesCopy) - estimatedLen += cookie.name.length() + 1 + cookie.value.length(); - cookieHeader.reserve(estimatedLen); - } - for (const auto& cookie : cookiesCopy) { - if (cookieMatchesRequest(cookie, request)) { + for (const auto& cookie : _cookies) { + if (cookieMatchesRequest(cookie, request, now)) + estimatedLen += cookie.name.length() + 1 + cookie.value.length() + 2; + } + if (estimatedLen >= 2) + estimatedLen -= 2; + cookieHeader.reserve(estimatedLen); + for (auto& cookie : _cookies) { + if (cookieMatchesRequest(cookie, request, now)) { + cookie.lastAccessAt = now; if (!cookieHeader.isEmpty()) cookieHeader += "; "; cookieHeader += cookie.name; @@ -1389,6 +1608,7 @@ void AsyncHttpClient::applyCookies(AsyncHttpRequest* request) { cookieHeader += cookie.value; } } + unlock(); if (cookieHeader.isEmpty()) return; String existing = request->getHeader("Cookie"); @@ -1409,6 +1629,9 @@ void AsyncHttpClient::storeResponseCookie(const AsyncHttpRequest* request, const String raw = setCookieValue; if (raw.length() == 0) return; + if (raw.length() > kMaxCookieBytes) + return; + int64_t now = currentTimeSeconds(); int semi = raw.indexOf(';'); String pair = semi == -1 ? raw : raw.substring(0, semi); pair.trim(); @@ -1424,6 +1647,8 @@ void AsyncHttpClient::storeResponseCookie(const AsyncHttpRequest* request, const cookie.path = "/"; bool domainAttributeProvided = false; bool remove = cookie.value.isEmpty(); + bool maxAgeAttributeProvided = false; + int64_t expiresAt = -1; int pos = semi; while (pos != -1) { @@ -1444,9 +1669,18 @@ void AsyncHttpClient::storeResponseCookie(const AsyncHttpRequest* request, const } else if (key.equalsIgnoreCase("Secure")) { cookie.secure = true; } else if (key.equalsIgnoreCase("Max-Age")) { + maxAgeAttributeProvided = true; long age = val.toInt(); - if (age <= 0) + if (age <= 0) { remove = true; + expiresAt = now; + } else { + expiresAt = now + static_cast(age); + } + } else if (key.equalsIgnoreCase("Expires") && !maxAgeAttributeProvided) { + int64_t parsedExpiry = -1; + if (parseHttpDate(val, &parsedExpiry)) + expiresAt = parsedExpiry; } } pos = next; @@ -1456,8 +1690,17 @@ void AsyncHttpClient::storeResponseCookie(const AsyncHttpRequest* request, const return; if (!cookie.path.startsWith("/")) cookie.path = "/" + cookie.path; + size_t payloadSize = cookie.name.length() + cookie.value.length() + cookie.domain.length() + cookie.path.length(); + if (payloadSize > kMaxCookieBytes) + return; + cookie.expiresAt = expiresAt; + cookie.createdAt = now; + cookie.lastAccessAt = now; + if (isCookieExpired(cookie, now)) + remove = true; lock(); + purgeExpiredCookies(now); for (auto it = _cookies.begin(); it != _cookies.end();) { if (it->name.equalsIgnoreCase(cookie.name) && it->domain.equalsIgnoreCase(cookie.domain) && it->path.equals(cookie.path)) { @@ -1466,7 +1709,10 @@ void AsyncHttpClient::storeResponseCookie(const AsyncHttpRequest* request, const ++it; } } - if (!remove) + if (!remove) { + if (_cookies.size() >= kMaxCookieCount) + evictOneCookieLocked(); _cookies.push_back(cookie); + } unlock(); } diff --git a/src/AsyncHttpClient.h b/src/AsyncHttpClient.h index 863ca11..84a73e0 100644 --- a/src/AsyncHttpClient.h +++ b/src/AsyncHttpClient.h @@ -232,11 +232,17 @@ class AsyncHttpClient { String domain; String path; bool secure = false; + int64_t expiresAt = -1; // -1 means no expiration set + int64_t createdAt = 0; + int64_t lastAccessAt = 0; }; std::vector _cookies; void applyCookies(AsyncHttpRequest* request); void storeResponseCookie(const AsyncHttpRequest* request, const String& setCookieValue); - bool cookieMatchesRequest(const StoredCookie& cookie, const AsyncHttpRequest* request) const; + bool cookieMatchesRequest(const StoredCookie& cookie, const AsyncHttpRequest* request, int64_t nowSeconds) const; + bool isCookieExpired(const StoredCookie& cookie, int64_t nowSeconds) const; + void purgeExpiredCookies(int64_t nowSeconds); + void evictOneCookieLocked(); bool domainMatches(const String& cookieDomain, const String& host) const; bool pathMatches(const String& cookiePath, const String& requestPath) const; bool normalizeCookieDomain(String& domain, const String& host, bool domainAttributeProvided) const; diff --git a/src/AsyncTransport.cpp b/src/AsyncTransport.cpp index 8beae1c..bc09949 100644 --- a/src/AsyncTransport.cpp +++ b/src/AsyncTransport.cpp @@ -274,6 +274,7 @@ class AsyncTlsTransport : public AsyncTransport { std::vector _encryptedBuffer; size_t _encryptedOffset = 0; std::vector _fingerprintBytes; + bool _fingerprintInvalid = false; mbedtls_ssl_context _ssl; mbedtls_ssl_config _sslConfig; @@ -295,7 +296,8 @@ static int hexValue(char c) { return -1; } -static std::vector parseFingerprintString(const String& text) { +static std::vector parseFingerprintString(const String& text, bool* outValid) { + bool valid = true; std::vector bytes; int accum = -1; for (size_t i = 0; i < text.length(); ++i) { @@ -306,6 +308,7 @@ static std::vector parseFingerprintString(const String& text) { int v = hexValue(ch); if (v < 0) { bytes.clear(); + valid = false; break; } if (accum == -1) { @@ -318,7 +321,10 @@ static std::vector parseFingerprintString(const String& text) { if (accum != -1) { // Odd number of nibbles -> invalid bytes.clear(); + valid = false; } + if (outValid) + *outValid = valid || text.length() == 0; return bytes; } @@ -334,7 +340,9 @@ AsyncTlsTransport::AsyncTlsTransport(const AsyncHttpTLSConfig& config) : _client mbedtls_pk_init(&_clientKey); mbedtls_entropy_init(&_entropy); mbedtls_ctr_drbg_init(&_ctrDrbg); - _fingerprintBytes = parseFingerprintString(_config.fingerprint); + bool fpValid = true; + _fingerprintBytes = parseFingerprintString(_config.fingerprint, &fpValid); + _fingerprintInvalid = (_config.fingerprint.length() > 0 && !fpValid); } AsyncTlsTransport::~AsyncTlsTransport() { @@ -364,6 +372,10 @@ void AsyncTlsTransport::shutdownClient() { bool AsyncTlsTransport::connect(const char* host, uint16_t port) { if (!_client) return false; + if (_fingerprintInvalid) { + fail(TLS_FINGERPRINT_MISMATCH, "Invalid TLS fingerprint format"); + return false; + } _host = host; _port = port; _handshakeStartMs = millis(); @@ -456,16 +468,17 @@ void AsyncTlsTransport::continueHandshake() { } bool AsyncTlsTransport::verifyPeerCertificate() { - if (_config.insecure) - return true; uint32_t res = mbedtls_ssl_get_verify_result(&_ssl); - if (res != 0) { + bool requireCaValidation = !_config.insecure; + if (requireCaValidation && res != 0) { fail(TLS_CERT_INVALID, "TLS certificate validation failed"); return false; } - if (!_fingerprintBytes.empty() && !verifyFingerprint()) { - fail(TLS_FINGERPRINT_MISMATCH, "TLS fingerprint mismatch"); - return false; + if (!_fingerprintBytes.empty()) { + if (!verifyFingerprint()) { + fail(TLS_FINGERPRINT_MISMATCH, "TLS fingerprint mismatch"); + return false; + } } return true; } diff --git a/src/HttpRequest.cpp b/src/HttpRequest.cpp index 94ea1cd..5f5a7e7 100644 --- a/src/HttpRequest.cpp +++ b/src/HttpRequest.cpp @@ -101,17 +101,17 @@ bool AsyncHttpRequest::parseUrl(const String& url) { String AsyncHttpRequest::methodToString() const { switch (_method) { - case HTTP_GET: + case HTTP_METHOD_GET: return "GET"; - case HTTP_POST: + case HTTP_METHOD_POST: return "POST"; - case HTTP_PUT: + case HTTP_METHOD_PUT: return "PUT"; - case HTTP_DELETE: + case HTTP_METHOD_DELETE: return "DELETE"; - case HTTP_HEAD: + case HTTP_METHOD_HEAD: return "HEAD"; - case HTTP_PATCH: + case HTTP_METHOD_PATCH: return "PATCH"; default: return "GET"; diff --git a/src/HttpRequest.h b/src/HttpRequest.h index 142523d..fbe2ffb 100644 --- a/src/HttpRequest.h +++ b/src/HttpRequest.h @@ -7,7 +7,35 @@ #include #include "HttpCommon.h" -enum HttpMethod { HTTP_GET, HTTP_POST, HTTP_PUT, HTTP_DELETE, HTTP_HEAD, HTTP_PATCH }; +enum HttpMethod { + HTTP_METHOD_GET, + HTTP_METHOD_POST, + HTTP_METHOD_PUT, + HTTP_METHOD_DELETE, + HTTP_METHOD_HEAD, + HTTP_METHOD_PATCH +}; + +#ifdef ASYNC_HTTP_ENABLE_LEGACY_METHOD_ALIASES +#ifndef HTTP_GET +#define HTTP_GET HTTP_METHOD_GET +#endif +#ifndef HTTP_POST +#define HTTP_POST HTTP_METHOD_POST +#endif +#ifndef HTTP_PUT +#define HTTP_PUT HTTP_METHOD_PUT +#endif +#ifndef HTTP_DELETE +#define HTTP_DELETE HTTP_METHOD_DELETE +#endif +#ifndef HTTP_HEAD +#define HTTP_HEAD HTTP_METHOD_HEAD +#endif +#ifndef HTTP_PATCH +#define HTTP_PATCH HTTP_METHOD_PATCH +#endif +#endif // ASYNC_HTTP_ENABLE_LEGACY_METHOD_ALIASES struct AsyncHttpTLSConfig; @@ -100,8 +128,7 @@ class AsyncHttpRequest { // Accept-Encoding convenience (gzip) void enableGzipAcceptEncoding(bool enable = true); - // Avoid storing body in memory (use only streaming callbacks). Effective only if a response chunk callback - // (per-request or global) is present. + // Avoid storing response body in memory (use with global client.onBodyChunk(...) to consume the data). void setNoStoreBody(bool enable = true) { _noStoreBody = enable; } diff --git a/src/UrlParser.cpp b/src/UrlParser.cpp index 2e7c247..5deabcd 100644 --- a/src/UrlParser.cpp +++ b/src/UrlParser.cpp @@ -1,7 +1,14 @@ #include "UrlParser.h" +#include +#include +#include namespace UrlParser { +static constexpr size_t kMaxUrlLength = 2048; +static constexpr size_t kMaxHostLength = 255; +static constexpr size_t kMaxPathLength = 1900; + static bool startsWith(const std::string& s, const char* prefix) { size_t n = 0; while (prefix[n] != '\0') @@ -9,10 +16,57 @@ static bool startsWith(const std::string& s, const char* prefix) { return s.size() >= n && s.compare(0, n, prefix) == 0; } +static bool hasInvalidUrlChar(const std::string& url) { + for (char c : url) { + unsigned char uc = static_cast(c); + if (uc <= 0x1F || uc == 0x7F || c == '\r' || c == '\n' || c == ' ' || c == '\t') + return true; + } + return false; +} + +static bool isValidHostChar(char c) { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '-' || c == '.'; +} + +static bool isValidHost(const std::string& host) { + if (host.empty() || host.size() > kMaxHostLength) + return false; + if (host.front() == '.' || host.back() == '.') + return false; + for (char c : host) { + if (!isValidHostChar(c)) + return false; + } + return true; +} + +static bool parsePort(const std::string& portStr, uint16_t* out) { + if (!out || portStr.empty()) + return false; + for (char c : portStr) { + if (c < '0' || c > '9') + return false; + } + errno = 0; + char* end = nullptr; + unsigned long val = std::strtoul(portStr.c_str(), &end, 10); + if (end == portStr.c_str() || *end != '\0') + return false; + if (errno == ERANGE || val > 65535) + return false; + *out = static_cast(val); + return true; +} + bool parse(const std::string& originalUrl, ParsedUrl& out) { + if (originalUrl.size() > kMaxUrlLength || hasInvalidUrlChar(originalUrl)) + return false; + std::string url = originalUrl; // working copy out.secure = false; out.port = 80; + out.schemeImplicit = false; if (startsWith(url, "https://")) { out.secure = true; @@ -23,9 +77,10 @@ bool parse(const std::string& originalUrl, ParsedUrl& out) { out.port = 80; url.erase(0, 7); } else { - // No scheme provided -> default http - out.secure = false; - out.port = 80; + // No scheme provided -> default to HTTPS and signal implicit scheme + out.secure = true; + out.port = 443; + out.schemeImplicit = true; } // Find first '/' and first '?' @@ -55,10 +110,18 @@ bool parse(const std::string& originalUrl, ParsedUrl& out) { std::string portStr = out.host.substr(colon + 1); out.host = out.host.substr(0, colon); if (!portStr.empty()) { - out.port = static_cast(std::stoi(portStr)); + uint16_t parsedPort = 0; + if (!parsePort(portStr, &parsedPort)) + return false; + out.port = parsedPort; } } + if (!isValidHost(out.host)) + return false; + if (out.path.size() > kMaxPathLength) + return false; + return !out.host.empty(); } diff --git a/src/UrlParser.h b/src/UrlParser.h index 4545740..4b89723 100644 --- a/src/UrlParser.h +++ b/src/UrlParser.h @@ -2,10 +2,10 @@ * Lightweight URL parsing utility extracted from AsyncHttpRequest to allow * host (native) unit testing without requiring Arduino framework headers. * - * Supported forms (mirrors original behaviour): + * Supported forms (mirrors original behaviour, with secure default when scheme is omitted): * - http://host * - https://host - * - host (defaults to http) + * - host (defaults to https and marks schemeImplicit=true) * - host:port/path?query * - host?query (query before first '/') * - http(s)://host?query (same as above) @@ -22,6 +22,7 @@ struct ParsedUrl { std::string path; // always begins with '/' uint16_t port = 80; bool secure = false; + bool schemeImplicit = false; // true when no scheme was provided }; // Parse URL into components. Returns false if host empty after parsing. diff --git a/test/test_chunk_parse/test_main.cpp b/test/test_chunk_parse/test_main.cpp index bfc27f8..0a97d3f 100644 --- a/test/test_chunk_parse/test_main.cpp +++ b/test/test_chunk_parse/test_main.cpp @@ -12,6 +12,8 @@ static bool gErrorCalled = false; static HttpClientError gLastError = CONNECTION_FAILED; static String gLastBody; static std::vector gLastTrailers; +static String gStreamedBody; +static bool gStreamFinalCalled = false; static void resetState() { gSuccessCalled = false; @@ -19,6 +21,8 @@ static void resetState() { gLastError = CONNECTION_FAILED; gLastBody = ""; gLastTrailers.clear(); + gStreamedBody = ""; + gStreamFinalCalled = false; } static String trailerValue(const char* name) { @@ -32,7 +36,7 @@ static String trailerValue(const char* name) { static AsyncHttpClient::RequestContext* makeContext(AsyncHttpClient& client) { auto ctx = new AsyncHttpClient::RequestContext(); - ctx->request = new AsyncHttpRequest(HTTP_GET, "http://example.com/res"); + ctx->request = new AsyncHttpRequest(HTTP_METHOD_GET, "http://example.com/res"); ctx->response = new AsyncHttpResponse(); ctx->transport = nullptr; ctx->onSuccess = [](AsyncHttpResponse* resp) { @@ -67,8 +71,6 @@ static void test_chunk_trailers_are_parsed() { feed("X-Meta: done\r\n"); feed("\r\n"); - client.handleDisconnect(ctx); - TEST_ASSERT_TRUE(gSuccessCalled); TEST_ASSERT_FALSE(gErrorCalled); TEST_ASSERT_EQUAL_STRING("Wikipedia", gLastBody.c_str()); @@ -90,8 +92,6 @@ static void test_chunk_missing_crlf_is_error() { feed("4\r\n"); feed("Wiki\n"); // missing CR before LF terminator - client.handleDisconnect(ctx); - TEST_ASSERT_TRUE(gErrorCalled); TEST_ASSERT_FALSE(gSuccessCalled); TEST_ASSERT_EQUAL_INT(CHUNKED_DECODE_FAILED, gLastError); @@ -116,12 +116,39 @@ static void test_chunk_body_limit_enforced() { TEST_ASSERT_EQUAL_INT(MAX_BODY_SIZE_EXCEEDED, gLastError); } +static void test_chunk_body_limit_ignored_for_no_store_streaming() { + resetState(); + AsyncHttpClient client; + client.setMaxBodySize(5); + client.onBodyChunk([](const char* data, size_t len, bool final) { + if (final) { + gStreamFinalCalled = true; + return; + } + if (data && len > 0) + gStreamedBody.concat(data, len); + }); + auto ctx = makeContext(client); + ctx->headersComplete = true; + ctx->chunked = true; + ctx->request->setNoStoreBody(true); + + auto feed = [&](const char* data) { client.handleData(ctx, const_cast(data), strlen(data)); }; + feed("6\r\nTooBig\r\n0\r\n\r\n"); + + TEST_ASSERT_TRUE(gSuccessCalled); + TEST_ASSERT_FALSE(gErrorCalled); + TEST_ASSERT_EQUAL_STRING("TooBig", gStreamedBody.c_str()); + TEST_ASSERT_TRUE(gStreamFinalCalled); +} + void setup() { delay(2000); UNITY_BEGIN(); RUN_TEST(test_chunk_trailers_are_parsed); RUN_TEST(test_chunk_missing_crlf_is_error); RUN_TEST(test_chunk_body_limit_enforced); + RUN_TEST(test_chunk_body_limit_ignored_for_no_store_streaming); UNITY_END(); } diff --git a/test/test_cookies/test_main.cpp b/test/test_cookies/test_main.cpp index 98d3ff0..db6cef8 100644 --- a/test/test_cookies/test_main.cpp +++ b/test/test_cookies/test_main.cpp @@ -15,7 +15,7 @@ static void test_domain_matching_subdomains() { static void test_multiple_cookies_and_deduplication() { AsyncHttpClient client; - AsyncHttpRequest req(HTTP_GET, "http://example.com/path"); + AsyncHttpRequest req(HTTP_METHOD_GET, "http://example.com/path"); client.storeResponseCookie(&req, "a=1; Path=/"); client.storeResponseCookie(&req, "b=2; Path=/"); @@ -37,7 +37,7 @@ static void test_multiple_cookies_and_deduplication() { static void test_max_age_removes_cookie() { AsyncHttpClient client; - AsyncHttpRequest req(HTTP_GET, "http://example.com/"); + AsyncHttpRequest req(HTTP_METHOD_GET, "http://example.com/"); client.storeResponseCookie(&req, "temp=1; Path=/"); client.storeResponseCookie(&req, "temp=0; Max-Age=0; Path=/"); @@ -49,21 +49,21 @@ static void test_max_age_removes_cookie() { static void test_clear_and_public_set_cookie_api() { AsyncHttpClient client; - AsyncHttpRequest req(HTTP_GET, "http://example.com/"); + AsyncHttpRequest req(HTTP_METHOD_GET, "http://example.com/"); client.setCookie("manual", "123", "/", "example.com", false); client.applyCookies(&req); TEST_ASSERT_EQUAL_STRING("manual=123", req.getHeader("Cookie").c_str()); client.clearCookies(); - AsyncHttpRequest req2(HTTP_GET, "http://example.com/"); + AsyncHttpRequest req2(HTTP_METHOD_GET, "http://example.com/"); client.applyCookies(&req2); TEST_ASSERT_TRUE(req2.getHeader("Cookie").isEmpty()); } static void test_rejects_mismatched_domain_attribute() { AsyncHttpClient client; - AsyncHttpRequest req(HTTP_GET, "http://example.com/"); + AsyncHttpRequest req(HTTP_METHOD_GET, "http://example.com/"); client.storeResponseCookie(&req, "evil=1; Domain=evil.com; Path=/"); client.applyCookies(&req); @@ -73,17 +73,80 @@ static void test_rejects_mismatched_domain_attribute() { static void test_cookie_path_matching_rfc6265_rule() { AsyncHttpClient client; - AsyncHttpRequest req(HTTP_GET, "http://example.com/administrator"); + AsyncHttpRequest req(HTTP_METHOD_GET, "http://example.com/administrator"); client.storeResponseCookie(&req, "adminonly=1; Path=/admin"); client.applyCookies(&req); TEST_ASSERT_TRUE(req.getHeader("Cookie").isEmpty()); - AsyncHttpRequest req2(HTTP_GET, "http://example.com/admin/settings"); + AsyncHttpRequest req2(HTTP_METHOD_GET, "http://example.com/admin/settings"); client.applyCookies(&req2); TEST_ASSERT_EQUAL_STRING("adminonly=1", req2.getHeader("Cookie").c_str()); } +static void test_expires_and_max_age_enforcement() { + AsyncHttpClient client; + AsyncHttpRequest req(HTTP_METHOD_GET, "http://example.com/"); + + client.storeResponseCookie(&req, "persist=1; Expires=Fri, 01 Jan 2100 00:00:00 GMT; Path=/"); + client.applyCookies(&req); + TEST_ASSERT_EQUAL_STRING("persist=1", req.getHeader("Cookie").c_str()); + TEST_ASSERT_EQUAL(1, (int)client._cookies.size()); + + // Force the stored cookie to be expired and ensure it is not sent + client._cookies[0].expiresAt = 0; // Epoch start; always treated as expired + AsyncHttpRequest req2(HTTP_METHOD_GET, "http://example.com/"); + client.applyCookies(&req2); + TEST_ASSERT_TRUE(req2.getHeader("Cookie").isEmpty()); + TEST_ASSERT_EQUAL(0, (int)client._cookies.size()); + + // Past Expires attribute should remove the cookie immediately + client.storeResponseCookie(&req, "persist=1; Expires=Thu, 01 Jan 1970 00:00:00 GMT; Path=/"); + AsyncHttpRequest req3(HTTP_METHOD_GET, "http://example.com/"); + client.applyCookies(&req3); + TEST_ASSERT_TRUE(req3.getHeader("Cookie").isEmpty()); +} + +static bool hasCookieNamed(const AsyncHttpClient& client, const char* name) { + for (const auto& cookie : client._cookies) { + if (cookie.name.equalsIgnoreCase(name)) + return true; + } + return false; +} + +static void test_cookie_jar_eviction_is_lru_session_then_scope() { + AsyncHttpClient client; + AsyncHttpRequest req(HTTP_METHOD_GET, "http://example.com/"); + + for (int i = 0; i < 16; ++i) { + client.storeResponseCookie(&req, String("c") + String(i) + "=1; Path=/"); + } + TEST_ASSERT_EQUAL(16, (int)client._cookies.size()); + + for (size_t i = 0; i < client._cookies.size(); ++i) { + client._cookies[i].createdAt = (int64_t)i; + client._cookies[i].lastAccessAt = 1000; + } + + client._cookies[1].expiresAt = 2000000000; // persistent + client._cookies[1].lastAccessAt = 1; + client._cookies[2].expiresAt = -1; // session + client._cookies[2].path = "/"; + client._cookies[2].lastAccessAt = 1; + client._cookies[3].expiresAt = -1; // session, more specific scope + client._cookies[3].path = "/admin"; + client._cookies[3].lastAccessAt = 1; + + client.storeResponseCookie(&req, "new=1; Path=/"); + TEST_ASSERT_EQUAL(16, (int)client._cookies.size()); + + TEST_ASSERT_TRUE(hasCookieNamed(client, "c1")); + TEST_ASSERT_FALSE(hasCookieNamed(client, "c2")); + TEST_ASSERT_TRUE(hasCookieNamed(client, "c3")); + TEST_ASSERT_TRUE(hasCookieNamed(client, "new")); +} + int runUnityTests() { UNITY_BEGIN(); RUN_TEST(test_domain_matching_subdomains); @@ -92,6 +155,8 @@ int runUnityTests() { RUN_TEST(test_clear_and_public_set_cookie_api); RUN_TEST(test_rejects_mismatched_domain_attribute); RUN_TEST(test_cookie_path_matching_rfc6265_rule); + RUN_TEST(test_expires_and_max_age_enforcement); + RUN_TEST(test_cookie_jar_eviction_is_lru_session_then_scope); return UNITY_END(); } diff --git a/test/test_keep_alive/test_main.cpp b/test/test_keep_alive/test_main.cpp index 71651e4..5310eb4 100644 --- a/test/test_keep_alive/test_main.cpp +++ b/test/test_keep_alive/test_main.cpp @@ -77,7 +77,7 @@ static void test_pools_connection_on_complete_body() { client.setKeepAlive(true, 3000); auto ctx = new AsyncHttpClient::RequestContext(); - ctx->request = new AsyncHttpRequest(HTTP_GET, "http://example.com/"); + ctx->request = new AsyncHttpRequest(HTTP_METHOD_GET, "http://example.com/"); ctx->requestKeepAlive = true; ctx->resolvedTlsConfig = client.getDefaultTlsConfig(); ctx->transport = new MockTransport(false); @@ -103,7 +103,7 @@ static void test_does_not_pool_on_truncated_body() { client.setKeepAlive(true, 3000); auto ctx = new AsyncHttpClient::RequestContext(); - ctx->request = new AsyncHttpRequest(HTTP_GET, "http://example.com/"); + ctx->request = new AsyncHttpRequest(HTTP_METHOD_GET, "http://example.com/"); ctx->requestKeepAlive = true; ctx->resolvedTlsConfig = client.getDefaultTlsConfig(); ctx->transport = new MockTransport(false); @@ -131,7 +131,7 @@ static void test_reuses_pooled_connection() { // Seed pool with one connection auto poolCtx = new AsyncHttpClient::RequestContext(); - poolCtx->request = new AsyncHttpRequest(HTTP_GET, "http://example.com/"); + poolCtx->request = new AsyncHttpRequest(HTTP_METHOD_GET, "http://example.com/"); poolCtx->requestKeepAlive = true; poolCtx->resolvedTlsConfig = client.getDefaultTlsConfig(); poolCtx->transport = new MockTransport(false); @@ -145,7 +145,7 @@ static void test_reuses_pooled_connection() { MockTransport* pooled = static_cast(client._idleConnections[0].transport); auto ctx = new AsyncHttpClient::RequestContext(); - ctx->request = new AsyncHttpRequest(HTTP_GET, "http://example.com/"); + ctx->request = new AsyncHttpRequest(HTTP_METHOD_GET, "http://example.com/"); ctx->request->setHeader("Connection", "keep-alive"); ctx->response = new AsyncHttpResponse(); ctx->onSuccess = [](AsyncHttpResponse* resp) { TEST_ASSERT_EQUAL(200, resp->getStatusCode()); }; diff --git a/test/test_redirects/test_main.cpp b/test/test_redirects/test_main.cpp index a43adfd..c41fcee 100644 --- a/test/test_redirects/test_main.cpp +++ b/test/test_redirects/test_main.cpp @@ -32,7 +32,7 @@ static void cleanupContext(AsyncHttpClient::RequestContext* ctx) { static void test_redirect_same_host_get() { AsyncHttpClient client; client.setFollowRedirects(true, 3); - auto ctx = makeRedirectContext(HTTP_POST, "http://example.com/path"); + auto ctx = makeRedirectContext(HTTP_METHOD_POST, "http://example.com/path"); ctx->request->setHeader("Authorization", "Bearer token"); ctx->request->setHeader("Content-Type", "text/plain"); ctx->request->setBody("payload"); @@ -47,7 +47,7 @@ static void test_redirect_same_host_get() { TEST_ASSERT_TRUE(decision); TEST_ASSERT_NOT_NULL(newReq); - TEST_ASSERT_EQUAL(HTTP_GET, newReq->getMethod()); + TEST_ASSERT_EQUAL(HTTP_METHOD_GET, newReq->getMethod()); TEST_ASSERT_TRUE(newReq->getBody().isEmpty()); TEST_ASSERT_EQUAL_STRING("Bearer token", newReq->getHeader("Authorization").c_str()); TEST_ASSERT_TRUE(newReq->getHeader("Content-Type").isEmpty()); @@ -59,7 +59,7 @@ static void test_redirect_same_host_get() { static void test_redirect_cross_host_preserve_method_strip_auth() { AsyncHttpClient client; client.setFollowRedirects(true, 3); - auto ctx = makeRedirectContext(HTTP_POST, "http://example.com/login"); + auto ctx = makeRedirectContext(HTTP_METHOD_POST, "http://example.com/login"); ctx->request->setHeader("Authorization", "Bearer token"); ctx->request->setHeader("Proxy-Authorization", "Basic abc"); ctx->request->setHeader("Content-Type", "application/json"); @@ -75,7 +75,7 @@ static void test_redirect_cross_host_preserve_method_strip_auth() { TEST_ASSERT_TRUE(decision); TEST_ASSERT_NOT_NULL(newReq); - TEST_ASSERT_EQUAL(HTTP_POST, newReq->getMethod()); + TEST_ASSERT_EQUAL(HTTP_METHOD_POST, newReq->getMethod()); TEST_ASSERT_EQUAL_STRING("{\"name\":\"demo\"}", newReq->getBody().c_str()); TEST_ASSERT_TRUE(newReq->getHeader("Authorization").isEmpty()); TEST_ASSERT_TRUE(newReq->getHeader("Proxy-Authorization").isEmpty()); @@ -88,7 +88,7 @@ static void test_redirect_cross_host_preserve_method_strip_auth() { static void test_redirect_too_many_hops() { AsyncHttpClient client; client.setFollowRedirects(true, 2); - auto ctx = makeRedirectContext(HTTP_GET, "http://example.com/a"); + auto ctx = makeRedirectContext(HTTP_METHOD_GET, "http://example.com/a"); ctx->redirectCount = 2; ctx->response->setStatusCode(302); ctx->response->setHeader("Location", "/b"); @@ -109,7 +109,7 @@ static void test_redirect_too_many_hops() { static void test_redirect_to_https_supported() { AsyncHttpClient client; client.setFollowRedirects(true, 3); - auto ctx = makeRedirectContext(HTTP_GET, "http://example.com/path"); + auto ctx = makeRedirectContext(HTTP_METHOD_GET, "http://example.com/path"); ctx->response->setStatusCode(301); ctx->response->setHeader("Location", "https://secure.example.com/next"); @@ -142,7 +142,7 @@ static void test_header_limit_triggers_error() { AsyncHttpClient client; client.setMaxHeaderBytes(32); auto ctx = new AsyncHttpClient::RequestContext(); - ctx->request = new AsyncHttpRequest(HTTP_GET, "http://example.com/"); + ctx->request = new AsyncHttpRequest(HTTP_METHOD_GET, "http://example.com/"); ctx->response = new AsyncHttpResponse(); ctx->onError = [](HttpClientError error, const char* message) { (void)message; @@ -166,7 +166,7 @@ static void test_header_limit_allows_body_bytes_after_headers() { AsyncHttpClient client; client.setMaxHeaderBytes(48); auto ctx = new AsyncHttpClient::RequestContext(); - ctx->request = new AsyncHttpRequest(HTTP_GET, "http://example.com/"); + ctx->request = new AsyncHttpRequest(HTTP_METHOD_GET, "http://example.com/"); ctx->response = new AsyncHttpResponse(); ctx->onError = [](HttpClientError error, const char* message) { (void)message; @@ -189,13 +189,13 @@ static void test_header_limit_allows_body_bytes_after_headers() { static void test_cookie_roundtrip_basic() { AsyncHttpClient client; auto ctx = new AsyncHttpClient::RequestContext(); - ctx->request = new AsyncHttpRequest(HTTP_GET, "http://example.com/login"); + ctx->request = new AsyncHttpRequest(HTTP_METHOD_GET, "http://example.com/login"); ctx->response = new AsyncHttpResponse(); String frame = "HTTP/1.1 200 OK\r\nSet-Cookie: session=abc123; Path=/\r\nContent-Length: 0\r\n\r\n"; TEST_ASSERT_TRUE(client.parseResponseHeaders(ctx, frame)); - AsyncHttpRequest follow(HTTP_GET, "http://example.com/home"); + AsyncHttpRequest follow(HTTP_METHOD_GET, "http://example.com/home"); client.applyCookies(&follow); TEST_ASSERT_EQUAL_STRING("session=abc123", follow.getHeader("Cookie").c_str()); @@ -205,21 +205,21 @@ static void test_cookie_roundtrip_basic() { static void test_cookie_path_and_secure_rules() { AsyncHttpClient client; auto ctx = new AsyncHttpClient::RequestContext(); - ctx->request = new AsyncHttpRequest(HTTP_GET, "http://example.com/login"); + ctx->request = new AsyncHttpRequest(HTTP_METHOD_GET, "http://example.com/login"); ctx->response = new AsyncHttpResponse(); String frame = "HTTP/1.1 200 OK\r\nSet-Cookie: admin=1; Path=/admin; Secure\r\nContent-Length: 0\r\n\r\n"; TEST_ASSERT_TRUE(client.parseResponseHeaders(ctx, frame)); - AsyncHttpRequest wrongPath(HTTP_GET, "http://example.com/public"); + AsyncHttpRequest wrongPath(HTTP_METHOD_GET, "http://example.com/public"); client.applyCookies(&wrongPath); TEST_ASSERT_TRUE(wrongPath.getHeader("Cookie").isEmpty()); - AsyncHttpRequest insecureTarget(HTTP_GET, "http://example.com/admin/dashboard"); + AsyncHttpRequest insecureTarget(HTTP_METHOD_GET, "http://example.com/admin/dashboard"); client.applyCookies(&insecureTarget); TEST_ASSERT_TRUE(insecureTarget.getHeader("Cookie").isEmpty()); - AsyncHttpRequest secureTarget(HTTP_GET, "https://example.com/admin/dashboard"); + AsyncHttpRequest secureTarget(HTTP_METHOD_GET, "https://example.com/admin/dashboard"); client.applyCookies(&secureTarget); TEST_ASSERT_EQUAL_STRING("admin=1", secureTarget.getHeader("Cookie").c_str()); diff --git a/test/test_urlparser_native/test_main.cpp b/test/test_urlparser_native/test_main.cpp new file mode 100644 index 0000000..424f364 --- /dev/null +++ b/test/test_urlparser_native/test_main.cpp @@ -0,0 +1,117 @@ +#include + +#include +#include + +#include "UrlParser.h" + +#include "../../url_test_cases.h" + +static void test_parse_url_shared_cases() { +#define X(url, expHost, expPath, expPort, expSecure, expImplicit) \ + do { \ + UrlParser::ParsedUrl parsed; \ + TEST_ASSERT_TRUE_MESSAGE(UrlParser::parse(url, parsed), url); \ + TEST_ASSERT_EQUAL_STRING(expHost, parsed.host.c_str()); \ + TEST_ASSERT_EQUAL_STRING(expPath, parsed.path.c_str()); \ + TEST_ASSERT_EQUAL_UINT16(expPort, parsed.port); \ + TEST_ASSERT_EQUAL(expSecure, parsed.secure); \ + TEST_ASSERT_EQUAL(expImplicit, parsed.schemeImplicit); \ + } while (0); + URL_TEST_CASES +#undef X +} + +static void test_rejects_urls_with_control_chars_and_whitespace() { + struct Case { + const char* name; + std::string url; + }; + + const std::vector cases = { + {"space", "http://example.com/pa th"}, + {"tab", std::string("http://example.com/") + std::string(1, '\t')}, + {"newline", std::string("http://example.com/") + std::string(1, '\n')}, + {"carriage_return", std::string("http://example.com/") + std::string(1, '\r')}, + {"vertical_tab", std::string("http://example.com/") + std::string(1, '\v')}, + {"form_feed", std::string("http://example.com/") + std::string(1, '\f')}, + {"esc", std::string("http://example.com/") + std::string(1, static_cast(0x1B))}, + {"del", std::string("http://example.com/") + std::string(1, static_cast(0x7F))}, + {"nul", std::string("http://exam") + std::string(1, '\0') + "ple.com/"}, + }; + + for (const auto& tc : cases) { + UrlParser::ParsedUrl parsed; + TEST_ASSERT_FALSE_MESSAGE(UrlParser::parse(tc.url, parsed), tc.name); + } +} + +static void test_rejects_hosts_with_invalid_characters() { + const std::vector urls = { + "http://exa_mple.com/", + "http://example!.com/", + "http://examp[le.com/", + "http://example.com@evil.com/", + "http://.example.com/", + "http://example.com./", + "http:///", + "http://:80/", + }; + + for (const auto& url : urls) { + UrlParser::ParsedUrl parsed; + TEST_ASSERT_FALSE_MESSAGE(UrlParser::parse(url, parsed), url.c_str()); + } + + const std::string longHost(256, 'a'); + { + UrlParser::ParsedUrl parsed; + TEST_ASSERT_FALSE(UrlParser::parse(std::string("http://") + longHost + "/", parsed)); + } +} + +static void test_port_boundaries_and_invalid_ports() { + { + UrlParser::ParsedUrl parsed; + TEST_ASSERT_TRUE(UrlParser::parse("http://example.com:0/", parsed)); + TEST_ASSERT_EQUAL_UINT16(0, parsed.port); + } + { + UrlParser::ParsedUrl parsed; + TEST_ASSERT_TRUE(UrlParser::parse("http://example.com:65535/", parsed)); + TEST_ASSERT_EQUAL_UINT16(65535, parsed.port); + } + { + UrlParser::ParsedUrl parsed; + TEST_ASSERT_FALSE(UrlParser::parse("http://example.com:65536/", parsed)); + } + + const std::vector invalidPorts = { + "http://example.com:/", + "http://example.com:", + "http://example.com:?q=1", + "http://example.com:-1/", + "http://example.com:+1/", + "http://example.com:12a3/", + "http://example.com:18446744073709551616/", + "http://example.com:9999999999999999999999999999999999999999/", + "http://example.com::80/", + }; + + for (const auto& url : invalidPorts) { + UrlParser::ParsedUrl parsed; + TEST_ASSERT_FALSE_MESSAGE(UrlParser::parse(url, parsed), url.c_str()); + } +} + +int main(int argc, char** argv) { + (void)argc; + (void)argv; + + UNITY_BEGIN(); + RUN_TEST(test_parse_url_shared_cases); + RUN_TEST(test_rejects_urls_with_control_chars_and_whitespace); + RUN_TEST(test_rejects_hosts_with_invalid_characters); + RUN_TEST(test_port_boundaries_and_invalid_ports); + return UNITY_END(); +} diff --git a/test_parse_url.py b/test_parse_url.py index 854b736..e887699 100644 --- a/test_parse_url.py +++ b/test_parse_url.py @@ -7,6 +7,7 @@ def parse_url(url: str): url_copy = url secure = False port = 80 + scheme_implicit = False if url_copy.startswith("https://"): secure = True port = 443 @@ -16,8 +17,9 @@ def parse_url(url: str): port = 80 url_copy = url_copy[7:] else: - secure = False - port = 80 + secure = True + port = 443 + scheme_implicit = True path_index = url_copy.find('/') query_index = url_copy.find('?') @@ -41,7 +43,7 @@ def parse_url(url: str): port = int(host[port_index + 1:]) host = host[:port_index] - return host, path, port, secure + return host, path, port, secure, scheme_implicit def load_cases(): @@ -49,15 +51,15 @@ def load_cases(): if not header.exists(): return [] content = header.read_text() - pattern = re.compile(r'X\("([^"\\]+)","([^"\\]+)","([^"\\]+)",(\d+),(true|false)\)') + pattern = re.compile(r'X\("([^"\\]+)","([^"\\]+)","([^"\\]+)",(\d+),(true|false),(true|false)\)') cases = [] for m in pattern.finditer(content): - url, host, path, port, secure = m.groups() - cases.append((url, host, path, int(port), secure == 'true')) + url, host, path, port, secure, implicit = m.groups() + cases.append((url, host, path, int(port), secure == 'true', implicit == 'true')) return cases -@pytest.mark.parametrize("url,exp_host,exp_path,exp_port,exp_secure", load_cases()) -def test_parse_url_table(url, exp_host, exp_path, exp_port, exp_secure): - host, path, port, secure = parse_url(url) - assert (host, path, port, secure) == (exp_host, exp_path, exp_port, exp_secure) +@pytest.mark.parametrize("url,exp_host,exp_path,exp_port,exp_secure,exp_implicit", load_cases()) +def test_parse_url_table(url, exp_host, exp_path, exp_port, exp_secure, exp_implicit): + host, path, port, secure, implicit = parse_url(url) + assert (host, path, port, secure, implicit) == (exp_host, exp_path, exp_port, exp_secure, exp_implicit) diff --git a/url_test_cases.h b/url_test_cases.h index 1d9d6e9..9102cfa 100644 --- a/url_test_cases.h +++ b/url_test_cases.h @@ -1,13 +1,13 @@ #ifndef URL_TEST_CASES_H #define URL_TEST_CASES_H // Shared URL test cases between Python and C++ tests. -// Format macro: X(url, host, path, port, secureBool) +// Format macro: X(url, host, path, port, secureBool, schemeImplicit) #define URL_TEST_CASES \ - X("http://example.com?foo=bar", "example.com", "/?foo=bar", 80, false) \ - X("https://example.com/path?foo=bar", "example.com", "/path?foo=bar", 443, true) \ - X("http://example.com", "example.com", "/", 80, false) \ - X("http://example.com:8080/api", "example.com", "/api", 8080, false) \ - X("https://example.com:4443/", "example.com", "/", 4443, true) \ - X("example.com", "example.com", "/", 80, false) \ - X("example.com?x=1", "example.com", "/?x=1", 80, false) + X("http://example.com?foo=bar", "example.com", "/?foo=bar", 80, false, false) \ + X("https://example.com/path?foo=bar", "example.com", "/path?foo=bar", 443, true, false) \ + X("http://example.com", "example.com", "/", 80, false, false) \ + X("http://example.com:8080/api", "example.com", "/api", 8080, false, false) \ + X("https://example.com:4443/", "example.com", "/", 4443, true, false) \ + X("example.com", "example.com", "/", 443, true, true) \ + X("example.com?x=1", "example.com", "/?x=1", 443, true, true) #endif // URL_TEST_CASES_H