Compare commits

..

14 Commits

Author SHA1 Message Date
Yuji Hirose
048f31109f Updated README 2019-12-10 13:14:23 -05:00
Yuji Hirose
d064fb7ff2 Fixed warning 2019-12-10 13:08:07 -05:00
Yuji Hirose
3c2736bb2a Fixed regex syntax error 2019-12-10 13:07:49 -05:00
Yuji Hirose
fd4e1b4112 Fix #266 2019-12-10 12:10:14 -05:00
yhirose
f6a2365ca5 Fix #282 2019-12-06 12:21:15 -05:00
yhirose
df1ff7510b Made code more readable 2019-12-06 12:02:08 -05:00
yhirose
379905bd34 Merge branch 'whitespace-and-libcxx-compat' of https://github.com/matvore/cpp-httplib 2019-12-06 09:51:21 -05:00
yhirose
66719ae3d4 Merge pull request #283 from barryam3/noexcept
Remove use of exceptions.
2019-12-05 21:32:06 -05:00
Matthew DeVore
bc9251ea49 Work around incompatibility in <regex> in libc++
libc++ (the implementation of the C++ standard library usually used by
Clang) throws an exception for the regex used by parse_headers before
this patch for certain strings. Work around this by simplifying the
regex and parsing the header lines "by hand" partially. I have repro'd
this problem with Xcode 11.1 which I believe uses libc++ version 8.

This may be a bug in libc++ as I can't see why the regex would result in
asymptotic run-time complexity for any strings. However, it may take a
while for libc++ to be fixed and for everyone to migrate to it, so it
makes sense to work around it in this codebase for now.
2019-12-05 17:14:16 -08:00
Matthew DeVore
a9e942d755 Properly trim whitespace from headers
HTTP Whitespace and regex whitespace are not the same, so we can't use
\s in regexes when parsing HTTP headers. Instead, explicitly specify
what is considered whitespace in the regex.
2019-12-05 17:14:16 -08:00
Barry McNamara
e1785d6723 Remove use of exceptions. 2019-12-05 15:56:55 -08:00
yhirose
b9539b8921 Fixed build errors 2019-12-03 10:30:07 -05:00
yhirose
4c93b973ff Fixed typo in README 2019-12-02 09:50:52 -05:00
yhirose
033bc35723 Improve multipart content reader interface 2019-12-02 07:11:12 -05:00
6 changed files with 437 additions and 99 deletions

View File

