diff --git a/mongoose.c b/mongoose.c index 64d4906d..6afbbe45 100644 --- a/mongoose.c +++ b/mongoose.c @@ -6962,6 +6962,8 @@ static void read_conn(struct mg_connection *c) { if (c->is_tls_hs) mg_tls_handshake(c); if (c->is_tls_hs) return; n = mg_tls_recv(c, buf, len); + } else if (n == MG_IO_WAIT) { + n = mg_tls_recv(c, buf, len); } } else { n = recv_raw(c, buf, len); @@ -8762,8 +8764,7 @@ enum mg_tls_hs_state { struct tls_data { enum mg_tls_hs_state state; // keep track of connection handshake progress - struct mg_iobuf send; - struct mg_iobuf recv; + struct mg_iobuf send; // For the receive path, we're reusing c->rtls mg_sha256_ctx sha256; // incremental SHA-256 hash for TLS handshake @@ -9055,56 +9056,30 @@ static void mg_tls_derive_secret(const char *label, uint8_t *key, size_t keysz, memmove(hash, secret, hashsz); } -// receive as much data as we can, but at least one full TLS record -static int mg_tls_recv_msg(struct mg_connection *c) { - struct tls_data *tls = c->tls; - struct mg_iobuf *rio = &tls->recv; - uint16_t record_len; - // Pull data from TCP - for (;;) { - long n; - mg_iobuf_resize(rio, rio->len + 1); - n = mg_io_recv(c, &rio->buf[rio->len], rio->size - rio->len); - if (n > 0) { - rio->len += (size_t) n; - } else if (n == MG_IO_WAIT) { - break; - } else { - if (!c->is_closing) { - mg_error(c, "read IO err"); - } - return MG_IO_ERR; - } - } - // Look if we've pulled everything - if (rio->len < TLS_HDR_SIZE) return MG_IO_WAIT; - - record_len = MG_LOAD_BE16(rio->buf + 3); - if (rio->len < (size_t) TLS_HDR_SIZE + record_len) return MG_IO_WAIT; - return 0; +// Did we receive a full TLS message in the c->rtls buffer? +static bool mg_tls_got_msg(struct mg_connection *c) { + return c->rtls.len >= TLS_HDR_SIZE && + c->rtls.len >= (TLS_HDR_SIZE + MG_LOAD_BE16(c->rtls.buf + 3)); } // Remove a single TLS record from the recv buffer static void mg_tls_drop_packet(struct mg_iobuf *rio) { uint16_t n = MG_LOAD_BE16(rio->buf + 3) + TLS_HDR_SIZE; mg_iobuf_del(rio, 0, n); - // memmove(rio->buf, rio->buf + n, rio->len - n); - // rio->len = rio->len - n; } // read and parse ClientHello record static int mg_tls_client_hello(struct mg_connection *c) { struct tls_data *tls = c->tls; - struct mg_iobuf *rio = &tls->recv; + struct mg_iobuf *rio = &c->rtls; uint8_t session_id_len; uint16_t j; uint16_t cipher_suites_len; uint16_t ext_len; uint8_t *ext; - int r = mg_tls_recv_msg(c); - if (r < 0) { - return r; + if (!mg_tls_got_msg(c)) { + return MG_IO_WAIT; } if (rio->buf[0] != 0x16 || rio->buf[5] != 0x01) { mg_error(c, "not a hello packet"); @@ -9280,14 +9255,16 @@ static void mg_tls_encrypt(struct mg_connection *c, const uint8_t *msg, static int mg_tls_recv_decrypt(struct mg_connection *c, void *buf, size_t bufsz) { struct tls_data *tls = c->tls; - struct mg_iobuf *rio = &tls->recv; + struct mg_iobuf *rio = &c->rtls; + // struct mg_iobuf *rio = &tls->recv; uint16_t msgsz; uint8_t *msg; uint8_t nonce[12]; int r; for (;;) { - r = mg_tls_recv_msg(c); - if (r < 0) return r; + if (!mg_tls_got_msg(c)) { + return MG_IO_WAIT; + } if (rio->buf[0] == 0x17) { break; } else if (rio->buf[0] == 0x15) { @@ -9439,11 +9416,12 @@ static void mg_tls_server_finish(struct mg_connection *c) { } static int mg_tls_client_change_cipher(struct mg_connection *c) { - struct tls_data *tls = c->tls; - struct mg_iobuf *rio = &tls->recv; + // struct tls_data *tls = c->tls; + struct mg_iobuf *rio = &c->rtls; for (;;) { - int r = mg_tls_recv_msg(c); - if (r < 0) return r; + if (!mg_tls_got_msg(c)) { + return MG_IO_WAIT; + } if (rio->buf[0] == 0x14) { // got a ChangeCipher record break; } else if (rio->buf[0] == 0x15) { // skip Alert records @@ -9451,7 +9429,7 @@ static int mg_tls_client_change_cipher(struct mg_connection *c) { mg_tls_drop_packet(rio); } else { mg_error(c, "unexpected packet"); - return -1; + return MG_IO_ERR; } } // consume ChangeCipher packet @@ -9532,9 +9510,7 @@ void mg_tls_handshake(struct mg_connection *c) { mg_tls_generate_application_keys(c); tls->state = MG_TLS_HS_DONE; // fallthrough - case MG_TLS_HS_DONE: - c->is_tls_hs = 0; - return; + case MG_TLS_HS_DONE: c->is_tls_hs = 0; return; } } @@ -9607,7 +9583,8 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) { return; } - tls->send.align = tls->recv.align = MG_IO_SIZE; + // tls->send.align = tls->recv.align = MG_IO_SIZE; + tls->send.align = MG_IO_SIZE; c->tls = tls; c->is_tls = c->is_tls_hs = 1; mg_sha256_init(&tls->sha256); @@ -9617,7 +9594,6 @@ void mg_tls_free(struct mg_connection *c) { struct tls_data *tls = c->tls; if (tls != NULL) { mg_iobuf_free(&tls->send); - mg_iobuf_free(&tls->recv); free((void *) tls->server_cert_der.ptr); } free(c->tls); @@ -9627,7 +9603,7 @@ void mg_tls_free(struct mg_connection *c) { long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) { struct tls_data *tls = c->tls; long n = MG_IO_WAIT; - if (len > 2048) len = 2048; + if (len > MG_IO_SIZE) len = MG_IO_SIZE; mg_tls_encrypt(c, buf, len, 0x17); while (tls->send.len > 0 && (n = mg_io_send(c, tls->send.buf, tls->send.len)) > 0) { @@ -9642,8 +9618,7 @@ long mg_tls_recv(struct mg_connection *c, void *buf, size_t len) { } size_t mg_tls_pending(struct mg_connection *c) { - struct tls_data *tls = (struct tls_data *) c->tls; - return tls == NULL ? 0 : tls->recv.len; + return mg_tls_got_msg(c) ? 1 : 0; } void mg_tls_ctx_init(struct mg_mgr *mgr) { diff --git a/src/sock.c b/src/sock.c index 76755d17..4823caec 100644 --- a/src/sock.c +++ b/src/sock.c @@ -289,6 +289,8 @@ static void read_conn(struct mg_connection *c) { if (c->is_tls_hs) mg_tls_handshake(c); if (c->is_tls_hs) return; n = mg_tls_recv(c, buf, len); + } else if (n == MG_IO_WAIT) { + n = mg_tls_recv(c, buf, len); } } else { n = recv_raw(c, buf, len); diff --git a/src/tls_builtin.c b/src/tls_builtin.c index 9dc29754..08198a0d 100644 --- a/src/tls_builtin.c +++ b/src/tls_builtin.c @@ -15,8 +15,7 @@ enum mg_tls_hs_state { struct tls_data { enum mg_tls_hs_state state; // keep track of connection handshake progress - struct mg_iobuf send; - struct mg_iobuf recv; + struct mg_iobuf send; // For the receive path, we're reusing c->rtls mg_sha256_ctx sha256; // incremental SHA-256 hash for TLS handshake @@ -308,56 +307,30 @@ static void mg_tls_derive_secret(const char *label, uint8_t *key, size_t keysz, memmove(hash, secret, hashsz); } -// receive as much data as we can, but at least one full TLS record -static int mg_tls_recv_msg(struct mg_connection *c) { - struct tls_data *tls = c->tls; - struct mg_iobuf *rio = &tls->recv; - uint16_t record_len; - // Pull data from TCP - for (;;) { - long n; - mg_iobuf_resize(rio, rio->len + 1); - n = mg_io_recv(c, &rio->buf[rio->len], rio->size - rio->len); - if (n > 0) { - rio->len += (size_t) n; - } else if (n == MG_IO_WAIT) { - break; - } else { - if (!c->is_closing) { - mg_error(c, "read IO err"); - } - return MG_IO_ERR; - } - } - // Look if we've pulled everything - if (rio->len < TLS_HDR_SIZE) return MG_IO_WAIT; - - record_len = MG_LOAD_BE16(rio->buf + 3); - if (rio->len < (size_t) TLS_HDR_SIZE + record_len) return MG_IO_WAIT; - return 0; +// Did we receive a full TLS message in the c->rtls buffer? +static bool mg_tls_got_msg(struct mg_connection *c) { + return c->rtls.len >= TLS_HDR_SIZE && + c->rtls.len >= (TLS_HDR_SIZE + MG_LOAD_BE16(c->rtls.buf + 3)); } // Remove a single TLS record from the recv buffer static void mg_tls_drop_packet(struct mg_iobuf *rio) { uint16_t n = MG_LOAD_BE16(rio->buf + 3) + TLS_HDR_SIZE; mg_iobuf_del(rio, 0, n); - // memmove(rio->buf, rio->buf + n, rio->len - n); - // rio->len = rio->len - n; } // read and parse ClientHello record static int mg_tls_client_hello(struct mg_connection *c) { struct tls_data *tls = c->tls; - struct mg_iobuf *rio = &tls->recv; + struct mg_iobuf *rio = &c->rtls; uint8_t session_id_len; uint16_t j; uint16_t cipher_suites_len; uint16_t ext_len; uint8_t *ext; - int r = mg_tls_recv_msg(c); - if (r < 0) { - return r; + if (!mg_tls_got_msg(c)) { + return MG_IO_WAIT; } if (rio->buf[0] != 0x16 || rio->buf[5] != 0x01) { mg_error(c, "not a hello packet"); @@ -533,14 +506,16 @@ static void mg_tls_encrypt(struct mg_connection *c, const uint8_t *msg, static int mg_tls_recv_decrypt(struct mg_connection *c, void *buf, size_t bufsz) { struct tls_data *tls = c->tls; - struct mg_iobuf *rio = &tls->recv; + struct mg_iobuf *rio = &c->rtls; + // struct mg_iobuf *rio = &tls->recv; uint16_t msgsz; uint8_t *msg; uint8_t nonce[12]; int r; for (;;) { - r = mg_tls_recv_msg(c); - if (r < 0) return r; + if (!mg_tls_got_msg(c)) { + return MG_IO_WAIT; + } if (rio->buf[0] == 0x17) { break; } else if (rio->buf[0] == 0x15) { @@ -692,11 +667,12 @@ static void mg_tls_server_finish(struct mg_connection *c) { } static int mg_tls_client_change_cipher(struct mg_connection *c) { - struct tls_data *tls = c->tls; - struct mg_iobuf *rio = &tls->recv; + // struct tls_data *tls = c->tls; + struct mg_iobuf *rio = &c->rtls; for (;;) { - int r = mg_tls_recv_msg(c); - if (r < 0) return r; + if (!mg_tls_got_msg(c)) { + return MG_IO_WAIT; + } if (rio->buf[0] == 0x14) { // got a ChangeCipher record break; } else if (rio->buf[0] == 0x15) { // skip Alert records @@ -704,7 +680,7 @@ static int mg_tls_client_change_cipher(struct mg_connection *c) { mg_tls_drop_packet(rio); } else { mg_error(c, "unexpected packet"); - return -1; + return MG_IO_ERR; } } // consume ChangeCipher packet @@ -785,9 +761,7 @@ void mg_tls_handshake(struct mg_connection *c) { mg_tls_generate_application_keys(c); tls->state = MG_TLS_HS_DONE; // fallthrough - case MG_TLS_HS_DONE: - c->is_tls_hs = 0; - return; + case MG_TLS_HS_DONE: c->is_tls_hs = 0; return; } } @@ -860,7 +834,8 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) { return; } - tls->send.align = tls->recv.align = MG_IO_SIZE; + // tls->send.align = tls->recv.align = MG_IO_SIZE; + tls->send.align = MG_IO_SIZE; c->tls = tls; c->is_tls = c->is_tls_hs = 1; mg_sha256_init(&tls->sha256); @@ -870,7 +845,6 @@ void mg_tls_free(struct mg_connection *c) { struct tls_data *tls = c->tls; if (tls != NULL) { mg_iobuf_free(&tls->send); - mg_iobuf_free(&tls->recv); free((void *) tls->server_cert_der.ptr); } free(c->tls); @@ -880,7 +854,7 @@ void mg_tls_free(struct mg_connection *c) { long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) { struct tls_data *tls = c->tls; long n = MG_IO_WAIT; - if (len > 2048) len = 2048; + if (len > MG_IO_SIZE) len = MG_IO_SIZE; mg_tls_encrypt(c, buf, len, 0x17); while (tls->send.len > 0 && (n = mg_io_send(c, tls->send.buf, tls->send.len)) > 0) { @@ -895,8 +869,7 @@ long mg_tls_recv(struct mg_connection *c, void *buf, size_t len) { } size_t mg_tls_pending(struct mg_connection *c) { - struct tls_data *tls = (struct tls_data *) c->tls; - return tls == NULL ? 0 : tls->recv.len; + return mg_tls_got_msg(c) ? 1 : 0; } void mg_tls_ctx_init(struct mg_mgr *mgr) {