1#![allow(unexpected_cfgs)]
24
25mod web_context;
26
27use std::{
28 cmp::min,
29 pin::Pin,
30 rc::Rc,
31 sync::{
32 atomic::{AtomicBool, Ordering},
33 Mutex,
34 },
35 task::{Context, Poll},
36};
37
38use bytes::BytesMut;
39use futures::{future::Ready, io, prelude::*, task::AtomicWaker};
40use js_sys::Array;
41use libp2p_core::{
42 multiaddr::{Multiaddr, Protocol},
43 transport::{DialOpts, ListenerId, TransportError, TransportEvent},
44};
45use send_wrapper::SendWrapper;
46use wasm_bindgen::prelude::*;
47use web_sys::{CloseEvent, Event, MessageEvent, WebSocket};
48
49use crate::web_context::WebContext;
50
51#[derive(Default)]
70pub struct Transport {
71 _private: (),
72}
73
74const MAX_BUFFER: usize = 1024 * 1024;
76
77impl libp2p_core::Transport for Transport {
78 type Output = Connection;
79 type Error = Error;
80 type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
81 type Dial = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;
82
83 fn listen_on(
84 &mut self,
85 _: ListenerId,
86 addr: Multiaddr,
87 ) -> Result<(), TransportError<Self::Error>> {
88 Err(TransportError::MultiaddrNotSupported(addr))
89 }
90
91 fn remove_listener(&mut self, _id: ListenerId) -> bool {
92 false
93 }
94
95 fn dial(
96 &mut self,
97 addr: Multiaddr,
98 dial_opts: DialOpts,
99 ) -> Result<Self::Dial, TransportError<Self::Error>> {
100 if dial_opts.role.is_listener() {
101 return Err(TransportError::MultiaddrNotSupported(addr));
102 }
103
104 let url =
105 extract_websocket_url(&addr).ok_or(TransportError::MultiaddrNotSupported(addr))?;
106
107 Ok(async move {
108 let socket = match WebSocket::new(&url) {
109 Ok(ws) => ws,
110 Err(_) => return Err(Error::invalid_websocket_url(&url)),
111 };
112
113 Ok(Connection::new(socket))
114 }
115 .boxed())
116 }
117
118 fn poll(
119 self: Pin<&mut Self>,
120 _cx: &mut Context<'_>,
121 ) -> std::task::Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
122 Poll::Pending
123 }
124}
125
126fn extract_websocket_url(addr: &Multiaddr) -> Option<String> {
128 let mut protocols = addr.iter();
129 let host_port = match (protocols.next(), protocols.next()) {
130 (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
131 format!("{ip}:{port}")
132 }
133 (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
134 format!("[{ip}]:{port}")
135 }
136 (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
137 | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
138 | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
139 format!("{}:{}", &h, port)
140 }
141 _ => return None,
142 };
143
144 let (scheme, wspath) = match (protocols.next(), protocols.next()) {
145 (Some(Protocol::Tls), Some(Protocol::Ws(path))) => ("wss", path.into_owned()),
146 (Some(Protocol::Ws(path)), _) => ("ws", path.into_owned()),
147 (Some(Protocol::Wss(path)), _) => ("wss", path.into_owned()),
148 _ => return None,
149 };
150
151 Some(format!("{scheme}://{host_port}{wspath}"))
152}
153
154#[derive(thiserror::Error, Debug)]
155#[error("{msg}")]
156pub struct Error {
157 msg: String,
158}
159
160impl Error {
161 fn invalid_websocket_url(url: &str) -> Self {
162 Self {
163 msg: format!("Invalid websocket url: {url}"),
164 }
165 }
166}
167
168pub struct Connection {
170 inner: SendWrapper<Inner>,
171}
172
173struct Inner {
174 socket: WebSocket,
175
176 new_data_waker: Rc<AtomicWaker>,
177 read_buffer: Rc<Mutex<BytesMut>>,
178
179 open_waker: Rc<AtomicWaker>,
181
182 write_waker: Rc<AtomicWaker>,
185
186 close_waker: Rc<AtomicWaker>,
188
189 errored: Rc<AtomicBool>,
191
192 _on_open_closure: Rc<Closure<dyn FnMut(Event)>>,
195 _on_buffered_amount_low_closure: Rc<Closure<dyn FnMut(Event)>>,
196 _on_close_closure: Rc<Closure<dyn FnMut(CloseEvent)>>,
197 _on_error_closure: Rc<Closure<dyn FnMut(CloseEvent)>>,
198 _on_message_closure: Rc<Closure<dyn FnMut(MessageEvent)>>,
199 buffered_amount_low_interval: i32,
200}
201
202impl Inner {
203 fn ready_state(&self) -> ReadyState {
204 match self.socket.ready_state() {
205 0 => ReadyState::Connecting,
206 1 => ReadyState::Open,
207 2 => ReadyState::Closing,
208 3 => ReadyState::Closed,
209 unknown => unreachable!("invalid `ReadyState` value: {unknown}"),
210 }
211 }
212
213 fn poll_open(&mut self, cx: &Context<'_>) -> Poll<io::Result<()>> {
214 match self.ready_state() {
215 ReadyState::Connecting => {
216 self.open_waker.register(cx.waker());
217 Poll::Pending
218 }
219 ReadyState::Open => Poll::Ready(Ok(())),
220 ReadyState::Closed | ReadyState::Closing => {
221 Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
222 }
223 }
224 }
225
226 fn error_barrier(&self) -> io::Result<()> {
227 if self.errored.load(Ordering::SeqCst) {
228 return Err(io::ErrorKind::BrokenPipe.into());
229 }
230
231 Ok(())
232 }
233}
234
235#[derive(PartialEq)]
239enum ReadyState {
240 Connecting,
241 Open,
242 Closing,
243 Closed,
244}
245
246impl Connection {
247 fn new(socket: WebSocket) -> Self {
248 socket.set_binary_type(web_sys::BinaryType::Arraybuffer);
249
250 let open_waker = Rc::new(AtomicWaker::new());
251 let onopen_closure = Closure::<dyn FnMut(_)>::new({
252 let open_waker = open_waker.clone();
253 move |_| {
254 open_waker.wake();
255 }
256 });
257 socket.set_onopen(Some(onopen_closure.as_ref().unchecked_ref()));
258
259 let close_waker = Rc::new(AtomicWaker::new());
260 let onclose_closure = Closure::<dyn FnMut(_)>::new({
261 let close_waker = close_waker.clone();
262 move |_| {
263 close_waker.wake();
264 }
265 });
266 socket.set_onclose(Some(onclose_closure.as_ref().unchecked_ref()));
267
268 let errored = Rc::new(AtomicBool::new(false));
269 let onerror_closure = Closure::<dyn FnMut(_)>::new({
270 let errored = errored.clone();
271 move |_| {
272 errored.store(true, Ordering::SeqCst);
273 }
274 });
275 socket.set_onerror(Some(onerror_closure.as_ref().unchecked_ref()));
276
277 let read_buffer = Rc::new(Mutex::new(BytesMut::new()));
278 let new_data_waker = Rc::new(AtomicWaker::new());
279 let onmessage_closure = Closure::<dyn FnMut(_)>::new({
280 let read_buffer = read_buffer.clone();
281 let new_data_waker = new_data_waker.clone();
282 let errored = errored.clone();
283 move |e: MessageEvent| {
284 let data = js_sys::Uint8Array::new(&e.data());
285
286 let mut read_buffer = read_buffer.lock().unwrap();
287
288 if read_buffer.len() + data.length() as usize > MAX_BUFFER {
289 tracing::warn!("Remote is overloading us with messages, closing connection");
290 errored.store(true, Ordering::SeqCst);
291
292 return;
293 }
294
295 read_buffer.extend_from_slice(&data.to_vec());
296 new_data_waker.wake();
297 }
298 });
299 socket.set_onmessage(Some(onmessage_closure.as_ref().unchecked_ref()));
300
301 let write_waker = Rc::new(AtomicWaker::new());
302 let on_buffered_amount_low_closure = Closure::<dyn FnMut(_)>::new({
303 let write_waker = write_waker.clone();
304 let socket = socket.clone();
305 move |_| {
306 if socket.buffered_amount() == 0 {
307 write_waker.wake();
308 }
309 }
310 });
311 let buffered_amount_low_interval = WebContext::new()
312 .expect("to have a window or worker context")
313 .set_interval_with_callback_and_timeout_and_arguments(
314 on_buffered_amount_low_closure.as_ref().unchecked_ref(),
315 100,
318 &Array::new(),
319 )
320 .expect("to be able to set an interval");
321
322 Self {
323 inner: SendWrapper::new(Inner {
324 socket,
325 new_data_waker,
326 read_buffer,
327 open_waker,
328 write_waker,
329 close_waker,
330 errored,
331 _on_open_closure: Rc::new(onopen_closure),
332 _on_buffered_amount_low_closure: Rc::new(on_buffered_amount_low_closure),
333 _on_close_closure: Rc::new(onclose_closure),
334 _on_error_closure: Rc::new(onerror_closure),
335 _on_message_closure: Rc::new(onmessage_closure),
336 buffered_amount_low_interval,
337 }),
338 }
339 }
340
341 fn buffered_amount(&self) -> usize {
342 self.inner.socket.buffered_amount() as usize
343 }
344}
345
346impl AsyncRead for Connection {
347 fn poll_read(
348 self: Pin<&mut Self>,
349 cx: &mut Context<'_>,
350 buf: &mut [u8],
351 ) -> Poll<Result<usize, io::Error>> {
352 let this = self.get_mut();
353 this.inner.error_barrier()?;
354 futures::ready!(this.inner.poll_open(cx))?;
355
356 let mut read_buffer = this.inner.read_buffer.lock().unwrap();
357
358 if read_buffer.is_empty() {
359 this.inner.new_data_waker.register(cx.waker());
360 return Poll::Pending;
361 }
362
363 let split_index = min(buf.len(), read_buffer.len());
367
368 let bytes_to_return = read_buffer.split_to(split_index);
369 let len = bytes_to_return.len();
370 buf[..len].copy_from_slice(&bytes_to_return);
371
372 Poll::Ready(Ok(len))
373 }
374}
375
376impl AsyncWrite for Connection {
377 fn poll_write(
378 self: Pin<&mut Self>,
379 cx: &mut Context<'_>,
380 buf: &[u8],
381 ) -> Poll<io::Result<usize>> {
382 let this = self.get_mut();
383
384 this.inner.error_barrier()?;
385 futures::ready!(this.inner.poll_open(cx))?;
386
387 debug_assert!(this.buffered_amount() <= MAX_BUFFER);
388 let remaining_space = MAX_BUFFER - this.buffered_amount();
389
390 if remaining_space == 0 {
391 this.inner.write_waker.register(cx.waker());
392 return Poll::Pending;
393 }
394
395 let bytes_to_send = min(buf.len(), remaining_space);
396
397 if this
398 .inner
399 .socket
400 .send_with_u8_array(&buf[..bytes_to_send])
401 .is_err()
402 {
403 return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
404 }
405
406 Poll::Ready(Ok(bytes_to_send))
407 }
408
409 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
410 if self.buffered_amount() == 0 {
411 return Poll::Ready(Ok(()));
412 }
413
414 self.inner.error_barrier()?;
415
416 self.inner.write_waker.register(cx.waker());
417 Poll::Pending
418 }
419
420 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
421 const REGULAR_CLOSE: u16 = 1000; if self.inner.ready_state() == ReadyState::Closed {
424 return Poll::Ready(Ok(()));
425 }
426
427 self.inner.error_barrier()?;
428
429 if self.inner.ready_state() != ReadyState::Closing {
430 let _ = self
431 .inner
432 .socket
433 .close_with_code_and_reason(REGULAR_CLOSE, "user initiated");
434 }
435
436 self.inner.close_waker.register(cx.waker());
437 Poll::Pending
438 }
439}
440
441impl Drop for Connection {
442 fn drop(&mut self) {
443 self.inner.socket.set_onclose(None);
446 self.inner.socket.set_onerror(None);
447 self.inner.socket.set_onopen(None);
448 self.inner.socket.set_onmessage(None);
449
450 const REGULAR_CLOSE: u16 = 1000; if let ReadyState::Connecting | ReadyState::Open = self.inner.ready_state() {
454 let _ = self
455 .inner
456 .socket
457 .close_with_code_and_reason(REGULAR_CLOSE, "connection dropped");
458 }
459
460 WebContext::new()
461 .expect("to have a window or worker context")
462 .clear_interval_with_handle(self.inner.buffered_amount_low_interval);
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use libp2p_identity::PeerId;
469
470 use super::*;
471
472 #[test]
473 fn extract_url() {
474 let peer_id = PeerId::random();
475
476 let addr = "/dns4/example.com/tcp/2222/tls/ws"
478 .parse::<Multiaddr>()
479 .unwrap();
480 let url = extract_websocket_url(&addr).unwrap();
481 assert_eq!(url, "wss://example.com:2222/");
482
483 let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
485 .parse()
486 .unwrap();
487 let url = extract_websocket_url(&addr).unwrap();
488 assert_eq!(url, "wss://example.com:2222/");
489
490 let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
492 .parse::<Multiaddr>()
493 .unwrap();
494 let url = extract_websocket_url(&addr).unwrap();
495 assert_eq!(url, "wss://127.0.0.1:2222/");
496
497 let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
499 let url = extract_websocket_url(&addr).unwrap();
500 assert_eq!(url, "wss://[::1]:2222/");
501
502 let addr = "/dns4/example.com/tcp/2222/wss"
504 .parse::<Multiaddr>()
505 .unwrap();
506 let url = extract_websocket_url(&addr).unwrap();
507 assert_eq!(url, "wss://example.com:2222/");
508
509 let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
511 .parse()
512 .unwrap();
513 let url = extract_websocket_url(&addr).unwrap();
514 assert_eq!(url, "wss://example.com:2222/");
515
516 let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
518 let url = extract_websocket_url(&addr).unwrap();
519 assert_eq!(url, "wss://127.0.0.1:2222/");
520
521 let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
523 let url = extract_websocket_url(&addr).unwrap();
524 assert_eq!(url, "wss://[::1]:2222/");
525
526 let addr = "/dns4/example.com/tcp/2222/ws"
528 .parse::<Multiaddr>()
529 .unwrap();
530 let url = extract_websocket_url(&addr).unwrap();
531 assert_eq!(url, "ws://example.com:2222/");
532
533 let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
535 .parse()
536 .unwrap();
537 let url = extract_websocket_url(&addr).unwrap();
538 assert_eq!(url, "ws://example.com:2222/");
539
540 let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
542 let url = extract_websocket_url(&addr).unwrap();
543 assert_eq!(url, "ws://127.0.0.1:2222/");
544
545 let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
547 let url = extract_websocket_url(&addr).unwrap();
548 assert_eq!(url, "ws://[::1]:2222/");
549
550 let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
552 let url = extract_websocket_url(&addr).unwrap();
553 assert_eq!(url, "ws://127.0.0.1:2222/");
554
555 let addr = "/ip4/127.0.0.1/tcp/2222/tls/wss"
557 .parse::<Multiaddr>()
558 .unwrap();
559 assert!(extract_websocket_url(&addr).is_none());
560
561 let addr = "/dnsaddr/example.com/tcp/2222/ws"
563 .parse::<Multiaddr>()
564 .unwrap();
565 assert!(extract_websocket_url(&addr).is_none());
566
567 let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
569 assert!(extract_websocket_url(&addr).is_none());
570 }
571}