mirror of
https://github.com/cesanta/mongoose.git
synced 2024-12-28 23:49:44 +08:00
Fix TLS codepath for pending data. Reuse c->rtls
This commit is contained in:
parent
97127a57da
commit
06f8238107
77
mongoose.c
77
mongoose.c
@ -6962,6 +6962,8 @@ static void read_conn(struct mg_connection *c) {
|
||||
if (c->is_tls_hs) mg_tls_handshake(c);
|
||||
if (c->is_tls_hs) return;
|
||||
n = mg_tls_recv(c, buf, len);
|
||||
} else if (n == MG_IO_WAIT) {
|
||||
n = mg_tls_recv(c, buf, len);
|
||||
}
|
||||
} else {
|
||||
n = recv_raw(c, buf, len);
|
||||
@ -8762,8 +8764,7 @@ enum mg_tls_hs_state {
|
||||
struct tls_data {
|
||||
enum mg_tls_hs_state state; // keep track of connection handshake progress
|
||||
|
||||
struct mg_iobuf send;
|
||||
struct mg_iobuf recv;
|
||||
struct mg_iobuf send; // For the receive path, we're reusing c->rtls
|
||||
|
||||
mg_sha256_ctx sha256; // incremental SHA-256 hash for TLS handshake
|
||||
|
||||
@ -9055,56 +9056,30 @@ static void mg_tls_derive_secret(const char *label, uint8_t *key, size_t keysz,
|
||||
memmove(hash, secret, hashsz);
|
||||
}
|
||||
|
||||
// receive as much data as we can, but at least one full TLS record
|
||||
static int mg_tls_recv_msg(struct mg_connection *c) {
|
||||
struct tls_data *tls = c->tls;
|
||||
struct mg_iobuf *rio = &tls->recv;
|
||||
uint16_t record_len;
|
||||
// Pull data from TCP
|
||||
for (;;) {
|
||||
long n;
|
||||
mg_iobuf_resize(rio, rio->len + 1);
|
||||
n = mg_io_recv(c, &rio->buf[rio->len], rio->size - rio->len);
|
||||
if (n > 0) {
|
||||
rio->len += (size_t) n;
|
||||
} else if (n == MG_IO_WAIT) {
|
||||
break;
|
||||
} else {
|
||||
if (!c->is_closing) {
|
||||
mg_error(c, "read IO err");
|
||||
}
|
||||
return MG_IO_ERR;
|
||||
}
|
||||
}
|
||||
// Look if we've pulled everything
|
||||
if (rio->len < TLS_HDR_SIZE) return MG_IO_WAIT;
|
||||
|
||||
record_len = MG_LOAD_BE16(rio->buf + 3);
|
||||
if (rio->len < (size_t) TLS_HDR_SIZE + record_len) return MG_IO_WAIT;
|
||||
return 0;
|
||||
// Did we receive a full TLS message in the c->rtls buffer?
|
||||
static bool mg_tls_got_msg(struct mg_connection *c) {
|
||||
return c->rtls.len >= TLS_HDR_SIZE &&
|
||||
c->rtls.len >= (TLS_HDR_SIZE + MG_LOAD_BE16(c->rtls.buf + 3));
|
||||
}
|
||||
|
||||
// Remove a single TLS record from the recv buffer
|
||||
static void mg_tls_drop_packet(struct mg_iobuf *rio) {
|
||||
uint16_t n = MG_LOAD_BE16(rio->buf + 3) + TLS_HDR_SIZE;
|
||||
mg_iobuf_del(rio, 0, n);
|
||||
// memmove(rio->buf, rio->buf + n, rio->len - n);
|
||||
// rio->len = rio->len - n;
|
||||
}
|
||||
|
||||
// read and parse ClientHello record
|
||||
static int mg_tls_client_hello(struct mg_connection *c) {
|
||||
struct tls_data *tls = c->tls;
|
||||
struct mg_iobuf *rio = &tls->recv;
|
||||
struct mg_iobuf *rio = &c->rtls;
|
||||
uint8_t session_id_len;
|
||||
uint16_t j;
|
||||
uint16_t cipher_suites_len;
|
||||
uint16_t ext_len;
|
||||
uint8_t *ext;
|
||||
|
||||
int r = mg_tls_recv_msg(c);
|
||||
if (r < 0) {
|
||||
return r;
|
||||
if (!mg_tls_got_msg(c)) {
|
||||
return MG_IO_WAIT;
|
||||
}
|
||||
if (rio->buf[0] != 0x16 || rio->buf[5] != 0x01) {
|
||||
mg_error(c, "not a hello packet");
|
||||
@ -9280,14 +9255,16 @@ static void mg_tls_encrypt(struct mg_connection *c, const uint8_t *msg,
|
||||
static int mg_tls_recv_decrypt(struct mg_connection *c, void *buf,
|
||||
size_t bufsz) {
|
||||
struct tls_data *tls = c->tls;
|
||||
struct mg_iobuf *rio = &tls->recv;
|
||||
struct mg_iobuf *rio = &c->rtls;
|
||||
// struct mg_iobuf *rio = &tls->recv;
|
||||
uint16_t msgsz;
|
||||
uint8_t *msg;
|
||||
uint8_t nonce[12];
|
||||
int r;
|
||||
for (;;) {
|
||||
r = mg_tls_recv_msg(c);
|
||||
if (r < 0) return r;
|
||||
if (!mg_tls_got_msg(c)) {
|
||||
return MG_IO_WAIT;
|
||||
}
|
||||
if (rio->buf[0] == 0x17) {
|
||||
break;
|
||||
} else if (rio->buf[0] == 0x15) {
|
||||
@ -9439,11 +9416,12 @@ static void mg_tls_server_finish(struct mg_connection *c) {
|
||||
}
|
||||
|
||||
static int mg_tls_client_change_cipher(struct mg_connection *c) {
|
||||
struct tls_data *tls = c->tls;
|
||||
struct mg_iobuf *rio = &tls->recv;
|
||||
// struct tls_data *tls = c->tls;
|
||||
struct mg_iobuf *rio = &c->rtls;
|
||||
for (;;) {
|
||||
int r = mg_tls_recv_msg(c);
|
||||
if (r < 0) return r;
|
||||
if (!mg_tls_got_msg(c)) {
|
||||
return MG_IO_WAIT;
|
||||
}
|
||||
if (rio->buf[0] == 0x14) { // got a ChangeCipher record
|
||||
break;
|
||||
} else if (rio->buf[0] == 0x15) { // skip Alert records
|
||||
@ -9451,7 +9429,7 @@ static int mg_tls_client_change_cipher(struct mg_connection *c) {
|
||||
mg_tls_drop_packet(rio);
|
||||
} else {
|
||||
mg_error(c, "unexpected packet");
|
||||
return -1;
|
||||
return MG_IO_ERR;
|
||||
}
|
||||
}
|
||||
// consume ChangeCipher packet
|
||||
@ -9532,9 +9510,7 @@ void mg_tls_handshake(struct mg_connection *c) {
|
||||
mg_tls_generate_application_keys(c);
|
||||
tls->state = MG_TLS_HS_DONE;
|
||||
// fallthrough
|
||||
case MG_TLS_HS_DONE:
|
||||
c->is_tls_hs = 0;
|
||||
return;
|
||||
case MG_TLS_HS_DONE: c->is_tls_hs = 0; return;
|
||||
}
|
||||
}
|
||||
|
||||
@ -9607,7 +9583,8 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
|
||||
return;
|
||||
}
|
||||
|
||||
tls->send.align = tls->recv.align = MG_IO_SIZE;
|
||||
// tls->send.align = tls->recv.align = MG_IO_SIZE;
|
||||
tls->send.align = MG_IO_SIZE;
|
||||
c->tls = tls;
|
||||
c->is_tls = c->is_tls_hs = 1;
|
||||
mg_sha256_init(&tls->sha256);
|
||||
@ -9617,7 +9594,6 @@ void mg_tls_free(struct mg_connection *c) {
|
||||
struct tls_data *tls = c->tls;
|
||||
if (tls != NULL) {
|
||||
mg_iobuf_free(&tls->send);
|
||||
mg_iobuf_free(&tls->recv);
|
||||
free((void *) tls->server_cert_der.ptr);
|
||||
}
|
||||
free(c->tls);
|
||||
@ -9627,7 +9603,7 @@ void mg_tls_free(struct mg_connection *c) {
|
||||
long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) {
|
||||
struct tls_data *tls = c->tls;
|
||||
long n = MG_IO_WAIT;
|
||||
if (len > 2048) len = 2048;
|
||||
if (len > MG_IO_SIZE) len = MG_IO_SIZE;
|
||||
mg_tls_encrypt(c, buf, len, 0x17);
|
||||
while (tls->send.len > 0 &&
|
||||
(n = mg_io_send(c, tls->send.buf, tls->send.len)) > 0) {
|
||||
@ -9642,8 +9618,7 @@ long mg_tls_recv(struct mg_connection *c, void *buf, size_t len) {
|
||||
}
|
||||
|
||||
size_t mg_tls_pending(struct mg_connection *c) {
|
||||
struct tls_data *tls = (struct tls_data *) c->tls;
|
||||
return tls == NULL ? 0 : tls->recv.len;
|
||||
return mg_tls_got_msg(c) ? 1 : 0;
|
||||
}
|
||||
|
||||
void mg_tls_ctx_init(struct mg_mgr *mgr) {
|
||||
|
@ -289,6 +289,8 @@ static void read_conn(struct mg_connection *c) {
|
||||
if (c->is_tls_hs) mg_tls_handshake(c);
|
||||
if (c->is_tls_hs) return;
|
||||
n = mg_tls_recv(c, buf, len);
|
||||
} else if (n == MG_IO_WAIT) {
|
||||
n = mg_tls_recv(c, buf, len);
|
||||
}
|
||||
} else {
|
||||
n = recv_raw(c, buf, len);
|
||||
|
@ -15,8 +15,7 @@ enum mg_tls_hs_state {
|
||||
struct tls_data {
|
||||
enum mg_tls_hs_state state; // keep track of connection handshake progress
|
||||
|
||||
struct mg_iobuf send;
|
||||
struct mg_iobuf recv;
|
||||
struct mg_iobuf send; // For the receive path, we're reusing c->rtls
|
||||
|
||||
mg_sha256_ctx sha256; // incremental SHA-256 hash for TLS handshake
|
||||
|
||||
@ -308,56 +307,30 @@ static void mg_tls_derive_secret(const char *label, uint8_t *key, size_t keysz,
|
||||
memmove(hash, secret, hashsz);
|
||||
}
|
||||
|
||||
// receive as much data as we can, but at least one full TLS record
|
||||
static int mg_tls_recv_msg(struct mg_connection *c) {
|
||||
struct tls_data *tls = c->tls;
|
||||
struct mg_iobuf *rio = &tls->recv;
|
||||
uint16_t record_len;
|
||||
// Pull data from TCP
|
||||
for (;;) {
|
||||
long n;
|
||||
mg_iobuf_resize(rio, rio->len + 1);
|
||||
n = mg_io_recv(c, &rio->buf[rio->len], rio->size - rio->len);
|
||||
if (n > 0) {
|
||||
rio->len += (size_t) n;
|
||||
} else if (n == MG_IO_WAIT) {
|
||||
break;
|
||||
} else {
|
||||
if (!c->is_closing) {
|
||||
mg_error(c, "read IO err");
|
||||
}
|
||||
return MG_IO_ERR;
|
||||
}
|
||||
}
|
||||
// Look if we've pulled everything
|
||||
if (rio->len < TLS_HDR_SIZE) return MG_IO_WAIT;
|
||||
|
||||
record_len = MG_LOAD_BE16(rio->buf + 3);
|
||||
if (rio->len < (size_t) TLS_HDR_SIZE + record_len) return MG_IO_WAIT;
|
||||
return 0;
|
||||
// Did we receive a full TLS message in the c->rtls buffer?
|
||||
static bool mg_tls_got_msg(struct mg_connection *c) {
|
||||
return c->rtls.len >= TLS_HDR_SIZE &&
|
||||
c->rtls.len >= (TLS_HDR_SIZE + MG_LOAD_BE16(c->rtls.buf + 3));
|
||||
}
|
||||
|
||||
// Remove a single TLS record from the recv buffer
|
||||
static void mg_tls_drop_packet(struct mg_iobuf *rio) {
|
||||
uint16_t n = MG_LOAD_BE16(rio->buf + 3) + TLS_HDR_SIZE;
|
||||
mg_iobuf_del(rio, 0, n);
|
||||
// memmove(rio->buf, rio->buf + n, rio->len - n);
|
||||
// rio->len = rio->len - n;
|
||||
}
|
||||
|
||||
// read and parse ClientHello record
|
||||
static int mg_tls_client_hello(struct mg_connection *c) {
|
||||
struct tls_data *tls = c->tls;
|
||||
struct mg_iobuf *rio = &tls->recv;
|
||||
struct mg_iobuf *rio = &c->rtls;
|
||||
uint8_t session_id_len;
|
||||
uint16_t j;
|
||||
uint16_t cipher_suites_len;
|
||||
uint16_t ext_len;
|
||||
uint8_t *ext;
|
||||
|
||||
int r = mg_tls_recv_msg(c);
|
||||
if (r < 0) {
|
||||
return r;
|
||||
if (!mg_tls_got_msg(c)) {
|
||||
return MG_IO_WAIT;
|
||||
}
|
||||
if (rio->buf[0] != 0x16 || rio->buf[5] != 0x01) {
|
||||
mg_error(c, "not a hello packet");
|
||||
@ -533,14 +506,16 @@ static void mg_tls_encrypt(struct mg_connection *c, const uint8_t *msg,
|
||||
static int mg_tls_recv_decrypt(struct mg_connection *c, void *buf,
|
||||
size_t bufsz) {
|
||||
struct tls_data *tls = c->tls;
|
||||
struct mg_iobuf *rio = &tls->recv;
|
||||
struct mg_iobuf *rio = &c->rtls;
|
||||
// struct mg_iobuf *rio = &tls->recv;
|
||||
uint16_t msgsz;
|
||||
uint8_t *msg;
|
||||
uint8_t nonce[12];
|
||||
int r;
|
||||
for (;;) {
|
||||
r = mg_tls_recv_msg(c);
|
||||
if (r < 0) return r;
|
||||
if (!mg_tls_got_msg(c)) {
|
||||
return MG_IO_WAIT;
|
||||
}
|
||||
if (rio->buf[0] == 0x17) {
|
||||
break;
|
||||
} else if (rio->buf[0] == 0x15) {
|
||||
@ -692,11 +667,12 @@ static void mg_tls_server_finish(struct mg_connection *c) {
|
||||
}
|
||||
|
||||
static int mg_tls_client_change_cipher(struct mg_connection *c) {
|
||||
struct tls_data *tls = c->tls;
|
||||
struct mg_iobuf *rio = &tls->recv;
|
||||
// struct tls_data *tls = c->tls;
|
||||
struct mg_iobuf *rio = &c->rtls;
|
||||
for (;;) {
|
||||
int r = mg_tls_recv_msg(c);
|
||||
if (r < 0) return r;
|
||||
if (!mg_tls_got_msg(c)) {
|
||||
return MG_IO_WAIT;
|
||||
}
|
||||
if (rio->buf[0] == 0x14) { // got a ChangeCipher record
|
||||
break;
|
||||
} else if (rio->buf[0] == 0x15) { // skip Alert records
|
||||
@ -704,7 +680,7 @@ static int mg_tls_client_change_cipher(struct mg_connection *c) {
|
||||
mg_tls_drop_packet(rio);
|
||||
} else {
|
||||
mg_error(c, "unexpected packet");
|
||||
return -1;
|
||||
return MG_IO_ERR;
|
||||
}
|
||||
}
|
||||
// consume ChangeCipher packet
|
||||
@ -785,9 +761,7 @@ void mg_tls_handshake(struct mg_connection *c) {
|
||||
mg_tls_generate_application_keys(c);
|
||||
tls->state = MG_TLS_HS_DONE;
|
||||
// fallthrough
|
||||
case MG_TLS_HS_DONE:
|
||||
c->is_tls_hs = 0;
|
||||
return;
|
||||
case MG_TLS_HS_DONE: c->is_tls_hs = 0; return;
|
||||
}
|
||||
}
|
||||
|
||||
@ -860,7 +834,8 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
|
||||
return;
|
||||
}
|
||||
|
||||
tls->send.align = tls->recv.align = MG_IO_SIZE;
|
||||
// tls->send.align = tls->recv.align = MG_IO_SIZE;
|
||||
tls->send.align = MG_IO_SIZE;
|
||||
c->tls = tls;
|
||||
c->is_tls = c->is_tls_hs = 1;
|
||||
mg_sha256_init(&tls->sha256);
|
||||
@ -870,7 +845,6 @@ void mg_tls_free(struct mg_connection *c) {
|
||||
struct tls_data *tls = c->tls;
|
||||
if (tls != NULL) {
|
||||
mg_iobuf_free(&tls->send);
|
||||
mg_iobuf_free(&tls->recv);
|
||||
free((void *) tls->server_cert_der.ptr);
|
||||
}
|
||||
free(c->tls);
|
||||
@ -880,7 +854,7 @@ void mg_tls_free(struct mg_connection *c) {
|
||||
long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) {
|
||||
struct tls_data *tls = c->tls;
|
||||
long n = MG_IO_WAIT;
|
||||
if (len > 2048) len = 2048;
|
||||
if (len > MG_IO_SIZE) len = MG_IO_SIZE;
|
||||
mg_tls_encrypt(c, buf, len, 0x17);
|
||||
while (tls->send.len > 0 &&
|
||||
(n = mg_io_send(c, tls->send.buf, tls->send.len)) > 0) {
|
||||
@ -895,8 +869,7 @@ long mg_tls_recv(struct mg_connection *c, void *buf, size_t len) {
|
||||
}
|
||||
|
||||
size_t mg_tls_pending(struct mg_connection *c) {
|
||||
struct tls_data *tls = (struct tls_data *) c->tls;
|
||||
return tls == NULL ? 0 : tls->recv.len;
|
||||
return mg_tls_got_msg(c) ? 1 : 0;
|
||||
}
|
||||
|
||||
void mg_tls_ctx_init(struct mg_mgr *mgr) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user