diff --git a/unit_tests/smtp.c b/unit_tests/smtp.c index 456c284..9b43f83 100644 --- a/unit_tests/smtp.c +++ b/unit_tests/smtp.c @@ -17,6 +17,8 @@ #include #include #include +#include +#include #include "../src/piler.h" @@ -99,11 +101,56 @@ } else { send_smtp_command(net, "EHLO aaaa.fu\r\n", recvbuf, sizeof(recvbuf)-1); - assert(strncmp(recvbuf, "250-", 4) == 0 && "HELO"); + assert(strncmp(recvbuf, "250-", 4) == 0 && "EHLO"); + if(net->use_ssl == 0) assert(strstr(recvbuf, "250-STARTTLS") && "STARTTLS"); + else assert(strstr(recvbuf, "250-STARTTLS") == NULL && "STARTTLS"); } } +int init_ssl(struct data *data){ + int n; + char *str; + X509* server_cert; + + SSL_library_init(); + SSL_load_error_strings(); + + #if OPENSSL_VERSION_NUMBER < 0x10100000L + data->net->ctx = SSL_CTX_new(TLSv1_client_method()); + #else + data->net->ctx = SSL_CTX_new(TLS_client_method()); + #endif + CHK_NULL(data->net->ctx, "internal SSL error"); + + data->net->ssl = SSL_new(data->net->ctx); + CHK_NULL(data->net->ssl, "internal ssl error"); + + SSL_set_fd(data->net->ssl, data->net->socket); + n = SSL_connect(data->net->ssl); + CHK_SSL(n, "internal ssl error"); + + printf("Cipher: %s\n", SSL_get_cipher(data->net->ssl)); + + server_cert = SSL_get_peer_certificate(data->net->ssl); + CHK_NULL(server_cert, "server cert error"); + + str = X509_NAME_oneline(X509_get_subject_name(server_cert), 0, 0); + CHK_NULL(str, "error in server cert"); + OPENSSL_free(str); + + str = X509_NAME_oneline(X509_get_issuer_name(server_cert), 0, 0); + CHK_NULL(str, "error in server cert"); + OPENSSL_free(str); + + X509_free(server_cert); + + data->net->use_ssl = 1; + + return OK; +} + + static void test_smtp_commands_one_at_a_time(char *server, int port, struct data *data){ char recvbuf[MAXBUFSIZE], sendbuf[MAXBUFSIZE]; @@ -230,7 +277,35 @@ static void test_smtp_commands_starttls(char *server, int port, struct data *data){ char recvbuf[MAXBUFSIZE], sendbuf[MAXBUFSIZE]; - // TODO: implement starttls logic + connect_to_smtp_server(server, port, data); + + send_helo_command(data->net); + + send_smtp_command(data->net, "STARTTLS\r\n", recvbuf, sizeof(recvbuf)-1); + assert(strncmp(recvbuf, "220 ", 4) == 0 && "STARTTLS"); + + init_ssl(data); + + send_helo_command(data->net); + + send_smtp_command(data->net, "MAIL FROM: \r\n", recvbuf, sizeof(recvbuf)-1); + assert(strncmp(recvbuf, "250 ", 4) == 0 && "MAIL"); + + send_smtp_command(data->net, "RCPT TO: \r\n", recvbuf, sizeof(recvbuf)-1); + assert(strncmp(recvbuf, "250 ", 4) == 0 && "RCPT"); + + send_smtp_command(data->net, "DATA\r\n", recvbuf, sizeof(recvbuf)-1); + assert(strncmp(recvbuf, "354 ", 4) == 0 && "DATA"); + + snprintf(sendbuf, sizeof(sendbuf)-1, "%s\r\n.\r\n", testmessage); + + send_smtp_command(data->net, sendbuf, recvbuf, sizeof(recvbuf)-1); + assert(strncmp(recvbuf, "250 ", 4) == 0 && "PERIOD"); + + send_smtp_command(data->net, "QUIT\r\n", recvbuf, sizeof(recvbuf)-1); + assert(strncmp(recvbuf, "221 ", 4) == 0 && "QUIT"); + + close(data->net->socket); } @@ -305,6 +380,8 @@ test_smtp_commands_with_reset_command(server, port, &data); test_smtp_commands_partial_command(server, port, &data); test_smtp_commands_partial_command_pipelining(server, port, &data); + + helo = 1; // we must use EHLO to get the STARTTLS in the response test_smtp_commands_starttls(server, port, &data);