Fix TLS codepath for pending data. Reuse c->rtls

This commit is contained in:
Sergey Lyubka 2023-12-22 16:16:07 +00:00
parent 97127a57da
commit 06f8238107
3 changed files with 52 additions and 102 deletions

View File

@ -6962,6 +6962,8 @@ static void read_conn(struct mg_connection *c) {
if (c->is_tls_hs) mg_tls_handshake(c);
if (c->is_tls_hs) return;
n = mg_tls_recv(c, buf, len);
} else if (n == MG_IO_WAIT) {
n = mg_tls_recv(c, buf, len);
}
} else {
n = recv_raw(c, buf, len);
@ -8762,8 +8764,7 @@ enum mg_tls_hs_state {
struct tls_data {
enum mg_tls_hs_state state; // keep track of connection handshake progress
struct mg_iobuf send;
struct mg_iobuf recv;
struct mg_iobuf send; // For the receive path, we're reusing c->rtls
mg_sha256_ctx sha256; // incremental SHA-256 hash for TLS handshake
@ -9055,56 +9056,30 @@ static void mg_tls_derive_secret(const char *label, uint8_t *key, size_t keysz,
memmove(hash, secret, hashsz);
}
// receive as much data as we can, but at least one full TLS record
static int mg_tls_recv_msg(struct mg_connection *c) {
struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv;
uint16_t record_len;
// Pull data from TCP
for (;;) {
long n;
mg_iobuf_resize(rio, rio->len + 1);
n = mg_io_recv(c, &rio->buf[rio->len], rio->size - rio->len);
if (n > 0) {
rio->len += (size_t) n;
} else if (n == MG_IO_WAIT) {
break;
} else {
if (!c->is_closing) {
mg_error(c, "read IO err");
}
return MG_IO_ERR;
}
}
// Look if we've pulled everything
if (rio->len < TLS_HDR_SIZE) return MG_IO_WAIT;
record_len = MG_LOAD_BE16(rio->buf + 3);
if (rio->len < (size_t) TLS_HDR_SIZE + record_len) return MG_IO_WAIT;
return 0;
// Did we receive a full TLS message in the c->rtls buffer?
static bool mg_tls_got_msg(struct mg_connection *c) {
return c->rtls.len >= TLS_HDR_SIZE &&
c->rtls.len >= (TLS_HDR_SIZE + MG_LOAD_BE16(c->rtls.buf + 3));
}
// Remove a single TLS record from the recv buffer
static void mg_tls_drop_packet(struct mg_iobuf *rio) {
uint16_t n = MG_LOAD_BE16(rio->buf + 3) + TLS_HDR_SIZE;
mg_iobuf_del(rio, 0, n);
// memmove(rio->buf, rio->buf + n, rio->len - n);
// rio->len = rio->len - n;
}
// read and parse ClientHello record
static int mg_tls_client_hello(struct mg_connection *c) {
struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv;
struct mg_iobuf *rio = &c->rtls;
uint8_t session_id_len;
uint16_t j;
uint16_t cipher_suites_len;
uint16_t ext_len;
uint8_t *ext;
int r = mg_tls_recv_msg(c);
if (r < 0) {
return r;
if (!mg_tls_got_msg(c)) {
return MG_IO_WAIT;
}
if (rio->buf[0] != 0x16 || rio->buf[5] != 0x01) {
mg_error(c, "not a hello packet");
@ -9280,14 +9255,16 @@ static void mg_tls_encrypt(struct mg_connection *c, const uint8_t *msg,
static int mg_tls_recv_decrypt(struct mg_connection *c, void *buf,
size_t bufsz) {
struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv;
struct mg_iobuf *rio = &c->rtls;
// struct mg_iobuf *rio = &tls->recv;
uint16_t msgsz;
uint8_t *msg;
uint8_t nonce[12];
int r;
for (;;) {
r = mg_tls_recv_msg(c);
if (r < 0) return r;
if (!mg_tls_got_msg(c)) {
return MG_IO_WAIT;
}
if (rio->buf[0] == 0x17) {
break;
} else if (rio->buf[0] == 0x15) {
@ -9439,11 +9416,12 @@ static void mg_tls_server_finish(struct mg_connection *c) {
}
static int mg_tls_client_change_cipher(struct mg_connection *c) {
struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv;
// struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &c->rtls;
for (;;) {
int r = mg_tls_recv_msg(c);
if (r < 0) return r;
if (!mg_tls_got_msg(c)) {
return MG_IO_WAIT;
}
if (rio->buf[0] == 0x14) { // got a ChangeCipher record
break;
} else if (rio->buf[0] == 0x15) { // skip Alert records
@ -9451,7 +9429,7 @@ static int mg_tls_client_change_cipher(struct mg_connection *c) {
mg_tls_drop_packet(rio);
} else {
mg_error(c, "unexpected packet");
return -1;
return MG_IO_ERR;
}
}
// consume ChangeCipher packet
@ -9532,9 +9510,7 @@ void mg_tls_handshake(struct mg_connection *c) {
mg_tls_generate_application_keys(c);
tls->state = MG_TLS_HS_DONE;
// fallthrough
case MG_TLS_HS_DONE:
c->is_tls_hs = 0;
return;
case MG_TLS_HS_DONE: c->is_tls_hs = 0; return;
}
}
@ -9607,7 +9583,8 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
return;
}
tls->send.align = tls->recv.align = MG_IO_SIZE;
// tls->send.align = tls->recv.align = MG_IO_SIZE;
tls->send.align = MG_IO_SIZE;
c->tls = tls;
c->is_tls = c->is_tls_hs = 1;
mg_sha256_init(&tls->sha256);
@ -9617,7 +9594,6 @@ void mg_tls_free(struct mg_connection *c) {
struct tls_data *tls = c->tls;
if (tls != NULL) {
mg_iobuf_free(&tls->send);
mg_iobuf_free(&tls->recv);
free((void *) tls->server_cert_der.ptr);
}
free(c->tls);
@ -9627,7 +9603,7 @@ void mg_tls_free(struct mg_connection *c) {
long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) {
struct tls_data *tls = c->tls;
long n = MG_IO_WAIT;
if (len > 2048) len = 2048;
if (len > MG_IO_SIZE) len = MG_IO_SIZE;
mg_tls_encrypt(c, buf, len, 0x17);
while (tls->send.len > 0 &&
(n = mg_io_send(c, tls->send.buf, tls->send.len)) > 0) {
@ -9642,8 +9618,7 @@ long mg_tls_recv(struct mg_connection *c, void *buf, size_t len) {
}
size_t mg_tls_pending(struct mg_connection *c) {
struct tls_data *tls = (struct tls_data *) c->tls;
return tls == NULL ? 0 : tls->recv.len;
return mg_tls_got_msg(c) ? 1 : 0;
}
void mg_tls_ctx_init(struct mg_mgr *mgr) {

View File

@ -289,6 +289,8 @@ static void read_conn(struct mg_connection *c) {
if (c->is_tls_hs) mg_tls_handshake(c);
if (c->is_tls_hs) return;
n = mg_tls_recv(c, buf, len);
} else if (n == MG_IO_WAIT) {
n = mg_tls_recv(c, buf, len);
}
} else {
n = recv_raw(c, buf, len);

View File

@ -15,8 +15,7 @@ enum mg_tls_hs_state {
struct tls_data {
enum mg_tls_hs_state state; // keep track of connection handshake progress
struct mg_iobuf send;
struct mg_iobuf recv;
struct mg_iobuf send; // For the receive path, we're reusing c->rtls
mg_sha256_ctx sha256; // incremental SHA-256 hash for TLS handshake
@ -308,56 +307,30 @@ static void mg_tls_derive_secret(const char *label, uint8_t *key, size_t keysz,
memmove(hash, secret, hashsz);
}
// receive as much data as we can, but at least one full TLS record
static int mg_tls_recv_msg(struct mg_connection *c) {
struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv;
uint16_t record_len;
// Pull data from TCP
for (;;) {
long n;
mg_iobuf_resize(rio, rio->len + 1);
n = mg_io_recv(c, &rio->buf[rio->len], rio->size - rio->len);
if (n > 0) {
rio->len += (size_t) n;
} else if (n == MG_IO_WAIT) {
break;
} else {
if (!c->is_closing) {
mg_error(c, "read IO err");
}
return MG_IO_ERR;
}
}
// Look if we've pulled everything
if (rio->len < TLS_HDR_SIZE) return MG_IO_WAIT;
record_len = MG_LOAD_BE16(rio->buf + 3);
if (rio->len < (size_t) TLS_HDR_SIZE + record_len) return MG_IO_WAIT;
return 0;
// Did we receive a full TLS message in the c->rtls buffer?
static bool mg_tls_got_msg(struct mg_connection *c) {
return c->rtls.len >= TLS_HDR_SIZE &&
c->rtls.len >= (TLS_HDR_SIZE + MG_LOAD_BE16(c->rtls.buf + 3));
}
// Remove a single TLS record from the recv buffer
static void mg_tls_drop_packet(struct mg_iobuf *rio) {
uint16_t n = MG_LOAD_BE16(rio->buf + 3) + TLS_HDR_SIZE;
mg_iobuf_del(rio, 0, n);
// memmove(rio->buf, rio->buf + n, rio->len - n);
// rio->len = rio->len - n;
}
// read and parse ClientHello record
static int mg_tls_client_hello(struct mg_connection *c) {
struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv;
struct mg_iobuf *rio = &c->rtls;
uint8_t session_id_len;
uint16_t j;
uint16_t cipher_suites_len;
uint16_t ext_len;
uint8_t *ext;
int r = mg_tls_recv_msg(c);
if (r < 0) {
return r;
if (!mg_tls_got_msg(c)) {
return MG_IO_WAIT;
}
if (rio->buf[0] != 0x16 || rio->buf[5] != 0x01) {
mg_error(c, "not a hello packet");
@ -533,14 +506,16 @@ static void mg_tls_encrypt(struct mg_connection *c, const uint8_t *msg,
static int mg_tls_recv_decrypt(struct mg_connection *c, void *buf,
size_t bufsz) {
struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv;
struct mg_iobuf *rio = &c->rtls;
// struct mg_iobuf *rio = &tls->recv;
uint16_t msgsz;
uint8_t *msg;
uint8_t nonce[12];
int r;
for (;;) {
r = mg_tls_recv_msg(c);
if (r < 0) return r;
if (!mg_tls_got_msg(c)) {
return MG_IO_WAIT;
}
if (rio->buf[0] == 0x17) {
break;
} else if (rio->buf[0] == 0x15) {
@ -692,11 +667,12 @@ static void mg_tls_server_finish(struct mg_connection *c) {
}
static int mg_tls_client_change_cipher(struct mg_connection *c) {
struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &tls->recv;
// struct tls_data *tls = c->tls;
struct mg_iobuf *rio = &c->rtls;
for (;;) {
int r = mg_tls_recv_msg(c);
if (r < 0) return r;
if (!mg_tls_got_msg(c)) {
return MG_IO_WAIT;
}
if (rio->buf[0] == 0x14) { // got a ChangeCipher record
break;
} else if (rio->buf[0] == 0x15) { // skip Alert records
@ -704,7 +680,7 @@ static int mg_tls_client_change_cipher(struct mg_connection *c) {
mg_tls_drop_packet(rio);
} else {
mg_error(c, "unexpected packet");
return -1;
return MG_IO_ERR;
}
}
// consume ChangeCipher packet
@ -785,9 +761,7 @@ void mg_tls_handshake(struct mg_connection *c) {
mg_tls_generate_application_keys(c);
tls->state = MG_TLS_HS_DONE;
// fallthrough
case MG_TLS_HS_DONE:
c->is_tls_hs = 0;
return;
case MG_TLS_HS_DONE: c->is_tls_hs = 0; return;
}
}
@ -860,7 +834,8 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
return;
}
tls->send.align = tls->recv.align = MG_IO_SIZE;
// tls->send.align = tls->recv.align = MG_IO_SIZE;
tls->send.align = MG_IO_SIZE;
c->tls = tls;
c->is_tls = c->is_tls_hs = 1;
mg_sha256_init(&tls->sha256);
@ -870,7 +845,6 @@ void mg_tls_free(struct mg_connection *c) {
struct tls_data *tls = c->tls;
if (tls != NULL) {
mg_iobuf_free(&tls->send);
mg_iobuf_free(&tls->recv);
free((void *) tls->server_cert_der.ptr);
}
free(c->tls);
@ -880,7 +854,7 @@ void mg_tls_free(struct mg_connection *c) {
long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) {
struct tls_data *tls = c->tls;
long n = MG_IO_WAIT;
if (len > 2048) len = 2048;
if (len > MG_IO_SIZE) len = MG_IO_SIZE;
mg_tls_encrypt(c, buf, len, 0x17);
while (tls->send.len > 0 &&
(n = mg_io_send(c, tls->send.buf, tls->send.len)) > 0) {
@ -895,8 +869,7 @@ long mg_tls_recv(struct mg_connection *c, void *buf, size_t len) {
}
size_t mg_tls_pending(struct mg_connection *c) {
struct tls_data *tls = (struct tls_data *) c->tls;
return tls == NULL ? 0 : tls->recv.len;
return mg_tls_got_msg(c) ? 1 : 0;
}
void mg_tls_ctx_init(struct mg_mgr *mgr) {