Fix TLS codepath for pending data. Reuse c->rtls

This commit is contained in:
Sergey Lyubka 2023-12-22 16:16:07 +00:00
parent 97127a57da
commit 06f8238107
3 changed files with 52 additions and 102 deletions

View File

@ -6962,6 +6962,8 @@ static void read_conn(struct mg_connection *c) {
if (c->is_tls_hs) mg_tls_handshake(c); if (c->is_tls_hs) mg_tls_handshake(c);
if (c->is_tls_hs) return; if (c->is_tls_hs) return;
n = mg_tls_recv(c, buf, len); n = mg_tls_recv(c, buf, len);
} else if (n == MG_IO_WAIT) {
n = mg_tls_recv(c, buf, len);
} }
} else { } else {
n = recv_raw(c, buf, len); n = recv_raw(c, buf, len);
@ -8762,8 +8764,7 @@ enum mg_tls_hs_state {
struct tls_data { struct tls_data {
enum mg_tls_hs_state state; // keep track of connection handshake progress enum mg_tls_hs_state state; // keep track of connection handshake progress
struct mg_iobuf send; struct mg_iobuf send; // For the receive path, we're reusing c->rtls
struct mg_iobuf recv;
mg_sha256_ctx sha256; // incremental SHA-256 hash for TLS handshake mg_sha256_ctx sha256; // incremental SHA-256 hash for TLS handshake
@ -9055,56 +9056,30 @@ static void mg_tls_derive_secret(const char *label, uint8_t *key, size_t keysz,
memmove(hash, secret, hashsz); memmove(hash, secret, hashsz);
} }
// receive as much data as we can, but at least one full TLS record // Did we receive a full TLS message in the c->rtls buffer?
static int mg_tls_recv_msg(struct mg_connection *c) { static bool mg_tls_got_msg(struct mg_connection *c) {
struct tls_data *tls = c->tls; return c->rtls.len >= TLS_HDR_SIZE &&
struct mg_iobuf *rio = &tls->recv; c->rtls.len >= (TLS_HDR_SIZE + MG_LOAD_BE16(c->rtls.buf + 3));
uint16_t record_len;
// Pull data from TCP
for (;;) {
long n;
mg_iobuf_resize(rio, rio->len + 1);
n = mg_io_recv(c, &rio->buf[rio->len], rio->size - rio->len);
if (n > 0) {
rio->len += (size_t) n;
} else if (n == MG_IO_WAIT) {
break;
} else {
if (!c->is_closing) {
mg_error(c, "read IO err");
}
return MG_IO_ERR;
}
}
// Look if we've pulled everything
if (rio->len < TLS_HDR_SIZE) return MG_IO_WAIT;
record_len = MG_LOAD_BE16(rio->buf + 3);
if (rio->len < (size_t) TLS_HDR_SIZE + record_len) return MG_IO_WAIT;
return 0;
} }
// Remove a single TLS record from the recv buffer // Remove a single TLS record from the recv buffer
static void mg_tls_drop_packet(struct mg_iobuf *rio) { static void mg_tls_drop_packet(struct mg_iobuf *rio) {
uint16_t n = MG_LOAD_BE16(rio->buf + 3) + TLS_HDR_SIZE; uint16_t n = MG_LOAD_BE16(rio->buf + 3) + TLS_HDR_SIZE;
mg_iobuf_del(rio, 0, n); mg_iobuf_del(rio, 0, n);
// memmove(rio->buf, rio->buf + n, rio->len - n);
// rio->len = rio->len - n;
} }
// read and parse ClientHello record // read and parse ClientHello record
static int mg_tls_client_hello(struct mg_connection *c) { static int mg_tls_client_hello(struct mg_connection *c) {
struct tls_data *tls = c->tls; struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv; struct mg_iobuf *rio = &c->rtls;
uint8_t session_id_len; uint8_t session_id_len;
uint16_t j; uint16_t j;
uint16_t cipher_suites_len; uint16_t cipher_suites_len;
uint16_t ext_len; uint16_t ext_len;
uint8_t *ext; uint8_t *ext;
int r = mg_tls_recv_msg(c); if (!mg_tls_got_msg(c)) {
if (r < 0) { return MG_IO_WAIT;
return r;
} }
if (rio->buf[0] != 0x16 || rio->buf[5] != 0x01) { if (rio->buf[0] != 0x16 || rio->buf[5] != 0x01) {
mg_error(c, "not a hello packet"); mg_error(c, "not a hello packet");
@ -9280,14 +9255,16 @@ static void mg_tls_encrypt(struct mg_connection *c, const uint8_t *msg,
static int mg_tls_recv_decrypt(struct mg_connection *c, void *buf, static int mg_tls_recv_decrypt(struct mg_connection *c, void *buf,
size_t bufsz) { size_t bufsz) {
struct tls_data *tls = c->tls; struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv; struct mg_iobuf *rio = &c->rtls;
// struct mg_iobuf *rio = &tls->recv;
uint16_t msgsz; uint16_t msgsz;
uint8_t *msg; uint8_t *msg;
uint8_t nonce[12]; uint8_t nonce[12];
int r; int r;
for (;;) { for (;;) {
r = mg_tls_recv_msg(c); if (!mg_tls_got_msg(c)) {
if (r < 0) return r; return MG_IO_WAIT;
}
if (rio->buf[0] == 0x17) { if (rio->buf[0] == 0x17) {
break; break;
} else if (rio->buf[0] == 0x15) { } else if (rio->buf[0] == 0x15) {
@ -9439,11 +9416,12 @@ static void mg_tls_server_finish(struct mg_connection *c) {
} }
static int mg_tls_client_change_cipher(struct mg_connection *c) { static int mg_tls_client_change_cipher(struct mg_connection *c) {
struct tls_data *tls = c->tls; // struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv; struct mg_iobuf *rio = &c->rtls;
for (;;) { for (;;) {
int r = mg_tls_recv_msg(c); if (!mg_tls_got_msg(c)) {
if (r < 0) return r; return MG_IO_WAIT;
}
if (rio->buf[0] == 0x14) { // got a ChangeCipher record if (rio->buf[0] == 0x14) { // got a ChangeCipher record
break; break;
} else if (rio->buf[0] == 0x15) { // skip Alert records } else if (rio->buf[0] == 0x15) { // skip Alert records
@ -9451,7 +9429,7 @@ static int mg_tls_client_change_cipher(struct mg_connection *c) {
mg_tls_drop_packet(rio); mg_tls_drop_packet(rio);
} else { } else {
mg_error(c, "unexpected packet"); mg_error(c, "unexpected packet");
return -1; return MG_IO_ERR;
} }
} }
// consume ChangeCipher packet // consume ChangeCipher packet
@ -9532,9 +9510,7 @@ void mg_tls_handshake(struct mg_connection *c) {
mg_tls_generate_application_keys(c); mg_tls_generate_application_keys(c);
tls->state = MG_TLS_HS_DONE; tls->state = MG_TLS_HS_DONE;
// fallthrough // fallthrough
case MG_TLS_HS_DONE: case MG_TLS_HS_DONE: c->is_tls_hs = 0; return;
c->is_tls_hs = 0;
return;
} }
} }
@ -9607,7 +9583,8 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
return; return;
} }
tls->send.align = tls->recv.align = MG_IO_SIZE; // tls->send.align = tls->recv.align = MG_IO_SIZE;
tls->send.align = MG_IO_SIZE;
c->tls = tls; c->tls = tls;
c->is_tls = c->is_tls_hs = 1; c->is_tls = c->is_tls_hs = 1;
mg_sha256_init(&tls->sha256); mg_sha256_init(&tls->sha256);
@ -9617,7 +9594,6 @@ void mg_tls_free(struct mg_connection *c) {
struct tls_data *tls = c->tls; struct tls_data *tls = c->tls;
if (tls != NULL) { if (tls != NULL) {
mg_iobuf_free(&tls->send); mg_iobuf_free(&tls->send);
mg_iobuf_free(&tls->recv);
free((void *) tls->server_cert_der.ptr); free((void *) tls->server_cert_der.ptr);
} }
free(c->tls); free(c->tls);
@ -9627,7 +9603,7 @@ void mg_tls_free(struct mg_connection *c) {
long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) { long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) {
struct tls_data *tls = c->tls; struct tls_data *tls = c->tls;
long n = MG_IO_WAIT; long n = MG_IO_WAIT;
if (len > 2048) len = 2048; if (len > MG_IO_SIZE) len = MG_IO_SIZE;
mg_tls_encrypt(c, buf, len, 0x17); mg_tls_encrypt(c, buf, len, 0x17);
while (tls->send.len > 0 && while (tls->send.len > 0 &&
(n = mg_io_send(c, tls->send.buf, tls->send.len)) > 0) { (n = mg_io_send(c, tls->send.buf, tls->send.len)) > 0) {
@ -9642,8 +9618,7 @@ long mg_tls_recv(struct mg_connection *c, void *buf, size_t len) {
} }
size_t mg_tls_pending(struct mg_connection *c) { size_t mg_tls_pending(struct mg_connection *c) {
struct tls_data *tls = (struct tls_data *) c->tls; return mg_tls_got_msg(c) ? 1 : 0;
return tls == NULL ? 0 : tls->recv.len;
} }
void mg_tls_ctx_init(struct mg_mgr *mgr) { void mg_tls_ctx_init(struct mg_mgr *mgr) {

View File

@ -289,6 +289,8 @@ static void read_conn(struct mg_connection *c) {
if (c->is_tls_hs) mg_tls_handshake(c); if (c->is_tls_hs) mg_tls_handshake(c);
if (c->is_tls_hs) return; if (c->is_tls_hs) return;
n = mg_tls_recv(c, buf, len); n = mg_tls_recv(c, buf, len);
} else if (n == MG_IO_WAIT) {
n = mg_tls_recv(c, buf, len);
} }
} else { } else {
n = recv_raw(c, buf, len); n = recv_raw(c, buf, len);

View File

@ -15,8 +15,7 @@ enum mg_tls_hs_state {
struct tls_data { struct tls_data {
enum mg_tls_hs_state state; // keep track of connection handshake progress enum mg_tls_hs_state state; // keep track of connection handshake progress
struct mg_iobuf send; struct mg_iobuf send; // For the receive path, we're reusing c->rtls
struct mg_iobuf recv;
mg_sha256_ctx sha256; // incremental SHA-256 hash for TLS handshake mg_sha256_ctx sha256; // incremental SHA-256 hash for TLS handshake
@ -308,56 +307,30 @@ static void mg_tls_derive_secret(const char *label, uint8_t *key, size_t keysz,
memmove(hash, secret, hashsz); memmove(hash, secret, hashsz);
} }
// receive as much data as we can, but at least one full TLS record // Did we receive a full TLS message in the c->rtls buffer?
static int mg_tls_recv_msg(struct mg_connection *c) { static bool mg_tls_got_msg(struct mg_connection *c) {
struct tls_data *tls = c->tls; return c->rtls.len >= TLS_HDR_SIZE &&
struct mg_iobuf *rio = &tls->recv; c->rtls.len >= (TLS_HDR_SIZE + MG_LOAD_BE16(c->rtls.buf + 3));
uint16_t record_len;
// Pull data from TCP
for (;;) {
long n;
mg_iobuf_resize(rio, rio->len + 1);
n = mg_io_recv(c, &rio->buf[rio->len], rio->size - rio->len);
if (n > 0) {
rio->len += (size_t) n;
} else if (n == MG_IO_WAIT) {
break;
} else {
if (!c->is_closing) {
mg_error(c, "read IO err");
}
return MG_IO_ERR;
}
}
// Look if we've pulled everything
if (rio->len < TLS_HDR_SIZE) return MG_IO_WAIT;
record_len = MG_LOAD_BE16(rio->buf + 3);
if (rio->len < (size_t) TLS_HDR_SIZE + record_len) return MG_IO_WAIT;
return 0;
} }
// Remove a single TLS record from the recv buffer // Remove a single TLS record from the recv buffer
static void mg_tls_drop_packet(struct mg_iobuf *rio) { static void mg_tls_drop_packet(struct mg_iobuf *rio) {
uint16_t n = MG_LOAD_BE16(rio->buf + 3) + TLS_HDR_SIZE; uint16_t n = MG_LOAD_BE16(rio->buf + 3) + TLS_HDR_SIZE;
mg_iobuf_del(rio, 0, n); mg_iobuf_del(rio, 0, n);
// memmove(rio->buf, rio->buf + n, rio->len - n);
// rio->len = rio->len - n;
} }
// read and parse ClientHello record // read and parse ClientHello record
static int mg_tls_client_hello(struct mg_connection *c) { static int mg_tls_client_hello(struct mg_connection *c) {
struct tls_data *tls = c->tls; struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv; struct mg_iobuf *rio = &c->rtls;
uint8_t session_id_len; uint8_t session_id_len;
uint16_t j; uint16_t j;
uint16_t cipher_suites_len; uint16_t cipher_suites_len;
uint16_t ext_len; uint16_t ext_len;
uint8_t *ext; uint8_t *ext;
int r = mg_tls_recv_msg(c); if (!mg_tls_got_msg(c)) {
if (r < 0) { return MG_IO_WAIT;
return r;
} }
if (rio->buf[0] != 0x16 || rio->buf[5] != 0x01) { if (rio->buf[0] != 0x16 || rio->buf[5] != 0x01) {
mg_error(c, "not a hello packet"); mg_error(c, "not a hello packet");
@ -533,14 +506,16 @@ static void mg_tls_encrypt(struct mg_connection *c, const uint8_t *msg,
static int mg_tls_recv_decrypt(struct mg_connection *c, void *buf, static int mg_tls_recv_decrypt(struct mg_connection *c, void *buf,
size_t bufsz) { size_t bufsz) {
struct tls_data *tls = c->tls; struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv; struct mg_iobuf *rio = &c->rtls;
// struct mg_iobuf *rio = &tls->recv;
uint16_t msgsz; uint16_t msgsz;
uint8_t *msg; uint8_t *msg;
uint8_t nonce[12]; uint8_t nonce[12];
int r; int r;
for (;;) { for (;;) {
r = mg_tls_recv_msg(c); if (!mg_tls_got_msg(c)) {
if (r < 0) return r; return MG_IO_WAIT;
}
if (rio->buf[0] == 0x17) { if (rio->buf[0] == 0x17) {
break; break;
} else if (rio->buf[0] == 0x15) { } else if (rio->buf[0] == 0x15) {
@ -692,11 +667,12 @@ static void mg_tls_server_finish(struct mg_connection *c) {
} }
static int mg_tls_client_change_cipher(struct mg_connection *c) { static int mg_tls_client_change_cipher(struct mg_connection *c) {
struct tls_data *tls = c->tls; // struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv; struct mg_iobuf *rio = &c->rtls;
for (;;) { for (;;) {
int r = mg_tls_recv_msg(c); if (!mg_tls_got_msg(c)) {
if (r < 0) return r; return MG_IO_WAIT;
}
if (rio->buf[0] == 0x14) { // got a ChangeCipher record if (rio->buf[0] == 0x14) { // got a ChangeCipher record
break; break;
} else if (rio->buf[0] == 0x15) { // skip Alert records } else if (rio->buf[0] == 0x15) { // skip Alert records
@ -704,7 +680,7 @@ static int mg_tls_client_change_cipher(struct mg_connection *c) {
mg_tls_drop_packet(rio); mg_tls_drop_packet(rio);
} else { } else {
mg_error(c, "unexpected packet"); mg_error(c, "unexpected packet");
return -1; return MG_IO_ERR;
} }
} }
// consume ChangeCipher packet // consume ChangeCipher packet
@ -785,9 +761,7 @@ void mg_tls_handshake(struct mg_connection *c) {
mg_tls_generate_application_keys(c); mg_tls_generate_application_keys(c);
tls->state = MG_TLS_HS_DONE; tls->state = MG_TLS_HS_DONE;
// fallthrough // fallthrough
case MG_TLS_HS_DONE: case MG_TLS_HS_DONE: c->is_tls_hs = 0; return;
c->is_tls_hs = 0;
return;
} }
} }
@ -860,7 +834,8 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
return; return;
} }
tls->send.align = tls->recv.align = MG_IO_SIZE; // tls->send.align = tls->recv.align = MG_IO_SIZE;
tls->send.align = MG_IO_SIZE;
c->tls = tls; c->tls = tls;
c->is_tls = c->is_tls_hs = 1; c->is_tls = c->is_tls_hs = 1;
mg_sha256_init(&tls->sha256); mg_sha256_init(&tls->sha256);
@ -870,7 +845,6 @@ void mg_tls_free(struct mg_connection *c) {
struct tls_data *tls = c->tls; struct tls_data *tls = c->tls;
if (tls != NULL) { if (tls != NULL) {
mg_iobuf_free(&tls->send); mg_iobuf_free(&tls->send);
mg_iobuf_free(&tls->recv);
free((void *) tls->server_cert_der.ptr); free((void *) tls->server_cert_der.ptr);
} }
free(c->tls); free(c->tls);
@ -880,7 +854,7 @@ void mg_tls_free(struct mg_connection *c) {
long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) { long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) {
struct tls_data *tls = c->tls; struct tls_data *tls = c->tls;
long n = MG_IO_WAIT; long n = MG_IO_WAIT;
if (len > 2048) len = 2048; if (len > MG_IO_SIZE) len = MG_IO_SIZE;
mg_tls_encrypt(c, buf, len, 0x17); mg_tls_encrypt(c, buf, len, 0x17);
while (tls->send.len > 0 && while (tls->send.len > 0 &&
(n = mg_io_send(c, tls->send.buf, tls->send.len)) > 0) { (n = mg_io_send(c, tls->send.buf, tls->send.len)) > 0) {
@ -895,8 +869,7 @@ long mg_tls_recv(struct mg_connection *c, void *buf, size_t len) {
} }
size_t mg_tls_pending(struct mg_connection *c) { size_t mg_tls_pending(struct mg_connection *c) {
struct tls_data *tls = (struct tls_data *) c->tls; return mg_tls_got_msg(c) ? 1 : 0;
return tls == NULL ? 0 : tls->recv.len;
} }
void mg_tls_ctx_init(struct mg_mgr *mgr) { void mg_tls_ctx_init(struct mg_mgr *mgr) {