diff --git a/Makefile b/Makefile index 752336f..7430cc0 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # for debug add -g -O0 to line below -CFLAGS+=-pthread -O2 -Wall -Wextra -Wpedantic -Wstrict-overflow -fno-strict-aliasing -std=gnu11 -g -O0 +CFLAGS+=-pthread -O2 -Wall -Wextra -Wpedantic -Wstrict-overflow -fno-strict-aliasing -std=gnu11 -g -O0 -lcrypto -lssl prefix=/usr/local/bin all: diff --git a/fiche.c b/fiche.c index 99d140d..b389e02 100644 --- a/fiche.c +++ b/fiche.c @@ -48,12 +48,26 @@ Use netcat to push text - example: #include #include +#include +#include +#include +#include +#include +#include +#include /****************************************************************************** * Various declarations */ const char *Fiche_Symbols = "abcdefghijklmnopqrstuvwxyz0123456789"; +EVP_PKEY *g_key = NULL; +STACK_OF(X509) *g_cert_chain = NULL; +SSL_METHOD *g_method = NULL; +X509_STORE *g_store = NULL; +SSL_CTX *g_ctx = NULL; + +bool debug = 0; /****************************************************************************** * Inner structs @@ -62,7 +76,7 @@ const char *Fiche_Symbols = "abcdefghijklmnopqrstuvwxyz0123456789"; struct fiche_connection { int socket; struct sockaddr_in address; - + SSL *ssl; Fiche_Settings *settings; }; @@ -214,7 +228,11 @@ void fiche_init(Fiche_Settings *settings) { // path to banlist NULL, // path to whitelist - NULL + NULL, + // cert + NULL, + // key + NULL }; // Copy default settings to provided instance @@ -313,6 +331,18 @@ static void print_status(const char *format, ...) { va_end(args); } +static void print_debug(const char *format, ...) { + va_list args; + + va_start(args, format); + + printf("[Fiche][DEBUG] "); + vprintf(format, args); + printf("\n"); + + va_end(args); +} + static void print_separator() { printf("============================================================\n"); @@ -425,11 +455,379 @@ static int perform_user_change(const Fiche_Settings *settings) { return 0; } +bool is_ssl(Fiche_Settings *settings) { + if (settings->cert && settings->key) + return true; + else + return false; +} + +EVP_PKEY *read_key(char *key) { + FILE *fp; + unsigned long my_err; + EVP_PKEY *new_key; + + if (!key || !*key) { + return NULL; + } + + fp = fopen(key, "r"); + if (!fp) { + print_error("SSL Private Key fopen() Error: %s", strerror(errno)); + return NULL; + } + new_key = PEM_read_PrivateKey(fp, NULL, NULL, NULL); + fclose(fp); + + if (!new_key) { + while ((my_err = ERR_get_error())) { + print_error("SSL Private Key Loading Error: %s", ERR_error_string(my_err, NULL)); + } + } + + return new_key; +} + +STACK_OF(X509) *read_cert_chain(char *cert) { + FILE *fp = NULL; + unsigned long my_err = 0; + X509 *new_cert = NULL; + STACK_OF(X509) *new_cert_chain = NULL; + + if (!cert || !*cert) { + return NULL; + } + + fp = fopen(cert, "r"); + if (!fp) { + print_error("SSL Certificate fopen() Error: %s", strerror(errno)); + return NULL; + } + + new_cert_chain = sk_X509_new_null(); + if (!new_cert_chain) { + print_error("SSL Certificate sk_X509_new_null() Error"); + fclose(fp); + return NULL; + } + + while((new_cert = PEM_read_X509(fp, NULL, NULL, NULL))) { + sk_X509_push(new_cert_chain, new_cert); + } + + fclose(fp); + + if (!new_cert_chain || sk_X509_num(new_cert_chain) <= 0) { + while ((my_err = ERR_get_error())) { + print_error("SSL Certificate Loading Error: %s", ERR_error_string(my_err, NULL)); + } + if (new_cert_chain) { + sk_X509_free(new_cert_chain); + } + return NULL; + } + + return new_cert_chain; +} + +X509_STORE *make_cert_store(void) { + X509_STORE *store = NULL; + + store = X509_STORE_new(); + + if (!store) { + print_error("SSL Certificate Error: X509_STORE_new() Failed"); + return NULL; + } + + X509_STORE_set_default_paths(store); + + return store; +} + +void info_callback(SSL *s, int where, int ret) { + where = where; + ret = ret; + if (debug) print_debug("SSL info: %s", SSL_state_string_long(s)); + return; +} + +SSL_CTX *make_ctx(STACK_OF(X509) *cert_chain, EVP_PKEY *key) { + + SSL_CTX *new_ctx=NULL; + unsigned long my_err=0; + EC_KEY *ecdh=NULL; + int i=0; + + if (!cert_chain || sk_X509_num(cert_chain) <= 0 || !key) + return NULL; + + if (!g_method) + return NULL; + + new_ctx = SSL_CTX_new(g_method); + if (!new_ctx) { + while ((my_err = ERR_get_error())) + print_error("SSL Context Structure Error: %s", ERR_error_string(my_err, NULL)); + return NULL; + } + + if (!SSL_CTX_use_certificate(new_ctx, sk_X509_value(cert_chain, 0))) { + while ((my_err = ERR_get_error())) + print_error("SSL Certificate Error: %s", ERR_error_string(my_err, NULL)); + SSL_CTX_free(new_ctx); + new_ctx = NULL; + return NULL; + } + + for (i=1; i < sk_X509_num(cert_chain); i++) { + if (!SSL_CTX_add_extra_chain_cert(new_ctx, sk_X509_value(cert_chain, i))) { + while ((my_err = ERR_get_error())) + print_error("SSL Certificate Error: %s", ERR_error_string(my_err, NULL)); + SSL_CTX_free(new_ctx); + new_ctx = NULL; + return NULL; + } + } + + if (!SSL_CTX_use_PrivateKey(new_ctx, key)) { + while ((my_err = ERR_get_error())) + print_error("SSL Private Key Error: %s", ERR_error_string(my_err, NULL)); + SSL_CTX_free(new_ctx); + new_ctx = NULL; + return NULL; + } + + if (!SSL_CTX_check_private_key(new_ctx)) { + print_error("SSL Error: Private key does not match the certificate public key"); + SSL_CTX_free(new_ctx); + new_ctx = NULL; + return NULL; + } + + SSL_CTX_set_options(new_ctx, SSL_OP_ALL); + + SSL_CTX_set_info_callback(new_ctx, (void (*)())info_callback); + + SSL_CTX_set_mode(new_ctx, SSL_MODE_AUTO_RETRY); + + ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + if (ecdh) { + SSL_CTX_set_tmp_ecdh(new_ctx, ecdh); + EC_KEY_free(ecdh); + } else { + print_error("SSL Error: Elliptic curve Diffie-Hellman failure"); + } + + return new_ctx; + +} + +int fiche_ssl_init(Fiche_Settings *settings) { + + SSL_library_init(); + SSL_load_error_strings(); + OpenSSL_add_ssl_algorithms(); + +#ifdef HAVE_SHA256 + EVP_add_digest(EVP_sha256()); +#endif + + + if (!g_method) { + // this is a static structure. it will never be NULL. nothing to free. + g_method = (SSL_METHOD *)SSLv23_method(); + } + + g_key = read_key(settings->key); + if (!g_key) { + g_method = NULL; + return -1; + } + + g_cert_chain = read_cert_chain(settings->cert); + if (!g_cert_chain) { + EVP_PKEY_free(g_key); + g_key = NULL; + g_method = NULL; + return -1; + } + + if (sk_X509_num(g_cert_chain) <= 0) { + sk_X509_pop_free(g_cert_chain, X509_free); + g_cert_chain = NULL; + EVP_PKEY_free(g_key); + g_key = NULL; + g_method = NULL; + return -1; + } + + g_store = make_cert_store(); + if (!g_store) { + sk_X509_pop_free(g_cert_chain, X509_free); + g_cert_chain = NULL; + EVP_PKEY_free(g_key); + g_key = NULL; + g_method = NULL; + return -1; + } + + g_ctx = make_ctx(g_cert_chain, g_key); + if (!g_ctx) { + X509_STORE_free(g_store); + g_store = NULL; + sk_X509_pop_free(g_cert_chain, X509_free); + g_cert_chain = NULL; + EVP_PKEY_free(g_key); + g_key = NULL; + g_method = NULL; + return -1; + } + + return 0; +} + +int accept_ssl(SSL *ssl) { + int need_more; + int ssl_err; + unsigned long my_err; + + do { + + need_more = 0; + ssl_err = SSL_accept(ssl); + switch(SSL_get_error(ssl, ssl_err)){ + case SSL_ERROR_NONE: + break; + case SSL_ERROR_SSL: + while ((my_err = ERR_get_error())) + print_error("SSL_accept: SSL_ERROR_SSL: %s", ERR_error_string(my_err, NULL)); + return -1; + break; + case SSL_ERROR_SYSCALL: + if (ERR_peek_error()) { + while ((my_err = ERR_get_error())) + print_error("SSL_accept: SSL_ERROR_SYSCALL: %s", ERR_error_string(my_err, NULL)); + } else { + if (ssl_err) + print_error("SSL_accept: SSL_ERROR_SYSCALL errno: %s", strerror(errno)); + else + print_error("SSL_accept: SSL_ERROR_SYSCALL EOF"); + } + return -1; + break; + case SSL_ERROR_ZERO_RETURN: + print_error("SSL_accept: SSL_ERROR_ZERO_RETURN"); + return -1; + break; + case SSL_ERROR_WANT_READ: + need_more = 1; + if (debug) print_debug("SSL_accept: SSL_ERROR_WANT_READ"); + break; + case SSL_ERROR_WANT_WRITE: + need_more = 1; + if (debug) print_debug("SSL_accept: SSL_ERROR_WANT_WRITE"); + break; + case SSL_ERROR_WANT_ACCEPT: + need_more = 1; + if (debug) print_debug("SSL_accept: SSL_ERROR_WANT_ACCEPT"); + break; + default: + print_error("SSL_accept: SSL accept problem"); + return -1; + break; + } + + } while(need_more); + + return 0; +} + +int read_ssl(SSL *ssl, void *buf, int count) { + + int need_more; + unsigned long my_err; + int inbound_offset = 0; + int ret = 0; + + if (!ssl || !buf || count < 1) + return -1; + + do { + need_more = 0; + ret = SSL_read(ssl, buf, count); + if (ret >= 0) + inbound_offset += ret; + switch(SSL_get_error(ssl, ret)) { + case SSL_ERROR_NONE: + break; + case SSL_ERROR_SSL: + while ((my_err = ERR_get_error())) + print_error("SSL_read: SSL_ERROR_SSL: %s", ERR_error_string(my_err, NULL)); + break; + case SSL_ERROR_SYSCALL: + if (ERR_peek_error()) { + while ((my_err = ERR_get_error())) + print_error("SSL_read: SSL_ERROR_SYSCALL: %s", ERR_error_string(my_err, NULL)); + } else { + if (ret) + print_error("SSL_read: SSL_ERROR_SYSCALL errno: %s", strerror(errno)); + } + break; + case SSL_ERROR_ZERO_RETURN: + print_error("SSL_read: SSL_ERROR_ZERO_RETURN"); + break; + case SSL_ERROR_WANT_READ: + need_more = 1; + if (debug) print_debug("SSL_read: SSL_ERROR_WANT_READ"); + break; + case SSL_ERROR_WANT_WRITE: + need_more = 1; + if (debug) print_debug("SSL_read: SSL_ERROR_WANT_WRITE"); + break; + default: + print_error("SSL_read: SSL read problem"); + break; + } + } while (need_more); + + return inbound_offset; +} + +int read_ssl_waitall(SSL *ssl, void *buf, int count) { + + ssize_t inbound_offset=0; + ssize_t ret=0; + + if (!ssl || !buf || count < 1) + return -1; + + while (inbound_offset < count) { + ret = read_ssl(ssl, (char *)buf + inbound_offset, count - inbound_offset); + + if (ret <= 0) + break; + + inbound_offset += ret; + + } + + return inbound_offset; +} static int start_server(Fiche_Settings *settings) { + int s; + + if (is_ssl(settings)) { + if (fiche_ssl_init(settings) < 0) { + print_error("Couldn't initialize SSL!"); + return -1; + } + } + // Perform socket creation - int s = socket(AF_INET, SOCK_STREAM, 0); + s = socket(AF_INET, SOCK_STREAM, 0); if (s < 0) { print_error("Couldn't create a socket!"); return -1; @@ -492,7 +890,7 @@ static void dispatch_connection(int socket, Fiche_Settings *settings) { // Accept a connection and get a new socket id const int s = accept(socket, (struct sockaddr *) &address, &addlen); - if (s < 0 ) { + if (s < 0) { print_error("Error on accepting connection!"); return; } @@ -516,6 +914,7 @@ static void dispatch_connection(int socket, Fiche_Settings *settings) { } c->socket = s; c->address = address; + c->ssl = NULL; c->settings = settings; // Spawn a new thread to handle this connection @@ -560,15 +959,55 @@ static void *handle_connection(void *args) { print_status("Incoming connection from: %s (%s).", ip, hostname); } + if(is_ssl(c->settings)) { + c->ssl = SSL_new(g_ctx); + if (!c->ssl) { + close(c->socket); + free(c); + pthread_exit(NULL); + return 0; + } + + SSL_set_accept_state(c->ssl); + SSL_set_fd(c->ssl, c->socket); + SSL_set_mode(c->ssl, SSL_MODE_AUTO_RETRY); + + if (accept_ssl(c->ssl) < 0) { + if (c->ssl) { + SSL_shutdown(c->ssl); + SSL_free(c->ssl); + } + c->ssl = NULL; + close(c->socket); + free(c); + pthread_exit(NULL); + return 0; + } + } + // Create a buffer uint8_t buffer[c->settings->buffer_len]; memset(buffer, 0, c->settings->buffer_len); - const int r = recv(c->socket, buffer, sizeof(buffer), MSG_WAITALL); + int r; + + if (is_ssl(c->settings)) + r = read_ssl_waitall(c->ssl, buffer, sizeof(buffer)); + else + r = recv(c->socket, buffer, sizeof(buffer), MSG_WAITALL); + if (r <= 0) { print_error("No data received from the client!"); print_separator(); + if (is_ssl(c->settings)) { + if (c->ssl) { + SSL_shutdown(c->ssl); + SSL_free(c->ssl); + } + c->ssl = NULL; + } + // Close the socket close(c->socket); @@ -613,10 +1052,18 @@ static void *handle_connection(void *args) { print_error("Couldn't generate a valid slug!"); print_separator(); + if (is_ssl(c->settings)) { + if (c->ssl) { + SSL_shutdown(c->ssl); + SSL_free(c->ssl); + } + c->ssl = NULL; + } + // Cleanup - free(c); free(slug); close(c->socket); + free(c); pthread_exit(NULL); return NULL; } @@ -630,6 +1077,14 @@ static void *handle_connection(void *args) { print_error("Couldn't generate a slug!"); print_separator(); + if (is_ssl(c->settings)) { + if (c->ssl) { + SSL_shutdown(c->ssl); + SSL_free(c->ssl); + } + c->ssl = NULL; + } + close(c->socket); // Cleanup @@ -644,6 +1099,14 @@ static void *handle_connection(void *args) { print_error("Couldn't save a file!"); print_separator(); + if (is_ssl(c->settings)) { + if (c->ssl) { + SSL_shutdown(c->ssl); + SSL_free(c->ssl); + } + c->ssl = NULL; + } + close(c->socket); // Cleanup @@ -662,7 +1125,10 @@ static void *handle_connection(void *args) { snprintf(url, len, "%s%s%s%s", c->settings->domain, "/", slug, "\n"); // Send the response - write(c->socket, url, len); + if(is_ssl(c->settings)) + SSL_write(c->ssl, url, len); + else + write(c->socket, url, len); } print_status("Received %d bytes, saved to: %s.", r, slug); @@ -672,6 +1138,14 @@ static void *handle_connection(void *args) { // TODO: log unsuccessful and rejected connections log_entry(c->settings, ip, hostname, slug); + if (is_ssl(c->settings)) { + if (c->ssl) { + SSL_shutdown(c->ssl); + SSL_free(c->ssl); + } + c->ssl = NULL; + } + // Close the connection close(c->socket); diff --git a/fiche.h b/fiche.h index b1c97a2..55b7eee 100644 --- a/fiche.h +++ b/fiche.h @@ -90,7 +90,15 @@ typedef struct Fiche_Settings { */ char *whitelist_path; + /** + * @brief Cert used in SSL + */ + char *cert; + /** + * @brief Key used in SSL + */ + char *key; } Fiche_Settings; @@ -115,5 +123,9 @@ int fiche_run(Fiche_Settings settings); */ extern const char *Fiche_Symbols; +/** + * @brief debug on/off + */ +extern bool debug; #endif diff --git a/main.c b/main.c index da503c0..5ee62b2 100644 --- a/main.c +++ b/main.c @@ -44,7 +44,7 @@ int main(int argc, char **argv) { // Parse input arguments int c; - while ((c = getopt(argc, argv, "D6eSL:p:b:s:d:o:l:B:u:w:")) != -1) { + while ((c = getopt(argc, argv, "D6eSL:p:b:s:d:o:l:B:u:w:c:k:")) != -1) { switch (c) { // domain @@ -124,13 +124,35 @@ int main(int argc, char **argv) { } break; + // cert + case 'c': + { + fs.cert = optarg; + } + break; + + // key + case 'k': + { + fs.key = optarg; + } + break; + + // debug + case 'D': + { + debug = true; + } + break; + // Display help in case of any unsupported argument default: { - printf("usage: fiche [-dLpsSoBulbw].\n"); + printf("usage: fiche [-dLpsoBulbwSckD].\n"); printf(" [-d domain] [-L listen_addr] [-p port] [-s slug size]\n"); printf(" [-o output directory] [-B buffer size] [-u user name]\n"); printf(" [-l log file] [-b banlist] [-w whitelist] [-S]\n"); + printf(" [-c cert file] [-k key file] [-D]\n"); return 0; } break;