From c2c6ff11ed11b5c65aa5aa85cfcd78576e352a26 Mon Sep 17 00:00:00 2001 From: Sergey Lyubka Date: Wed, 4 Aug 2021 11:38:32 +0100 Subject: [PATCH] Fix #1329 - send MG_EV_WS_OPEN for server connections --- Makefile | 2 +- mongoose.c | 20 ++++++++++++++------ src/http.c | 2 +- src/ws.c | 18 +++++++++++++----- test/unit_test.c | 27 +++++++++++++++++---------- 5 files changed, 46 insertions(+), 23 deletions(-) diff --git a/Makefile b/Makefile index cf1c8a85..d4802fa0 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ SRCS = mongoose.c test/unit_test.c test/packed_fs.c HDRS = $(wildcard src/*.h) -DEFS ?= -DMG_MAX_HTTP_HEADERS=5 -DMG_ENABLE_LINES -DMG_ENABLE_PACKED_FS=1 +DEFS ?= -DMG_MAX_HTTP_HEADERS=7 -DMG_ENABLE_LINES -DMG_ENABLE_PACKED_FS=1 WARN ?= -W -Wall -Werror -Wshadow -Wdouble-promotion -fno-common -Wconversion OPTS ?= -O3 -g3 INCS ?= -Isrc -I. diff --git a/mongoose.c b/mongoose.c index 9561409e..99d8072e 100644 --- a/mongoose.c +++ b/mongoose.c @@ -1511,7 +1511,7 @@ void mg_http_serve_dir(struct mg_connection *c, struct mg_http_message *hm, int flags = uri_to_path(c, hm, opts, root, sizeof(root), path, sizeof(path)); if (flags == 0) return; - LOG(LL_DEBUG, ("root [%s], path [%s] %d", root, path, flags)); + // LOG(LL_DEBUG, ("root [%s], path [%s] %d", root, path, flags)); if (flags & MG_FS_DIR) { listdir(c, hm, opts, path); } else if (opts->ssi_pattern != NULL && @@ -4458,15 +4458,16 @@ struct ws_msg { size_t data_len; }; -static void ws_handshake(struct mg_connection *c, const char *key, - size_t key_len, const char *fmt, va_list ap) { +static void ws_handshake(struct mg_connection *c, const struct mg_str *wskey, + const struct mg_str *wsproto, const char *fmt, + va_list ap) { const char *magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; unsigned char sha[20], b64_sha[30]; char mem[128], *buf = mem; mg_sha1_ctx sha_ctx; mg_sha1_init(&sha_ctx); - mg_sha1_update(&sha_ctx, (unsigned char *) key, key_len); + mg_sha1_update(&sha_ctx, (unsigned char *) wskey->ptr, wskey->len); mg_sha1_update(&sha_ctx, (unsigned char *) magic, 36); mg_sha1_final(sha, &sha_ctx); mg_base64_encode(sha, sizeof(sha), (char *) b64_sha); @@ -4477,9 +4478,14 @@ static void ws_handshake(struct mg_connection *c, const char *key, "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Accept: %s\r\n" - "%s\r\n", + "%s", b64_sha, buf); if (buf != mem) free(buf); + if (wsproto != NULL) { + mg_printf(c, "Sec-WebSocket-Protocol: %.*s\r\n", (int) wsproto->len, + wsproto->ptr); + } + mg_send(c, "\r\n", 2); } static size_t ws_process(uint8_t *buf, size_t len, struct ws_msg *msg) { @@ -4665,11 +4671,13 @@ void mg_ws_upgrade(struct mg_connection *c, struct mg_http_message *hm, mg_http_reply(c, 426, "", "WS upgrade expected\n"); c->is_draining = 1; } else { + struct mg_str *wsproto = mg_http_get_header(hm, "Sec-WebSocket-Protocol"); va_list ap; va_start(ap, fmt); - ws_handshake(c, wskey->ptr, wskey->len, fmt, ap); + ws_handshake(c, wskey, wsproto, fmt, ap); va_end(ap); c->is_websocket = 1; + mg_call(c, MG_EV_WS_OPEN, &hm); } } diff --git a/src/http.c b/src/http.c index 3d89c665..dfc30827 100644 --- a/src/http.c +++ b/src/http.c @@ -740,7 +740,7 @@ void mg_http_serve_dir(struct mg_connection *c, struct mg_http_message *hm, int flags = uri_to_path(c, hm, opts, root, sizeof(root), path, sizeof(path)); if (flags == 0) return; - LOG(LL_DEBUG, ("root [%s], path [%s] %d", root, path, flags)); + // LOG(LL_DEBUG, ("root [%s], path [%s] %d", root, path, flags)); if (flags & MG_FS_DIR) { listdir(c, hm, opts, path); } else if (opts->ssi_pattern != NULL && diff --git a/src/ws.c b/src/ws.c index 28448339..1e189f5a 100644 --- a/src/ws.c +++ b/src/ws.c @@ -14,15 +14,16 @@ struct ws_msg { size_t data_len; }; -static void ws_handshake(struct mg_connection *c, const char *key, - size_t key_len, const char *fmt, va_list ap) { +static void ws_handshake(struct mg_connection *c, const struct mg_str *wskey, + const struct mg_str *wsproto, const char *fmt, + va_list ap) { const char *magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; unsigned char sha[20], b64_sha[30]; char mem[128], *buf = mem; mg_sha1_ctx sha_ctx; mg_sha1_init(&sha_ctx); - mg_sha1_update(&sha_ctx, (unsigned char *) key, key_len); + mg_sha1_update(&sha_ctx, (unsigned char *) wskey->ptr, wskey->len); mg_sha1_update(&sha_ctx, (unsigned char *) magic, 36); mg_sha1_final(sha, &sha_ctx); mg_base64_encode(sha, sizeof(sha), (char *) b64_sha); @@ -33,9 +34,14 @@ static void ws_handshake(struct mg_connection *c, const char *key, "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" "Sec-WebSocket-Accept: %s\r\n" - "%s\r\n", + "%s", b64_sha, buf); if (buf != mem) free(buf); + if (wsproto != NULL) { + mg_printf(c, "Sec-WebSocket-Protocol: %.*s\r\n", (int) wsproto->len, + wsproto->ptr); + } + mg_send(c, "\r\n", 2); } static size_t ws_process(uint8_t *buf, size_t len, struct ws_msg *msg) { @@ -221,11 +227,13 @@ void mg_ws_upgrade(struct mg_connection *c, struct mg_http_message *hm, mg_http_reply(c, 426, "", "WS upgrade expected\n"); c->is_draining = 1; } else { + struct mg_str *wsproto = mg_http_get_header(hm, "Sec-WebSocket-Protocol"); va_list ap; va_start(ap, fmt); - ws_handshake(c, wskey->ptr, wskey->len, fmt, ap); + ws_handshake(c, wskey, wsproto, fmt, ap); va_end(ap); c->is_websocket = 1; + mg_call(c, MG_EV_WS_OPEN, &hm); } } diff --git a/test/unit_test.c b/test/unit_test.c index 0e9cb4e9..0474b5b2 100644 --- a/test/unit_test.c +++ b/test/unit_test.c @@ -377,6 +377,8 @@ static void eh1(struct mg_connection *c, int ev, void *ev_data, void *fn_data) { sopts.extra_headers = "C: D\r\n"; mg_http_serve_dir(c, hm, &sopts); } + } else if (ev == MG_EV_WS_OPEN) { + mg_ws_send(c, "opened", 6, WEBSOCKET_OP_BINARY); } else if (ev == MG_EV_WS_MSG) { struct mg_ws_message *wm = (struct mg_ws_message *) ev_data; mg_ws_send(c, wm->data.ptr, wm->data.len, WEBSOCKET_OP_BINARY); @@ -441,29 +443,34 @@ static int cmpbody(const char *buf, const char *str) { static void wcb(struct mg_connection *c, int ev, void *ev_data, void *fn_data) { if (ev == MG_EV_WS_OPEN) { + struct mg_http_message *hm = (struct mg_http_message *) ev_data; + struct mg_str *wsproto = mg_http_get_header(hm, "Sec-WebSocket-Protocol"); + ASSERT(wsproto != NULL); mg_ws_send(c, "boo", 3, WEBSOCKET_OP_BINARY); mg_ws_send(c, "", 0, WEBSOCKET_OP_PING); + ((int *) fn_data)[0] += 100; } else if (ev == MG_EV_WS_MSG) { struct mg_ws_message *wm = (struct mg_ws_message *) ev_data; - ASSERT(mg_strstr(wm->data, mg_str("boo"))); - mg_ws_send(c, "", 0, WEBSOCKET_OP_CLOSE); // Ask server to close - *(int *) fn_data = 1; + if (mg_strstr(wm->data, mg_str("boo"))) + mg_ws_send(c, "", 0, WEBSOCKET_OP_CLOSE); + ((int *) fn_data)[0]++; } else if (ev == MG_EV_CLOSE) { - *(int *) fn_data = 2; + ((int *) fn_data)[0] += 10; } } static void test_ws(void) { char buf[FETCH_BUF_SIZE]; - const char *url = "ws://LOCALHOST:12343"; + const char *url = "ws://LOCALHOST:12343/ws"; struct mg_mgr mgr; int i, done = 0; mg_mgr_init(&mgr); ASSERT(mg_http_listen(&mgr, url, eh1, NULL) != NULL); - mg_ws_connect(&mgr, url, wcb, &done, "%s", ""); - for (i = 0; i < 20; i++) mg_mgr_poll(&mgr, 1); - ASSERT(done == 2); + mg_ws_connect(&mgr, url, wcb, &done, "%s", "Sec-WebSocket-Protocol: meh\r\n"); + for (i = 0; i < 30; i++) mg_mgr_poll(&mgr, 1); + // LOG(LL_INFO, ("--> %d", done)); + ASSERT(done == 112); // Test that non-WS requests fail ASSERT(fetch(&mgr, buf, url, "GET /ws HTTP/1.0\r\n\n") == 426); @@ -879,11 +886,11 @@ static void test_http_parse(void) { } { - static const char *s = "a b c\na:1\nb:2\nc:3\nd:4\ne:5\nf:6\n\n"; + static const char *s = "a b c\na:1\nb:2\nc:3\nd:4\ne:5\nf:6\ng:7\nh:8\n\n"; ASSERT(mg_http_parse(s, strlen(s), &req) == (int) strlen(s)); ASSERT((v = mg_http_get_header(&req, "e")) != NULL); ASSERT(mg_vcmp(v, "5") == 0); - ASSERT((v = mg_http_get_header(&req, "f")) == NULL); + ASSERT((v = mg_http_get_header(&req, "h")) == NULL); } {