libp2p_websocket_websys/
lib.rs

1// Copyright (C) 2023 Vince Vasta
2//
3// Permission is hereby granted, free of charge, to any person obtaining a copy
4// of this software and associated documentation files (the "Software"), to deal
5// in the Software without restriction, including without limitation the rights
6// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7// copies of the Software, and to permit persons to whom the Software is
8// furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in all
11// copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19// SOFTWARE.
20
21//! Libp2p websocket transports built on [web-sys](https://rustwasm.github.io/wasm-bindgen/web-sys/index.html).
22
23#![allow(unexpected_cfgs)]
24
25mod web_context;
26
27use std::{
28    cmp::min,
29    pin::Pin,
30    rc::Rc,
31    sync::{
32        atomic::{AtomicBool, Ordering},
33        Mutex,
34    },
35    task::{Context, Poll},
36};
37
38use bytes::BytesMut;
39use futures::{future::Ready, io, prelude::*, task::AtomicWaker};
40use js_sys::Array;
41use libp2p_core::{
42    multiaddr::{Multiaddr, Protocol},
43    transport::{DialOpts, ListenerId, TransportError, TransportEvent},
44};
45use send_wrapper::SendWrapper;
46use wasm_bindgen::prelude::*;
47use web_sys::{CloseEvent, Event, MessageEvent, WebSocket};
48
49use crate::web_context::WebContext;
50
51/// A Websocket transport that can be used in a wasm environment.
52///
53/// ## Example
54///
55/// To create an authenticated transport instance with Noise protocol and Yamux:
56///
57/// ```
58/// # use libp2p_core::{upgrade::Version, Transport};
59/// # use libp2p_identity::Keypair;
60/// # use libp2p_yamux as yamux;
61/// # use libp2p_noise as noise;
62/// let local_key = Keypair::generate_ed25519();
63/// let transport = libp2p_websocket_websys::Transport::default()
64///     .upgrade(Version::V1)
65///     .authenticate(noise::Config::new(&local_key).unwrap())
66///     .multiplex(yamux::Config::default())
67///     .boxed();
68/// ```
69#[derive(Default)]
70pub struct Transport {
71    _private: (),
72}
73
74/// Arbitrary, maximum amount we are willing to buffer before we throttle our user.
75const MAX_BUFFER: usize = 1024 * 1024;
76
77impl libp2p_core::Transport for Transport {
78    type Output = Connection;
79    type Error = Error;
80    type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
81    type Dial = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
82
83    fn listen_on(
84        &mut self,
85        _: ListenerId,
86        addr: Multiaddr,
87    ) -> Result<(), TransportError<Self::Error>> {
88        Err(TransportError::MultiaddrNotSupported(addr))
89    }
90
91    fn remove_listener(&mut self, _id: ListenerId) -> bool {
92        false
93    }
94
95    fn dial(
96        &mut self,
97        addr: Multiaddr,
98        dial_opts: DialOpts,
99    ) -> Result<Self::Dial, TransportError<Self::Error>> {
100        if dial_opts.role.is_listener() {
101            return Err(TransportError::MultiaddrNotSupported(addr));
102        }
103
104        let url =
105            extract_websocket_url(&addr).ok_or(TransportError::MultiaddrNotSupported(addr))?;
106
107        Ok(async move {
108            let socket = match WebSocket::new(&url) {
109                Ok(ws) => ws,
110                Err(_) => return Err(Error::invalid_websocket_url(&url)),
111            };
112
113            Ok(Connection::new(socket))
114        }
115        .boxed())
116    }
117
118    fn poll(
119        self: Pin<&mut Self>,
120        _cx: &mut Context<'_>,
121    ) -> std::task::Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
122        Poll::Pending
123    }
124}
125
126// Try to convert Multiaddr to a Websocket url.
127fn extract_websocket_url(addr: &Multiaddr) -> Option<String> {
128    let mut protocols = addr.iter();
129    let host_port = match (protocols.next(), protocols.next()) {
130        (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
131            format!("{ip}:{port}")
132        }
133        (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
134            format!("[{ip}]:{port}")
135        }
136        (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
137        | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
138        | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
139            format!("{}:{}", &h, port)
140        }
141        _ => return None,
142    };
143
144    let (scheme, wspath) = match (protocols.next(), protocols.next()) {
145        (Some(Protocol::Tls), Some(Protocol::Ws(path))) => ("wss", path.into_owned()),
146        (Some(Protocol::Ws(path)), _) => ("ws", path.into_owned()),
147        (Some(Protocol::Wss(path)), _) => ("wss", path.into_owned()),
148        _ => return None,
149    };
150
151    Some(format!("{scheme}://{host_port}{wspath}"))
152}
153
154#[derive(thiserror::Error, Debug)]
155#[error("{msg}")]
156pub struct Error {
157    msg: String,
158}
159
160impl Error {
161    fn invalid_websocket_url(url: &str) -> Self {
162        Self {
163            msg: format!("Invalid websocket url: {url}"),
164        }
165    }
166}
167
168/// A Websocket connection created by the [`Transport`].
169pub struct Connection {
170    inner: SendWrapper<Inner>,
171}
172
173struct Inner {
174    socket: WebSocket,
175
176    new_data_waker: Rc<AtomicWaker>,
177    read_buffer: Rc<Mutex<BytesMut>>,
178
179    /// Waker for when we are waiting for the WebSocket to be opened.
180    open_waker: Rc<AtomicWaker>,
181
182    /// Waker for when we are waiting to write (again) to the WebSocket because we previously
183    /// exceeded the [`MAX_BUFFER`] threshold.
184    write_waker: Rc<AtomicWaker>,
185
186    /// Waker for when we are waiting for the WebSocket to be closed.
187    close_waker: Rc<AtomicWaker>,
188
189    /// Whether the connection errored.
190    errored: Rc<AtomicBool>,
191
192    // Store the closures for proper garbage collection.
193    // These are wrapped in an [`Rc`] so we can implement [`Clone`].
194    _on_open_closure: Rc<Closure<dyn FnMut(Event)>>,
195    _on_buffered_amount_low_closure: Rc<Closure<dyn FnMut(Event)>>,
196    _on_close_closure: Rc<Closure<dyn FnMut(CloseEvent)>>,
197    _on_error_closure: Rc<Closure<dyn FnMut(CloseEvent)>>,
198    _on_message_closure: Rc<Closure<dyn FnMut(MessageEvent)>>,
199    buffered_amount_low_interval: i32,
200}
201
202impl Inner {
203    fn ready_state(&self) -> ReadyState {
204        match self.socket.ready_state() {
205            0 => ReadyState::Connecting,
206            1 => ReadyState::Open,
207            2 => ReadyState::Closing,
208            3 => ReadyState::Closed,
209            unknown => unreachable!("invalid `ReadyState` value: {unknown}"),
210        }
211    }
212
213    fn poll_open(&mut self, cx: &Context<'_>) -> Poll<io::Result<()>> {
214        match self.ready_state() {
215            ReadyState::Connecting => {
216                self.open_waker.register(cx.waker());
217                Poll::Pending
218            }
219            ReadyState::Open => Poll::Ready(Ok(())),
220            ReadyState::Closed | ReadyState::Closing => {
221                Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
222            }
223        }
224    }
225
226    fn error_barrier(&self) -> io::Result<()> {
227        if self.errored.load(Ordering::SeqCst) {
228            return Err(io::ErrorKind::BrokenPipe.into());
229        }
230
231        Ok(())
232    }
233}
234
235/// The state of the WebSocket.
236///
237/// See <https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/readyState>.
238#[derive(PartialEq)]
239enum ReadyState {
240    Connecting,
241    Open,
242    Closing,
243    Closed,
244}
245
246impl Connection {
247    fn new(socket: WebSocket) -> Self {
248        socket.set_binary_type(web_sys::BinaryType::Arraybuffer);
249
250        let open_waker = Rc::new(AtomicWaker::new());
251        let onopen_closure = Closure::<dyn FnMut(_)>::new({
252            let open_waker = open_waker.clone();
253            move |_| {
254                open_waker.wake();
255            }
256        });
257        socket.set_onopen(Some(onopen_closure.as_ref().unchecked_ref()));
258
259        let close_waker = Rc::new(AtomicWaker::new());
260        let onclose_closure = Closure::<dyn FnMut(_)>::new({
261            let close_waker = close_waker.clone();
262            move |_| {
263                close_waker.wake();
264            }
265        });
266        socket.set_onclose(Some(onclose_closure.as_ref().unchecked_ref()));
267
268        let errored = Rc::new(AtomicBool::new(false));
269        let onerror_closure = Closure::<dyn FnMut(_)>::new({
270            let errored = errored.clone();
271            move |_| {
272                errored.store(true, Ordering::SeqCst);
273            }
274        });
275        socket.set_onerror(Some(onerror_closure.as_ref().unchecked_ref()));
276
277        let read_buffer = Rc::new(Mutex::new(BytesMut::new()));
278        let new_data_waker = Rc::new(AtomicWaker::new());
279        let onmessage_closure = Closure::<dyn FnMut(_)>::new({
280            let read_buffer = read_buffer.clone();
281            let new_data_waker = new_data_waker.clone();
282            let errored = errored.clone();
283            move |e: MessageEvent| {
284                let data = js_sys::Uint8Array::new(&e.data());
285
286                let mut read_buffer = read_buffer.lock().unwrap();
287
288                if read_buffer.len() + data.length() as usize > MAX_BUFFER {
289                    tracing::warn!("Remote is overloading us with messages, closing connection");
290                    errored.store(true, Ordering::SeqCst);
291
292                    return;
293                }
294
295                read_buffer.extend_from_slice(&data.to_vec());
296                new_data_waker.wake();
297            }
298        });
299        socket.set_onmessage(Some(onmessage_closure.as_ref().unchecked_ref()));
300
301        let write_waker = Rc::new(AtomicWaker::new());
302        let on_buffered_amount_low_closure = Closure::<dyn FnMut(_)>::new({
303            let write_waker = write_waker.clone();
304            let socket = socket.clone();
305            move |_| {
306                if socket.buffered_amount() == 0 {
307                    write_waker.wake();
308                }
309            }
310        });
311        let buffered_amount_low_interval = WebContext::new()
312            .expect("to have a window or worker context")
313            .set_interval_with_callback_and_timeout_and_arguments(
314                on_buffered_amount_low_closure.as_ref().unchecked_ref(),
315                // Chosen arbitrarily and likely worth tuning. Due to low impact of the /ws
316                // transport, no further effort was invested at the time.
317                100,
318                &Array::new(),
319            )
320            .expect("to be able to set an interval");
321
322        Self {
323            inner: SendWrapper::new(Inner {
324                socket,
325                new_data_waker,
326                read_buffer,
327                open_waker,
328                write_waker,
329                close_waker,
330                errored,
331                _on_open_closure: Rc::new(onopen_closure),
332                _on_buffered_amount_low_closure: Rc::new(on_buffered_amount_low_closure),
333                _on_close_closure: Rc::new(onclose_closure),
334                _on_error_closure: Rc::new(onerror_closure),
335                _on_message_closure: Rc::new(onmessage_closure),
336                buffered_amount_low_interval,
337            }),
338        }
339    }
340
341    fn buffered_amount(&self) -> usize {
342        self.inner.socket.buffered_amount() as usize
343    }
344}
345
346impl AsyncRead for Connection {
347    fn poll_read(
348        self: Pin<&mut Self>,
349        cx: &mut Context<'_>,
350        buf: &mut [u8],
351    ) -> Poll<Result<usize, io::Error>> {
352        let this = self.get_mut();
353        this.inner.error_barrier()?;
354        futures::ready!(this.inner.poll_open(cx))?;
355
356        let mut read_buffer = this.inner.read_buffer.lock().unwrap();
357
358        if read_buffer.is_empty() {
359            this.inner.new_data_waker.register(cx.waker());
360            return Poll::Pending;
361        }
362
363        // Ensure that we:
364        // - at most return what the caller can read (`buf.len()`)
365        // - at most what we have (`read_buffer.len()`)
366        let split_index = min(buf.len(), read_buffer.len());
367
368        let bytes_to_return = read_buffer.split_to(split_index);
369        let len = bytes_to_return.len();
370        buf[..len].copy_from_slice(&bytes_to_return);
371
372        Poll::Ready(Ok(len))
373    }
374}
375
376impl AsyncWrite for Connection {
377    fn poll_write(
378        self: Pin<&mut Self>,
379        cx: &mut Context<'_>,
380        buf: &[u8],
381    ) -> Poll<io::Result<usize>> {
382        let this = self.get_mut();
383
384        this.inner.error_barrier()?;
385        futures::ready!(this.inner.poll_open(cx))?;
386
387        debug_assert!(this.buffered_amount() <= MAX_BUFFER);
388        let remaining_space = MAX_BUFFER - this.buffered_amount();
389
390        if remaining_space == 0 {
391            this.inner.write_waker.register(cx.waker());
392            return Poll::Pending;
393        }
394
395        let bytes_to_send = min(buf.len(), remaining_space);
396
397        if this
398            .inner
399            .socket
400            .send_with_u8_array(&buf[..bytes_to_send])
401            .is_err()
402        {
403            return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
404        }
405
406        Poll::Ready(Ok(bytes_to_send))
407    }
408
409    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
410        if self.buffered_amount() == 0 {
411            return Poll::Ready(Ok(()));
412        }
413
414        self.inner.error_barrier()?;
415
416        self.inner.write_waker.register(cx.waker());
417        Poll::Pending
418    }
419
420    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
421        const REGULAR_CLOSE: u16 = 1000; // See https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1.
422
423        if self.inner.ready_state() == ReadyState::Closed {
424            return Poll::Ready(Ok(()));
425        }
426
427        self.inner.error_barrier()?;
428
429        if self.inner.ready_state() != ReadyState::Closing {
430            let _ = self
431                .inner
432                .socket
433                .close_with_code_and_reason(REGULAR_CLOSE, "user initiated");
434        }
435
436        self.inner.close_waker.register(cx.waker());
437        Poll::Pending
438    }
439}
440
441impl Drop for Connection {
442    fn drop(&mut self) {
443        // Unset event listeners, as otherwise they will be called by JS after the handlers have
444        // already been dropped.
445        self.inner.socket.set_onclose(None);
446        self.inner.socket.set_onerror(None);
447        self.inner.socket.set_onopen(None);
448        self.inner.socket.set_onmessage(None);
449
450        // In browsers, userland code is not allowed to use any other status code than 1000: https://websockets.spec.whatwg.org/#dom-websocket-close
451        const REGULAR_CLOSE: u16 = 1000; // See https://www.rfc-editor.org/rfc/rfc6455.html#section-7.4.1.
452
453        if let ReadyState::Connecting | ReadyState::Open = self.inner.ready_state() {
454            let _ = self
455                .inner
456                .socket
457                .close_with_code_and_reason(REGULAR_CLOSE, "connection dropped");
458        }
459
460        WebContext::new()
461            .expect("to have a window or worker context")
462            .clear_interval_with_handle(self.inner.buffered_amount_low_interval);
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use libp2p_identity::PeerId;
469
470    use super::*;
471
472    #[test]
473    fn extract_url() {
474        let peer_id = PeerId::random();
475
476        // Check `/tls/ws`
477        let addr = "/dns4/example.com/tcp/2222/tls/ws"
478            .parse::<Multiaddr>()
479            .unwrap();
480        let url = extract_websocket_url(&addr).unwrap();
481        assert_eq!(url, "wss://example.com:2222/");
482
483        // Check `/tls/ws` with `/p2p`
484        let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
485            .parse()
486            .unwrap();
487        let url = extract_websocket_url(&addr).unwrap();
488        assert_eq!(url, "wss://example.com:2222/");
489
490        // Check `/tls/ws` with `/ip4`
491        let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
492            .parse::<Multiaddr>()
493            .unwrap();
494        let url = extract_websocket_url(&addr).unwrap();
495        assert_eq!(url, "wss://127.0.0.1:2222/");
496
497        // Check `/tls/ws` with `/ip6`
498        let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
499        let url = extract_websocket_url(&addr).unwrap();
500        assert_eq!(url, "wss://[::1]:2222/");
501
502        // Check `/wss`
503        let addr = "/dns4/example.com/tcp/2222/wss"
504            .parse::<Multiaddr>()
505            .unwrap();
506        let url = extract_websocket_url(&addr).unwrap();
507        assert_eq!(url, "wss://example.com:2222/");
508
509        // Check `/wss` with `/p2p`
510        let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
511            .parse()
512            .unwrap();
513        let url = extract_websocket_url(&addr).unwrap();
514        assert_eq!(url, "wss://example.com:2222/");
515
516        // Check `/wss` with `/ip4`
517        let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
518        let url = extract_websocket_url(&addr).unwrap();
519        assert_eq!(url, "wss://127.0.0.1:2222/");
520
521        // Check `/wss` with `/ip6`
522        let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
523        let url = extract_websocket_url(&addr).unwrap();
524        assert_eq!(url, "wss://[::1]:2222/");
525
526        // Check `/ws`
527        let addr = "/dns4/example.com/tcp/2222/ws"
528            .parse::<Multiaddr>()
529            .unwrap();
530        let url = extract_websocket_url(&addr).unwrap();
531        assert_eq!(url, "ws://example.com:2222/");
532
533        // Check `/ws` with `/p2p`
534        let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
535            .parse()
536            .unwrap();
537        let url = extract_websocket_url(&addr).unwrap();
538        assert_eq!(url, "ws://example.com:2222/");
539
540        // Check `/ws` with `/ip4`
541        let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
542        let url = extract_websocket_url(&addr).unwrap();
543        assert_eq!(url, "ws://127.0.0.1:2222/");
544
545        // Check `/ws` with `/ip6`
546        let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
547        let url = extract_websocket_url(&addr).unwrap();
548        assert_eq!(url, "ws://[::1]:2222/");
549
550        // Check `/ws` with `/ip4`
551        let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
552        let url = extract_websocket_url(&addr).unwrap();
553        assert_eq!(url, "ws://127.0.0.1:2222/");
554
555        // Check that `/tls/wss` is invalid
556        let addr = "/ip4/127.0.0.1/tcp/2222/tls/wss"
557            .parse::<Multiaddr>()
558            .unwrap();
559        assert!(extract_websocket_url(&addr).is_none());
560
561        // Check `/dnsaddr`
562        let addr = "/dnsaddr/example.com/tcp/2222/ws"
563            .parse::<Multiaddr>()
564            .unwrap();
565        assert!(extract_websocket_url(&addr).is_none());
566
567        // Check non-ws address
568        let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
569        assert!(extract_websocket_url(&addr).is_none());
570    }
571}