@@ -121,14 +121,14 @@ svr.Post("/content_receiver",
if (req.is_multipart_form_data()) { if (req.is_multipart_form_data()) {
MultipartFiles files; MultipartFiles files;
content_reader( content_reader(
[&](const std::string &name, const MultipartFile &file) {
files.emplace(name, file);
return true;
},
[&](const std::string &name, const char *data, size_t data_length) { [&](const std::string &name, const char *data, size_t data_length) {
auto &file = files.find(name)->second; auto &file = files.find(name)->second;
file.content.append(data, data_length); file.content.append(data, data_length);
return true; return true;
},
[&](const std::string &name, const MultipartFile &file) {
files.emplace(name, file);
return true;
}); });
} else { } else {
std::string body; std::string body;
@@ -156,7 +156,7 @@ svr.Get("/chunked", [&](const Request& req, Response& res) {
}); });
``` ```
### Default thread pool supporet ### Default thread pool support
Set thread count to 8: Set thread count to 8:
@@ -324,16 +324,23 @@ std::shared_ptr<httplib::Response> res =
This feature was contributed by [underscorediscovery](https://github.com/yhirose/cpp-httplib/pull/23). This feature was contributed by [underscorediscovery](https://github.com/yhirose/cpp-httplib/pull/23).
### Basic Authentication ### Authentication
NOTE: OpenSSL is required for Digest Authentication, since cpp-httplib uses message digest functions in OpenSSL.
```cpp ```cpp
httplib::Client cli("httplib.org"); httplib::Client cli("httplib.org");
cli.set_auth("user", "pass");
auto res = cli.Get("/basic-auth/hello/world", { // Basic
httplib::make_basic_authentication_header("hello", "world") auto res = cli.Get("/basic-auth/user/pass");
});
// res->status should be 200 // res->status should be 200
// res->body should be "{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n". // res->body should be "{\n \"authenticated\": true, \n \"user\": \"user\"\n}\n".
// Digest
res = cli.Get("/digest-auth/auth/user/pass/SHA-256");
// res->status should be 200
// res->body should be "{\n \"authenticated\": true, \n \"user\": \"user\"\n}\n".
``` ```
### Range ### Range

View File

@@ -33,4 +33,4 @@ pem:
openssl req -new -key key.pem | openssl x509 -days 3650 -req -signkey key.pem > cert.pem openssl req -new -key key.pem | openssl x509 -days 3650 -req -signkey key.pem > cert.pem
clean: clean:
rm server client hello simplesvr upload redirect *.pem rm server client hello simplesvr upload redirect benchmark *.pem

View File

@@ -46,10 +46,7 @@ string dump_multipart_files(const MultipartFiles &files) {
snprintf(buf, sizeof(buf), "content type: %s\n", file.content_type.c_str()); snprintf(buf, sizeof(buf), "content type: %s\n", file.content_type.c_str());
s += buf; s += buf;
snprintf(buf, sizeof(buf), "text offset: %lu\n", file.offset); snprintf(buf, sizeof(buf), "text length: %lu\n", file.content.size());
s += buf;
snprintf(buf, sizeof(buf), "text length: %lu\n", file.length);
s += buf; s += buf;
s += "----------------\n"; s += "----------------\n";

View File

@@ -37,10 +37,10 @@ int main(void) {
svr.Post("/post", [](const Request & req, Response &res) { svr.Post("/post", [](const Request & req, Response &res) {
auto file = req.get_file_value("file"); auto file = req.get_file_value("file");
cout << "file: " << file.offset << ":" << file.length << ":" << file.filename << endl; cout << "file length: " << file.content.length() << ":" << file.filename << endl;
ofstream ofs(file.filename, ios::binary); ofstream ofs(file.filename, ios::binary);
ofs << req.body.substr(file.offset, file.length); ofs << file.content;
res.set_content("done", "text/plain"); res.set_content("done", "text/plain");
}); });

363
httplib.h
View File

@@ -149,9 +149,13 @@ using socket_t = int;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
#include <openssl/err.h> #include <openssl/err.h>
#include <openssl/md5.h>
#include <openssl/ssl.h> #include <openssl/ssl.h>
#include <openssl/x509v3.h> #include <openssl/x509v3.h>
#include <iomanip>
#include <sstream>
// #if OPENSSL_VERSION_NUMBER < 0x1010100fL // #if OPENSSL_VERSION_NUMBER < 0x1010100fL
// #error Sorry, OpenSSL versions prior to 1.1.1 are not supported // #error Sorry, OpenSSL versions prior to 1.1.1 are not supported
// #endif // #endif
@@ -225,22 +229,22 @@ using MultipartFormDataItems = std::vector<MultipartFormData>;
using ContentReceiver = using ContentReceiver =
std::function<bool(const char *data, size_t data_length)>; std::function<bool(const char *data, size_t data_length)>;
using MultipartContentReceiver =
std::function<bool(const std::string& name, const char *data, size_t data_length)>;
using MultipartContentHeader = using MultipartContentHeader =
std::function<bool(const std::string &name, const MultipartFile &file)>; std::function<bool(const std::string &name, const MultipartFile &file)>;
using MultipartContentReceiver =
std::function<bool(const std::string& name, const char *data, size_t data_length)>;
class ContentReader { class ContentReader {
public: public:
using Reader = std::function<bool(ContentReceiver receiver)>; using Reader = std::function<bool(ContentReceiver receiver)>;
using MultipartReader = std::function<bool(MultipartContentReceiver receiver, MultipartContentHeader header)>; using MultipartReader = std::function<bool(MultipartContentHeader header, MultipartContentReceiver receiver)>;
ContentReader(Reader reader, MultipartReader muitlpart_reader) ContentReader(Reader reader, MultipartReader muitlpart_reader)
: reader_(reader), muitlpart_reader_(muitlpart_reader) {} : reader_(reader), muitlpart_reader_(muitlpart_reader) {}
bool operator()(MultipartContentReceiver receiver, MultipartContentHeader header) const { bool operator()(MultipartContentHeader header, MultipartContentReceiver receiver) const {
return muitlpart_reader_(receiver, header); return muitlpart_reader_(header, receiver);
} }
bool operator()(ContentReceiver receiver) const { bool operator()(ContentReceiver receiver) const {
@@ -591,13 +595,13 @@ private:
bool read_content_with_content_receiver(Stream &strm, bool last_connection, bool read_content_with_content_receiver(Stream &strm, bool last_connection,
Request &req, Response &res, Request &req, Response &res,
ContentReceiver receiver, ContentReceiver receiver,
MultipartContentReceiver multipart_receiver, MultipartContentHeader multipart_header,
MultipartContentHeader multipart_header); MultipartContentReceiver multipart_receiver);
bool read_content_core(Stream &strm, bool last_connection, bool read_content_core(Stream &strm, bool last_connection,
Request &req, Response &res, Request &req, Response &res,
ContentReceiver receiver, ContentReceiver receiver,
MultipartContentReceiver multipart_receiver, MultipartContentHeader mulitpart_header,
MultipartContentHeader mulitpart_header); MultipartContentReceiver multipart_receiver);
virtual bool process_and_close_socket(socket_t sock); virtual bool process_and_close_socket(socket_t sock);
@@ -756,10 +760,13 @@ public:
std::vector<Response> &responses); std::vector<Response> &responses);
void set_keep_alive_max_count(size_t count); void set_keep_alive_max_count(size_t count);
void set_read_timeout(time_t sec, time_t usec); void set_read_timeout(time_t sec, time_t usec);
void follow_location(bool on); void follow_location(bool on);
void set_auth(const char *username, const char *password);
protected: protected:
bool process_request(Stream &strm, const Request &req, Response &res, bool process_request(Stream &strm, const Request &req, Response &res,
bool last_connection, bool &connection_close); bool last_connection, bool &connection_close);
@@ -772,6 +779,8 @@ protected:
time_t read_timeout_sec_; time_t read_timeout_sec_;
time_t read_timeout_usec_; time_t read_timeout_usec_;
size_t follow_location_; size_t follow_location_;
std::string username_;
std::string password_;
private: private:
socket_t create_client_socket() const; socket_t create_client_socket() const;
@@ -1114,6 +1123,11 @@ public:
} }
} }
bool end_with_crlf() const {
auto end = ptr() + size();
return size() >= 2 && end[-2] == '\r' && end[-1] == '\n';
}
bool getline() { bool getline() {
fixed_buffer_used_size_ = 0; fixed_buffer_used_size_ = 0;
glowable_buffer_.clear(); glowable_buffer_.clear();
@@ -1357,6 +1371,26 @@ inline bool is_connection_error() {
#endif #endif
} }
inline socket_t create_client_socket(
const char *host, int port, time_t timeout_sec) {
return create_socket(
host, port, [=](socket_t sock, struct addrinfo &ai) -> bool {
set_nonblocking(sock, true);
auto ret = ::connect(sock, ai.ai_addr, static_cast<int>(ai.ai_addrlen));
if (ret < 0) {
if (is_connection_error() ||
!wait_until_socket_is_ready(sock, timeout_sec, 0)) {
close_socket(sock);
return false;
}
}
set_nonblocking(sock, false);
return true;
});
}
inline std::string get_remote_addr(socket_t sock) { inline std::string get_remote_addr(socket_t sock) {
struct sockaddr_storage addr; struct sockaddr_storage addr;
socklen_t len = sizeof(addr); socklen_t len = sizeof(addr);
@@ -1414,6 +1448,7 @@ inline const char *status_message(int status) {
case 303: return "See Other"; case 303: return "See Other";
case 304: return "Not Modified"; case 304: return "Not Modified";
case 400: return "Bad Request"; case 400: return "Bad Request";
case 401: return "Unauthorized";
case 403: return "Forbidden"; case 403: return "Forbidden";
case 404: return "Not Found"; case 404: return "Not Found";
case 413: return "Payload Too Large"; case 413: return "Payload Too Large";
@@ -1542,18 +1577,35 @@ inline uint64_t get_header_value_uint64(const Headers &headers, const char *key,
} }
inline bool read_headers(Stream &strm, Headers &headers) { inline bool read_headers(Stream &strm, Headers &headers) {
static std::regex re(R"((.+?):\s*(.+?)\s*\r\n)");
const auto bufsiz = 2048; const auto bufsiz = 2048;
char buf[bufsiz]; char buf[bufsiz];
stream_line_reader line_reader(strm, buf, bufsiz); stream_line_reader line_reader(strm, buf, bufsiz);
for (;;) { for (;;) {
if (!line_reader.getline()) { return false; } if (!line_reader.getline()) { return false; }
if (!strcmp(line_reader.ptr(), "\r\n")) { break; }
// Check if the line ends with CRLF.
if (line_reader.end_with_crlf()) {
// Blank line indicates end of headers.
if (line_reader.size() == 2) { break; }
} else {
continue; // Skip invalid line.
}
// Skip trailing spaces and tabs.
auto end = line_reader.ptr() + line_reader.size() - 2;
while (line_reader.ptr() < end && (end[-1] == ' ' || end[-1] == '\t')) {
end--;
}
// Horizontal tab and ' ' are considered whitespace and are ignored when on
// the left or right side of the header value:
// - https://stackoverflow.com/questions/50179659/
// - https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html
static const std::regex re(R"((.+?):[\t ]*(.+))");
std::cmatch m; std::cmatch m;
if (std::regex_match(line_reader.ptr(), m, re)) { if (std::regex_match(line_reader.ptr(), end, m, re)) {
auto key = std::string(m[1]); auto key = std::string(m[1]);
auto val = std::string(m[2]); auto val = std::string(m[2]);
headers.emplace(key, val); headers.emplace(key, val);
@@ -1881,38 +1933,39 @@ inline bool parse_multipart_boundary(const std::string &content_type,
} }
inline bool parse_range_header(const std::string &s, Ranges &ranges) { inline bool parse_range_header(const std::string &s, Ranges &ranges) {
try { static auto re_first_range =
static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))");
std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); std::smatch m;
std::smatch m; if (std::regex_match(s, m, re_first_range)) {
if (std::regex_match(s, m, re_first_range)) { auto pos = m.position(1);
auto pos = m.position(1); auto len = m.length(1);
auto len = m.length(1); bool all_valid_ranges = true;
detail::split( detail::split(
&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { &s[pos], &s[pos + len], ',', [&](const char *b, const char *e) {
static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); if (!all_valid_ranges) return;
std::cmatch m; static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))");
if (std::regex_match(b, e, m, re_another_range)) { std::cmatch m;
ssize_t first = -1; if (std::regex_match(b, e, m, re_another_range)) {
if (!m.str(1).empty()) { ssize_t first = -1;
first = static_cast<ssize_t>(std::stoll(m.str(1))); if (!m.str(1).empty()) {
} first = static_cast<ssize_t>(std::stoll(m.str(1)));
ssize_t last = -1;
if (!m.str(2).empty()) {
last = static_cast<ssize_t>(std::stoll(m.str(2)));
}
if (first != -1 && last != -1 && first > last) {
throw std::runtime_error("invalid range error");
}
ranges.emplace_back(std::make_pair(first, last));
} }
});
return true; ssize_t last = -1;
} if (!m.str(2).empty()) {
return false; last = static_cast<ssize_t>(std::stoll(m.str(2)));
} catch (...) { return false; } }
if (first != -1 && last != -1 && first > last) {
all_valid_ranges = false;
return;
}
ranges.emplace_back(std::make_pair(first, last));
}
});
return all_valid_ranges;
}
return false;
} }
class MultipartFormDataParser { class MultipartFormDataParser {
@@ -2035,17 +2088,15 @@ public:
break; break;
} }
case 4: { // Boundary case 4: { // Boundary
auto pos = buf_.find(crlf_);
if (crlf_.size() > buf_.size()) { return true; } if (crlf_.size() > buf_.size()) { return true; }
if (pos == 0) { if (buf_.find(crlf_) == 0) {
buf_.erase(0, crlf_.size()); buf_.erase(0, crlf_.size());
off_ += crlf_.size(); off_ += crlf_.size();
state_ = 1; state_ = 1;
} else { } else {
auto pattern = dash_ + crlf_; auto pattern = dash_ + crlf_;
if (pattern.size() > buf_.size()) { return true; } if (pattern.size() > buf_.size()) { return true; }
auto pos = buf_.find(pattern); if (buf_.find(pattern) == 0) {
if (pos == 0) {
buf_.erase(0, pattern.size()); buf_.erase(0, pattern.size());
off_ += pattern.size(); off_ += pattern.size();
is_valid_ = true; is_valid_ = true;
@@ -2246,6 +2297,43 @@ inline bool expect_content(const Request &req) {
return false; return false;
} }
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
template <typename CTX, typename Init, typename Update, typename Final>
inline std::string message_digest(const std::string &s, Init init,
Update update, Final final,
size_t digest_length) {
using namespace std;
unsigned char md[digest_length];
CTX ctx;
init(&ctx);
update(&ctx, s.data(), s.size());
final(md, &ctx);
stringstream ss;
for (auto c : md) {
ss << setfill('0') << setw(2) << hex << (unsigned int)c;
}
return ss.str();
}
inline std::string MD5(const std::string &s) {
using namespace detail;
return message_digest<MD5_CTX>(s, MD5_Init, MD5_Update, MD5_Final,
MD5_DIGEST_LENGTH);
}
inline std::string SHA_256(const std::string &s) {
return message_digest<SHA256_CTX>(s, SHA256_Init, SHA256_Update, SHA256_Final,
SHA256_DIGEST_LENGTH);
}
inline std::string SHA_512(const std::string &s) {
return message_digest<SHA512_CTX>(s, SHA512_Init, SHA512_Update, SHA512_Final,
SHA512_DIGEST_LENGTH);
}
#endif
#ifdef _WIN32 #ifdef _WIN32
class WSInit { class WSInit {
public: public:
@@ -2283,6 +2371,98 @@ make_basic_authentication_header(const std::string &username,
return std::make_pair("Authorization", field); return std::make_pair("Authorization", field);
} }
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
inline std::pair<std::string, std::string> make_digest_authentication_header(
const Request &req,
const std::map<std::string, std::string> &auth,
size_t cnonce_count, const std::string &cnonce,
const std::string &username, const std::string &password) {
using namespace std;
string nc;
{
stringstream ss;
ss << setfill('0') << setw(8) << hex << cnonce_count;
nc = ss.str();
}
auto qop = auth.at("qop");
if (qop.find("auth-int") != std::string::npos) {
qop = "auth-int";
} else {
qop = "auth";
}
string response;
{
auto algo = auth.at("algorithm");
auto H = algo == "SHA-256"
? detail::SHA_256
: algo == "SHA-512" ? detail::SHA_512 : detail::MD5;
auto A1 = username + ":" + auth.at("realm") + ":" + password;
auto A2 = req.method + ":" + req.path;
if (qop == "auth-int") {
A2 += ":" + H(req.body);
}
response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce +
":" + qop + ":" + H(A2));
}
auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") +
"\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path +
"\", algorithm=" + auth.at("algorithm") + ", qop=" + qop + ", nc=\"" +
nc + "\", cnonce=\"" + cnonce + "\", response=\"" + response +
"\"";
return make_pair("Authorization", field);
}
#endif
inline int parse_www_authenticate(const httplib::Response &res,
std::map<std::string, std::string> &digest_auth) {
if (res.has_header("WWW-Authenticate")) {
static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~");
auto s = res.get_header_value("WWW-Authenticate");
auto pos = s.find(' ');
if (pos != std::string::npos) {
auto type = s.substr(0, pos);
if (type == "Basic") {
return 1;
} else if (type == "Digest") {
s = s.substr(pos + 1);
auto beg = std::sregex_iterator(s.begin(), s.end(), re);
for (auto i = beg; i != std::sregex_iterator(); ++i) {
auto m = *i;
auto key = s.substr(m.position(1), m.length(1));
auto val = m.length(2) > 0 ? s.substr(m.position(2), m.length(2))
: s.substr(m.position(3), m.length(3));
digest_auth[key] = val;
}
return 2;
}
}
}
return 0;
}
// https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240
inline std::string random_string(size_t length) {
auto randchar = []() -> char {
const char charset[] = "0123456789"
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz";
const size_t max_index = (sizeof(charset) - 1);
return charset[rand() % max_index];
};
std::string str(length, 0);
std::generate_n(str.begin(), length, randchar);
return str;
}
// Request implementation // Request implementation
inline bool Request::has_header(const char *key) const { inline bool Request::has_header(const char *key) const {
return detail::has_header(headers, key); return detail::has_header(headers, key);
@@ -2796,11 +2976,17 @@ Server::write_content_with_provider(Stream &strm, const Request &req,
inline bool Server::read_content(Stream &strm, bool last_connection, inline bool Server::read_content(Stream &strm, bool last_connection,
Request &req, Response &res) { Request &req, Response &res) {
auto ret = read_content_core(strm, last_connection, req, res, auto ret = read_content_core(strm, last_connection, req, res,
// Regular
[&](const char *buf, size_t n) { [&](const char *buf, size_t n) {
if (req.body.size() + n > req.body.max_size()) { return false; } if (req.body.size() + n > req.body.max_size()) { return false; }
req.body.append(buf, n); req.body.append(buf, n);
return true; return true;
}, },
// Multipart
[&](const std::string &name, const MultipartFile &file) {
req.files.emplace(name, file);
return true;
},
[&](const std::string &name, const char *buf, size_t n) { [&](const std::string &name, const char *buf, size_t n) {
// TODO: handle elements with a same key // TODO: handle elements with a same key
auto it = req.files.find(name); auto it = req.files.find(name);
@@ -2808,10 +2994,6 @@ inline bool Server::read_content(Stream &strm, bool last_connection,
if (content.size() + n > content.max_size()) { return false; } if (content.size() + n > content.max_size()) { return false; }
content.append(buf, n); content.append(buf, n);
return true; return true;
},
[&](const std::string &name, const MultipartFile &file) {
req.files.emplace(name, file);
return true;
} }
); );
@@ -2827,18 +3009,18 @@ inline bool
Server::read_content_with_content_receiver(Stream &strm, bool last_connection, Server::read_content_with_content_receiver(Stream &strm, bool last_connection,
Request &req, Response &res, Request &req, Response &res,
ContentReceiver receiver, ContentReceiver receiver,
MultipartContentReceiver multipart_receiver, MultipartContentHeader multipart_header,
MultipartContentHeader multipart_header) { MultipartContentReceiver multipart_receiver) {
return read_content_core(strm, last_connection, req, res, return read_content_core(strm, last_connection, req, res,
receiver, multipart_receiver, multipart_header); receiver, multipart_header, multipart_receiver);
} }
inline bool inline bool
Server::read_content_core(Stream &strm, bool last_connection, Server::read_content_core(Stream &strm, bool last_connection,
Request &req, Response &res, Request &req, Response &res,
ContentReceiver receiver, ContentReceiver receiver,
MultipartContentReceiver multipart_receiver, MultipartContentHeader mulitpart_header,
MultipartContentHeader mulitpart_header) { MultipartContentReceiver multipart_receiver) {
detail::MultipartFormDataParser multipart_form_data_parser; detail::MultipartFormDataParser multipart_form_data_parser;
ContentReceiver out; ContentReceiver out;
@@ -3001,9 +3183,9 @@ inline bool Server::routing(Request &req, Response &res, Stream &strm,
return read_content_with_content_receiver(strm, last_connection, req, res, return read_content_with_content_receiver(strm, last_connection, req, res,
receiver, nullptr, nullptr); receiver, nullptr, nullptr);
}, },
[&](MultipartContentReceiver receiver, MultipartContentHeader header) { [&](MultipartContentHeader header, MultipartContentReceiver receiver) {
return read_content_with_content_receiver(strm, last_connection, req, res, return read_content_with_content_receiver(strm, last_connection, req, res,
nullptr, receiver, header); nullptr, header, receiver);
} }
); );
@@ -3164,22 +3346,7 @@ inline Client::~Client() {}
inline bool Client::is_valid() const { return true; } inline bool Client::is_valid() const { return true; }
inline socket_t Client::create_client_socket() const { inline socket_t Client::create_client_socket() const {
return detail::create_socket( return detail::create_client_socket(host_.c_str(), port_, timeout_sec_);
host_.c_str(), port_, [=](socket_t sock, struct addrinfo &ai) -> bool {
detail::set_nonblocking(sock, true);
auto ret = connect(sock, ai.ai_addr, static_cast<int>(ai.ai_addrlen));
if (ret < 0) {
if (detail::is_connection_error() ||
!detail::wait_until_socket_is_ready(sock, timeout_sec_, 0)) {
detail::close_socket(sock);
return false;
}
}
detail::set_nonblocking(sock, false);
return true;
});
} }
inline bool Client::read_response_line(Stream &strm, Response &res) { inline bool Client::read_response_line(Stream &strm, Response &res) {
@@ -3216,6 +3383,43 @@ inline bool Client::send(const Request &req, Response &res) {
ret = redirect(req, res); ret = redirect(req, res);
} }
if (ret && !username_.empty() && !password_.empty() && res.status == 401) {
int type;
std::map<std::string, std::string> digest_auth;
if ((type = parse_www_authenticate(res, digest_auth)) > 0) {
std::pair<std::string, std::string> header;
if (type == 1) {
header = make_basic_authentication_header(username_, password_);
} else if (type == 2) {
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
size_t cnonce_count = 1;
auto cnonce = random_string(10);
header = make_digest_authentication_header(
req, digest_auth, cnonce_count, cnonce, username_, password_);
#endif
}
Request new_req;
new_req.method = req.method;
new_req.path = req.path;
new_req.headers = req.headers;
new_req.body = req.body;
new_req.response_handler = req.response_handler;
new_req.content_receiver = req.content_receiver;
new_req.progress = req.progress;
new_req.headers.insert(header);
Response new_res;
auto ret = send(new_req, new_res);
if (ret) { res = new_res; }
return ret;
}
}
return ret; return ret;
} }
@@ -3782,6 +3986,11 @@ inline void Client::set_read_timeout(time_t sec, time_t usec) {
inline void Client::follow_location(bool on) { follow_location_ = on; } inline void Client::follow_location(bool on) { follow_location_ = on; }
inline void Client::set_auth(const char *username, const char *password) {
username_ = username;
password_ = password;
}
/* /*
* SSL Implementation * SSL Implementation
*/ */

View File

@@ -469,8 +469,50 @@ TEST(BaseAuthTest, FromHTTPWatch) {
"{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n"); "{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n");
EXPECT_EQ(200, res->status); EXPECT_EQ(200, res->status);
} }
{
cli.set_auth("hello", "world");
auto res = cli.Get("/basic-auth/hello/world");
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(res->body,
"{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n");
EXPECT_EQ(200, res->status);
}
} }
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
TEST(DigestAuthTest, FromHTTPWatch) {
auto host = "httpbin.org";
auto port = 443;
httplib::SSLClient cli(host, port);
{
auto res = cli.Get("/digest-auth/auth/hello/world");
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(401, res->status);
}
{
std::vector<std::string> paths = {
"/digest-auth/auth/hello/world/MD5",
"/digest-auth/auth/hello/world/SHA-256",
"/digest-auth/auth/hello/world/SHA-512",
"/digest-auth/auth-init/hello/world/MD5",
"/digest-auth/auth-int/hello/world/MD5",
};
cli.set_auth("hello", "world");
for (auto path: paths) {
auto res = cli.Get(path.c_str());
ASSERT_TRUE(res != nullptr);
EXPECT_EQ(res->body,
"{\n \"authenticated\": true, \n \"user\": \"hello\"\n}\n");
EXPECT_EQ(200, res->status);
}
}
}
#endif
TEST(AbsoluteRedirectTest, Redirect) { TEST(AbsoluteRedirectTest, Redirect) {
auto host = "httpbin.org"; auto host = "httpbin.org";
@@ -761,14 +803,14 @@ protected:
if (req.is_multipart_form_data()) { if (req.is_multipart_form_data()) {
MultipartFiles files; MultipartFiles files;
content_reader( content_reader(
[&](const std::string &name, const MultipartFile &file) {
files.emplace(name, file);
return true;
},
[&](const std::string &name, const char *data, size_t data_length) { [&](const std::string &name, const char *data, size_t data_length) {
auto &file = files.find(name)->second; auto &file = files.find(name)->second;
file.content.append(data, data_length); file.content.append(data, data_length);
return true; return true;
},
[&](const std::string &name, const MultipartFile &file) {
files.emplace(name, file);
return true;
}); });
EXPECT_EQ(5u, files.size()); EXPECT_EQ(5u, files.size());
@@ -1357,7 +1399,7 @@ TEST_F(ServerTest, GetStreamedWithRangeMultipart) {
} }
TEST_F(ServerTest, GetStreamedEndless) { TEST_F(ServerTest, GetStreamedEndless) {
size_t offset = 0; uint64_t offset = 0;
auto res = cli_.Get("/streamed-cancel", auto res = cli_.Get("/streamed-cancel",
[&](const char * /*data*/, uint64_t data_length) { [&](const char * /*data*/, uint64_t data_length) {
if (offset < 100) { if (offset < 100) {
@@ -1766,6 +1808,89 @@ TEST_F(ServerTest, MultipartFormDataGzip) {
} }
#endif #endif
// Sends a raw request to a server listening at HOST:PORT.
static bool send_request(time_t read_timeout_sec, const std::string& req) {
auto client_sock =
detail::create_client_socket(HOST, PORT, /*timeout_sec=*/5);
if (client_sock == INVALID_SOCKET) { return false; }
return detail::process_and_close_socket(
true, client_sock, 1, read_timeout_sec, 0,
[&](Stream& strm, bool /*last_connection*/,
bool &/*connection_close*/) -> bool {
if (req.size() !=
static_cast<size_t>(strm.write(req.data(), req.size()))) {
return false;
}
char buf[512];
detail::stream_line_reader line_reader(strm, buf, sizeof(buf));
while (line_reader.getline()) {}
return true;
});
}
TEST(ServerRequestParsingTest, TrimWhitespaceFromHeaderValues) {
Server svr;
std::string header_value;
svr.Get("/validate-ws-in-headers",
[&](const Request &req, Response &res) {
header_value = req.get_header_value("foo");
res.set_content("ok", "text/plain");
});
thread t = thread([&] { svr.listen(HOST, PORT); });
while (!svr.is_running()) {
msleep(1);
}
// Only space and horizontal tab are whitespace. Make sure other whitespace-
// like characters are not treated the same - use vertical tab and escape.
const std::string req =
"GET /validate-ws-in-headers HTTP/1.1\r\n"
"foo: \t \v bar \e\t \r\n"
"Connection: close\r\n"
"\r\n";
ASSERT_TRUE(send_request(5, req));
svr.stop();
t.join();
EXPECT_EQ(header_value, "\v bar \e");
}
TEST(ServerRequestParsingTest, ReadHeadersRegexComplexity) {
Server svr;
svr.Get("/hi",
[&](const Request & /*req*/, Response &res) {
res.set_content("ok", "text/plain");
});
// Server read timeout must be longer than the client read timeout for the
// bug to reproduce, probably to force the server to process a request
// without a trailing blank line.
const time_t client_read_timeout_sec = 1;
svr.set_read_timeout(client_read_timeout_sec + 1, 0);
bool listen_thread_ok = false;
thread t = thread([&] { listen_thread_ok = svr.listen(HOST, PORT); });
while (!svr.is_running()) {
msleep(1);
}
// A certain header line causes an exception if the header property is parsed
// naively with a single regex. This occurs with libc++ but not libstdc++.
const std::string req =
"GET /hi HTTP/1.1\r\n"
" : "
" ";
ASSERT_TRUE(send_request(client_read_timeout_sec, req));
svr.stop();
t.join();
EXPECT_TRUE(listen_thread_ok);
}
class ServerTestWithAI_PASSIVE : public ::testing::Test { class ServerTestWithAI_PASSIVE : public ::testing::Test {
protected: protected:
ServerTestWithAI_PASSIVE() ServerTestWithAI_PASSIVE()