Handle WS fragmentation

This commit is contained in:
Sergey Lyubka 2021-08-27 10:25:24 +01:00
parent 7ba6fda236
commit 714b7a8145
7 changed files with 155 additions and 38 deletions

View File

@ -1776,14 +1776,19 @@ size_t mg_iobuf_append(struct mg_iobuf *io, const void *buf, size_t len,
return len;
}
size_t mg_iobuf_delete(struct mg_iobuf *io, size_t len) {
if (len > io->len) len = io->len;
memmove(io->buf, io->buf + len, io->len - len);
zeromem(io->buf + io->len - len, len);
size_t mg_iobuf_del(struct mg_iobuf *io, size_t ofs, size_t len) {
if (ofs > io->len) ofs = io->len;
if (ofs + len > io->len) len = io->len - ofs;
memmove(io->buf + ofs, io->buf + ofs + len, io->len - ofs - len);
zeromem(io->buf + ofs + io->len - len, len);
io->len -= len;
return len;
}
size_t mg_iobuf_delete(struct mg_iobuf *io, size_t len) {
return mg_iobuf_del(io, 0, len);
}
void mg_iobuf_free(struct mg_iobuf *io) {
mg_iobuf_resize(io, 0);
}
@ -4526,9 +4531,9 @@ static size_t ws_process(uint8_t *buf, size_t len, struct ws_msg *msg) {
size_t i, n = 0, mask_len = 0;
memset(msg, 0, sizeof(*msg));
if (len >= 2) {
n = buf[1] & 0x7f;
mask_len = buf[1] & WEBSOCKET_FLAGS_MASK_FIN ? 4 : 0;
msg->flags = *(unsigned char *) buf;
n = buf[1] & 0x7f; // Frame length
mask_len = buf[1] & 128 ? 4 : 0; // last bit is a mask bit
msg->flags = buf[0];
if (n < 126 && len >= mask_len) {
msg->data_len = n;
msg->header_len = 2 + mask_len;
@ -4552,7 +4557,7 @@ static size_t ws_process(uint8_t *buf, size_t len, struct ws_msg *msg) {
static size_t mkhdr(size_t len, int op, bool is_client, uint8_t *buf) {
size_t n = 0;
buf[0] = (uint8_t) (op | WEBSOCKET_FLAGS_MASK_FIN);
buf[0] = (uint8_t) (op | 128);
if (len < 126) {
buf[1] = (unsigned char) len;
n = 2;
@ -4600,7 +4605,9 @@ size_t mg_ws_send(struct mg_connection *c, const char *buf, size_t len,
static void mg_ws_cb(struct mg_connection *c, int ev, void *ev_data,
void *fn_data) {
struct ws_msg msg;
size_t ofs = (size_t) c->pfn_data;
// assert(ofs < c->recv.len);
if (ev == MG_EV_READ) {
if (!c->is_websocket && c->is_client) {
int n = mg_http_get_request_len(c->recv.buf, c->recv.len);
@ -4623,10 +4630,14 @@ static void mg_ws_cb(struct mg_connection *c, int ev, void *ev_data,
}
}
while (ws_process(c->recv.buf, c->recv.len, &msg) > 0) {
char *s = (char *) c->recv.buf + msg.header_len;
while (ws_process(c->recv.buf + ofs, c->recv.len - ofs, &msg) > 0) {
char *s = (char *) c->recv.buf + ofs + msg.header_len;
struct mg_ws_message m = {{s, msg.data_len}, msg.flags};
switch (msg.flags & WEBSOCKET_FLAGS_MASK_OP) {
size_t len = msg.header_len + msg.data_len;
uint8_t final = msg.flags & 128, op = msg.flags & 15;
// LOG(LL_VERBOSE_DEBUG, ("fin %d op %d len %d [%.*s]", final, op,
// (int) m.data.len, (int) m.data.len, m.data.ptr));
switch (op) {
case WEBSOCKET_OP_CONTINUE:
mg_call(c, MG_EV_WS_CTL, &m);
break;
@ -4640,7 +4651,7 @@ static void mg_ws_cb(struct mg_connection *c, int ev, void *ev_data,
break;
case WEBSOCKET_OP_TEXT:
case WEBSOCKET_OP_BINARY:
mg_call(c, MG_EV_WS_MSG, &m);
if (final) mg_call(c, MG_EV_WS_MSG, &m);
break;
case WEBSOCKET_OP_CLOSE:
LOG(LL_ERROR, ("%lu Got WS CLOSE", c->id));
@ -4649,10 +4660,30 @@ static void mg_ws_cb(struct mg_connection *c, int ev, void *ev_data,
break;
default:
// Per RFC6455, close conn when an unknown op is recvd
mg_error(c, "unknown WS op %d", msg.flags & WEBSOCKET_FLAGS_MASK_OP);
mg_error(c, "unknown WS op %d", op);
break;
}
mg_iobuf_delete(&c->recv, msg.header_len + msg.data_len);
// Handle fragmented frames: strip header, keep in c->recv
if (final == 0 || op == 0) {
if (op) ofs++, len--, msg.header_len--; // First frame
mg_iobuf_del(&c->recv, ofs, msg.header_len); // Strip header
len -= msg.header_len;
ofs += len;
c->pfn_data = (void *) ofs;
// LOG(LL_INFO, ("FRAG %d [%.*s]", (int) ofs, (int) ofs, c->recv.buf));
}
// Remove non-fragmented frame
if (final && op) mg_iobuf_del(&c->recv, ofs, len);
// Last chunk of the fragmented frame
if (final && !op) {
m.flags = c->recv.buf[0];
m.data = mg_str_n((char *) &c->recv.buf[1], (size_t) (ofs - 1));
mg_call(c, MG_EV_WS_MSG, &m);
mg_iobuf_del(&c->recv, 0, ofs);
ofs = 0;
c->pfn_data = NULL;
}
}
}
(void) fn_data;
@ -4692,7 +4723,7 @@ struct mg_connection *mg_ws_connect(struct mg_mgr *mgr, const char *url,
if (buf1 != mem1) free(buf1);
if (buf2 != mem2) free(buf2);
c->pfn = mg_ws_cb;
c->fn_data = fn_data;
c->pfn_data = NULL;
}
return c;
}
@ -4701,6 +4732,7 @@ void mg_ws_upgrade(struct mg_connection *c, struct mg_http_message *hm,
const char *fmt, ...) {
struct mg_str *wskey = mg_http_get_header(hm, "Sec-WebSocket-Key");
c->pfn = mg_ws_cb;
c->pfn_data = NULL;
if (wskey == NULL) {
mg_http_reply(c, 426, "", "WS upgrade expected\n");
c->is_draining = 1;

View File

@ -636,7 +636,8 @@ int mg_iobuf_init(struct mg_iobuf *, size_t);
int mg_iobuf_resize(struct mg_iobuf *, size_t);
void mg_iobuf_free(struct mg_iobuf *);
size_t mg_iobuf_append(struct mg_iobuf *, const void *, size_t, size_t);
size_t mg_iobuf_delete(struct mg_iobuf *, size_t);
size_t mg_iobuf_delete(struct mg_iobuf *, size_t len);
size_t mg_iobuf_del(struct mg_iobuf *, size_t ofs, size_t len);
int mg_base64_update(unsigned char p, char *to, int len);
int mg_base64_final(char *to, int len);
@ -869,9 +870,6 @@ void mg_tls_handshake(struct mg_connection *);
#define WEBSOCKET_OP_PING 9
#define WEBSOCKET_OP_PONG 10
#define WEBSOCKET_FLAGS_MASK_FIN 128
#define WEBSOCKET_FLAGS_MASK_OP 15
struct mg_ws_message {

View File

@ -56,14 +56,19 @@ size_t mg_iobuf_append(struct mg_iobuf *io, const void *buf, size_t len,
return len;
}
size_t mg_iobuf_delete(struct mg_iobuf *io, size_t len) {
if (len > io->len) len = io->len;
memmove(io->buf, io->buf + len, io->len - len);
zeromem(io->buf + io->len - len, len);
size_t mg_iobuf_del(struct mg_iobuf *io, size_t ofs, size_t len) {
if (ofs > io->len) ofs = io->len;
if (ofs + len > io->len) len = io->len - ofs;
memmove(io->buf + ofs, io->buf + ofs + len, io->len - ofs - len);
zeromem(io->buf + ofs + io->len - len, len);
io->len -= len;
return len;
}
size_t mg_iobuf_delete(struct mg_iobuf *io, size_t len) {
return mg_iobuf_del(io, 0, len);
}
void mg_iobuf_free(struct mg_iobuf *io) {
mg_iobuf_resize(io, 0);
}

View File

@ -11,4 +11,5 @@ int mg_iobuf_init(struct mg_iobuf *, size_t);
int mg_iobuf_resize(struct mg_iobuf *, size_t);
void mg_iobuf_free(struct mg_iobuf *);
size_t mg_iobuf_append(struct mg_iobuf *, const void *, size_t, size_t);
size_t mg_iobuf_delete(struct mg_iobuf *, size_t);
size_t mg_iobuf_delete(struct mg_iobuf *, size_t len);
size_t mg_iobuf_del(struct mg_iobuf *, size_t ofs, size_t len);

View File

@ -48,9 +48,9 @@ static size_t ws_process(uint8_t *buf, size_t len, struct ws_msg *msg) {
size_t i, n = 0, mask_len = 0;
memset(msg, 0, sizeof(*msg));
if (len >= 2) {
n = buf[1] & 0x7f;
mask_len = buf[1] & WEBSOCKET_FLAGS_MASK_FIN ? 4 : 0;
msg->flags = *(unsigned char *) buf;
n = buf[1] & 0x7f; // Frame length
mask_len = buf[1] & 128 ? 4 : 0; // last bit is a mask bit
msg->flags = buf[0];
if (n < 126 && len >= mask_len) {
msg->data_len = n;
msg->header_len = 2 + mask_len;
@ -74,7 +74,7 @@ static size_t ws_process(uint8_t *buf, size_t len, struct ws_msg *msg) {
static size_t mkhdr(size_t len, int op, bool is_client, uint8_t *buf) {
size_t n = 0;
buf[0] = (uint8_t) (op | WEBSOCKET_FLAGS_MASK_FIN);
buf[0] = (uint8_t) (op | 128);
if (len < 126) {
buf[1] = (unsigned char) len;
n = 2;
@ -122,7 +122,9 @@ size_t mg_ws_send(struct mg_connection *c, const char *buf, size_t len,
static void mg_ws_cb(struct mg_connection *c, int ev, void *ev_data,
void *fn_data) {
struct ws_msg msg;
size_t ofs = (size_t) c->pfn_data;
// assert(ofs < c->recv.len);
if (ev == MG_EV_READ) {
if (!c->is_websocket && c->is_client) {
int n = mg_http_get_request_len(c->recv.buf, c->recv.len);
@ -145,10 +147,14 @@ static void mg_ws_cb(struct mg_connection *c, int ev, void *ev_data,
}
}
while (ws_process(c->recv.buf, c->recv.len, &msg) > 0) {
char *s = (char *) c->recv.buf + msg.header_len;
while (ws_process(c->recv.buf + ofs, c->recv.len - ofs, &msg) > 0) {
char *s = (char *) c->recv.buf + ofs + msg.header_len;
struct mg_ws_message m = {{s, msg.data_len}, msg.flags};
switch (msg.flags & WEBSOCKET_FLAGS_MASK_OP) {
size_t len = msg.header_len + msg.data_len;
uint8_t final = msg.flags & 128, op = msg.flags & 15;
// LOG(LL_VERBOSE_DEBUG, ("fin %d op %d len %d [%.*s]", final, op,
// (int) m.data.len, (int) m.data.len, m.data.ptr));
switch (op) {
case WEBSOCKET_OP_CONTINUE:
mg_call(c, MG_EV_WS_CTL, &m);
break;
@ -162,7 +168,7 @@ static void mg_ws_cb(struct mg_connection *c, int ev, void *ev_data,
break;
case WEBSOCKET_OP_TEXT:
case WEBSOCKET_OP_BINARY:
mg_call(c, MG_EV_WS_MSG, &m);
if (final) mg_call(c, MG_EV_WS_MSG, &m);
break;
case WEBSOCKET_OP_CLOSE:
LOG(LL_ERROR, ("%lu Got WS CLOSE", c->id));
@ -171,10 +177,30 @@ static void mg_ws_cb(struct mg_connection *c, int ev, void *ev_data,
break;
default:
// Per RFC6455, close conn when an unknown op is recvd
mg_error(c, "unknown WS op %d", msg.flags & WEBSOCKET_FLAGS_MASK_OP);
mg_error(c, "unknown WS op %d", op);
break;
}
mg_iobuf_delete(&c->recv, msg.header_len + msg.data_len);
// Handle fragmented frames: strip header, keep in c->recv
if (final == 0 || op == 0) {
if (op) ofs++, len--, msg.header_len--; // First frame
mg_iobuf_del(&c->recv, ofs, msg.header_len); // Strip header
len -= msg.header_len;
ofs += len;
c->pfn_data = (void *) ofs;
// LOG(LL_INFO, ("FRAG %d [%.*s]", (int) ofs, (int) ofs, c->recv.buf));
}
// Remove non-fragmented frame
if (final && op) mg_iobuf_del(&c->recv, ofs, len);
// Last chunk of the fragmented frame
if (final && !op) {
m.flags = c->recv.buf[0];
m.data = mg_str_n((char *) &c->recv.buf[1], (size_t) (ofs - 1));
mg_call(c, MG_EV_WS_MSG, &m);
mg_iobuf_del(&c->recv, 0, ofs);
ofs = 0;
c->pfn_data = NULL;
}
}
}
(void) fn_data;
@ -214,7 +240,7 @@ struct mg_connection *mg_ws_connect(struct mg_mgr *mgr, const char *url,
if (buf1 != mem1) free(buf1);
if (buf2 != mem2) free(buf2);
c->pfn = mg_ws_cb;
c->fn_data = fn_data;
c->pfn_data = NULL;
}
return c;
}
@ -223,6 +249,7 @@ void mg_ws_upgrade(struct mg_connection *c, struct mg_http_message *hm,
const char *fmt, ...) {
struct mg_str *wskey = mg_http_get_header(hm, "Sec-WebSocket-Key");
c->pfn = mg_ws_cb;
c->pfn_data = NULL;
if (wskey == NULL) {
mg_http_reply(c, 426, "", "WS upgrade expected\n");
c->is_draining = 1;

View File

@ -7,9 +7,6 @@
#define WEBSOCKET_OP_PING 9
#define WEBSOCKET_OP_PONG 10
#define WEBSOCKET_FLAGS_MASK_FIN 128
#define WEBSOCKET_FLAGS_MASK_OP 15
#include "http.h"
struct mg_ws_message {

View File

@ -1474,6 +1474,62 @@ static void test_check_ip_acl(void) {
ASSERT(mg_check_ip_acl(mg_str("-0.0.0.0/0,+1.0.0.0/16"), ip) == 0);
}
static void w3(struct mg_connection *c, int ev, void *ev_data, void *fn_data) {
LOG(LL_INFO, ("ev %d", ev));
if (ev == MG_EV_WS_OPEN) {
mg_ws_send(c, "hi there!", 9, WEBSOCKET_OP_TEXT);
} else if (ev == MG_EV_WS_MSG) {
struct mg_ws_message *wm = (struct mg_ws_message *) ev_data;
ASSERT(mg_strcmp(wm->data, mg_str("lebowski")) == 0);
((int *) fn_data)[0]++;
} else if (ev == MG_EV_CLOSE) {
((int *) fn_data)[0] += 10;
}
}
static void w2(struct mg_connection *c, int ev, void *ev_data, void *fn_data) {
struct mg_str msg = mg_str_n("lebowski", 8);
if (ev == MG_EV_HTTP_MSG) {
mg_ws_upgrade(c, (struct mg_http_message *) ev_data, NULL);
} else if (ev == MG_EV_WS_OPEN) {
mg_ws_send(c, "x", 1, WEBSOCKET_OP_PONG);
} else if (ev == MG_EV_POLL && c->is_websocket) {
size_t ofs, n = (size_t) fn_data;
if (n < msg.len) {
// Send "msg" char by char using fragmented frames
// mg_ws_send() sets the FIN flag in the WS header. Clean it
// to send fragmented packet. Insert PONG messages between frames
uint8_t op = n == 0 ? WEBSOCKET_OP_TEXT : WEBSOCKET_OP_CONTINUE;
mg_ws_send(c, ":->", 3, WEBSOCKET_OP_PING);
ofs = c->send.len;
mg_ws_send(c, &msg.ptr[n], 1, op);
if (n < msg.len - 1) c->send.buf[ofs] = op; // Clear FIN flag
c->fn_data = (void *) (n + 1); // Point to the next char
} else {
mg_ws_send(c, "", 0, WEBSOCKET_OP_CLOSE);
}
} else if (ev == MG_EV_WS_MSG) {
struct mg_ws_message *wm = (struct mg_ws_message *) ev_data;
ASSERT(mg_strcmp(wm->data, mg_str("hi there!")) == 0);
}
}
static void test_ws_fragmentation(void) {
const char *url = "ws://localhost:12357/ws";
struct mg_mgr mgr;
int i, done = 0;
mg_mgr_init(&mgr);
ASSERT(mg_http_listen(&mgr, url, w2, NULL) != NULL);
mg_ws_connect(&mgr, url, w3, &done, "%s", "Sec-WebSocket-Protocol: echo\r\n");
for (i = 0; i < 25; i++) mg_mgr_poll(&mgr, 1);
// LOG(LL_INFO, ("--> %d", done));
ASSERT(done == 11);
mg_mgr_free(&mgr);
ASSERT(mgr.conns == NULL);
}
int main(void) {
mg_log_set("3");
test_check_ip_acl();
@ -1497,6 +1553,7 @@ int main(void) {
test_http_get_var();
test_tls();
test_ws();
test_ws_fragmentation();
test_http_server();
test_http_client();
test_http_no_content_length();