libp2p_websocket/
framed.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    borrow::Cow,
23    collections::HashMap,
24    fmt, io, mem,
25    net::IpAddr,
26    ops::DerefMut,
27    pin::Pin,
28    sync::Arc,
29    task::{Context, Poll},
30};
31
32use either::Either;
33use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
34use futures_rustls::{client, rustls::pki_types::ServerName, server};
35use libp2p_core::{
36    multiaddr::{Multiaddr, Protocol},
37    transport::{DialOpts, ListenerId, TransportError, TransportEvent},
38    Transport,
39};
40use parking_lot::Mutex;
41use soketto::{
42    connection::{self, CloseReason},
43    handshake,
44};
45use url::Url;
46
47use crate::{error::Error, quicksink, tls};
48
49/// Max. number of payload bytes of a single frame.
50const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
51
52/// A Websocket transport whose output type is a [`Stream`] and [`Sink`] of
53/// frame payloads which does not implement [`AsyncRead`] or
54/// [`AsyncWrite`]. See [`crate::Config`] if you require the latter.
55#[deprecated = "Use `Config` instead"]
56pub type WsConfig<T> = Config<T>;
57
58#[derive(Debug)]
59pub struct Config<T> {
60    transport: Arc<Mutex<T>>,
61    max_data_size: usize,
62    tls_config: tls::Config,
63    max_redirects: u8,
64    /// Websocket protocol of the inner listener.
65    listener_protos: HashMap<ListenerId, WsListenProto<'static>>,
66}
67
68impl<T> Config<T>
69where
70    T: Send,
71{
72    /// Create a new websocket transport based on another transport.
73    pub fn new(transport: T) -> Self {
74        Config {
75            transport: Arc::new(Mutex::new(transport)),
76            max_data_size: MAX_DATA_SIZE,
77            tls_config: tls::Config::client(),
78            max_redirects: 0,
79            listener_protos: HashMap::new(),
80        }
81    }
82
83    /// Return the configured maximum number of redirects.
84    pub fn max_redirects(&self) -> u8 {
85        self.max_redirects
86    }
87
88    /// Set max. number of redirects to follow.
89    pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
90        self.max_redirects = max;
91        self
92    }
93
94    /// Get the max. frame data size we support.
95    pub fn max_data_size(&self) -> usize {
96        self.max_data_size
97    }
98
99    /// Set the max. frame data size we support.
100    pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
101        self.max_data_size = size;
102        self
103    }
104
105    /// Set the TLS configuration if TLS support is desired.
106    pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
107        self.tls_config = c;
108        self
109    }
110}
111
112type TlsOrPlain<T> = future::Either<future::Either<client::TlsStream<T>, server::TlsStream<T>>, T>;
113
114impl<T> Transport for Config<T>
115where
116    T: Transport + Send + Unpin + 'static,
117    T::Error: Send + 'static,
118    T::Dial: Send + 'static,
119    T::ListenerUpgrade: Send + 'static,
120    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
121{
122    type Output = Connection<T::Output>;
123    type Error = Error<T::Error>;
124    type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
125    type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
126
127    fn listen_on(
128        &mut self,
129        id: ListenerId,
130        addr: Multiaddr,
131    ) -> Result<(), TransportError<Self::Error>> {
132        let (inner_addr, proto) = parse_ws_listen_addr(&addr).ok_or_else(|| {
133            tracing::debug!(address=%addr, "Address is not a websocket multiaddr");
134            TransportError::MultiaddrNotSupported(addr.clone())
135        })?;
136
137        if proto.use_tls() && self.tls_config.server.is_none() {
138            tracing::debug!(
139                "{} address but TLS server support is not configured",
140                proto.prefix()
141            );
142            return Err(TransportError::MultiaddrNotSupported(addr));
143        }
144
145        match self.transport.lock().listen_on(id, inner_addr) {
146            Ok(()) => {
147                self.listener_protos.insert(id, proto);
148                Ok(())
149            }
150            Err(e) => Err(e.map(Error::Transport)),
151        }
152    }
153
154    fn remove_listener(&mut self, id: ListenerId) -> bool {
155        self.transport.lock().remove_listener(id)
156    }
157
158    fn dial(
159        &mut self,
160        addr: Multiaddr,
161        dial_opts: DialOpts,
162    ) -> Result<Self::Dial, TransportError<Self::Error>> {
163        self.do_dial(addr, dial_opts)
164    }
165
166    fn poll(
167        mut self: Pin<&mut Self>,
168        cx: &mut Context<'_>,
169    ) -> Poll<libp2p_core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>> {
170        let inner_event = {
171            let mut transport = self.transport.lock();
172            match Transport::poll(Pin::new(transport.deref_mut()), cx) {
173                Poll::Ready(ev) => ev,
174                Poll::Pending => return Poll::Pending,
175            }
176        };
177        let event = match inner_event {
178            TransportEvent::NewAddress {
179                listener_id,
180                mut listen_addr,
181            } => {
182                // Append the ws / wss protocol back to the inner address.
183                self.listener_protos
184                    .get(&listener_id)
185                    .expect("Protocol was inserted in Transport::listen_on.")
186                    .append_on_addr(&mut listen_addr);
187                tracing::debug!(address=%listen_addr, "Listening on address");
188                TransportEvent::NewAddress {
189                    listener_id,
190                    listen_addr,
191                }
192            }
193            TransportEvent::AddressExpired {
194                listener_id,
195                mut listen_addr,
196            } => {
197                self.listener_protos
198                    .get(&listener_id)
199                    .expect("Protocol was inserted in Transport::listen_on.")
200                    .append_on_addr(&mut listen_addr);
201                TransportEvent::AddressExpired {
202                    listener_id,
203                    listen_addr,
204                }
205            }
206            TransportEvent::ListenerError { listener_id, error } => TransportEvent::ListenerError {
207                listener_id,
208                error: Error::Transport(error),
209            },
210            TransportEvent::ListenerClosed {
211                listener_id,
212                reason,
213            } => {
214                self.listener_protos
215                    .remove(&listener_id)
216                    .expect("Protocol was inserted in Transport::listen_on.");
217                TransportEvent::ListenerClosed {
218                    listener_id,
219                    reason: reason.map_err(Error::Transport),
220                }
221            }
222            TransportEvent::Incoming {
223                listener_id,
224                upgrade,
225                mut local_addr,
226                mut send_back_addr,
227            } => {
228                let proto = self
229                    .listener_protos
230                    .get(&listener_id)
231                    .expect("Protocol was inserted in Transport::listen_on.");
232                let use_tls = proto.use_tls();
233                proto.append_on_addr(&mut local_addr);
234                proto.append_on_addr(&mut send_back_addr);
235                let upgrade = self.map_upgrade(upgrade, send_back_addr.clone(), use_tls);
236                TransportEvent::Incoming {
237                    listener_id,
238                    upgrade,
239                    local_addr,
240                    send_back_addr,
241                }
242            }
243        };
244        Poll::Ready(event)
245    }
246}
247
248impl<T> Config<T>
249where
250    T: Transport + Send + Unpin + 'static,
251    T::Error: Send + 'static,
252    T::Dial: Send + 'static,
253    T::ListenerUpgrade: Send + 'static,
254    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
255{
256    fn do_dial(
257        &mut self,
258        addr: Multiaddr,
259        dial_opts: DialOpts,
260    ) -> Result<<Self as Transport>::Dial, TransportError<<Self as Transport>::Error>> {
261        let mut addr = match parse_ws_dial_addr(addr) {
262            Ok(addr) => addr,
263            Err(Error::InvalidMultiaddr(a)) => {
264                return Err(TransportError::MultiaddrNotSupported(a))
265            }
266            Err(e) => return Err(TransportError::Other(e)),
267        };
268
269        // We are looping here in order to follow redirects (if any):
270        let mut remaining_redirects = self.max_redirects;
271
272        let transport = self.transport.clone();
273        let tls_config = self.tls_config.clone();
274        let max_redirects = self.max_redirects;
275
276        let future = async move {
277            loop {
278                match Self::dial_once(transport.clone(), addr, tls_config.clone(), dial_opts).await
279                {
280                    Ok(Either::Left(redirect)) => {
281                        if remaining_redirects == 0 {
282                            tracing::debug!(%max_redirects, "Too many redirects");
283                            return Err(Error::TooManyRedirects);
284                        }
285                        remaining_redirects -= 1;
286                        addr = parse_ws_dial_addr(location_to_multiaddr(&redirect)?)?
287                    }
288                    Ok(Either::Right(conn)) => return Ok(conn),
289                    Err(e) => return Err(e),
290                }
291            }
292        };
293
294        Ok(Box::pin(future))
295    }
296
297    /// Attempts to dial the given address and perform a websocket handshake.
298    async fn dial_once(
299        transport: Arc<Mutex<T>>,
300        addr: WsAddress,
301        tls_config: tls::Config,
302        dial_opts: DialOpts,
303    ) -> Result<Either<String, Connection<T::Output>>, Error<T::Error>> {
304        tracing::trace!(address=?addr, "Dialing websocket address");
305
306        let dial = transport
307            .lock()
308            .dial(addr.tcp_addr, dial_opts)
309            .map_err(|e| match e {
310                TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a),
311                TransportError::Other(e) => Error::Transport(e),
312            })?;
313
314        let stream = dial.map_err(Error::Transport).await?;
315        tracing::trace!(port=%addr.host_port, "TCP connection established");
316
317        let stream = if addr.use_tls {
318            // begin TLS session
319            tracing::trace!(?addr.server_name, "Starting TLS handshake");
320            let stream = tls_config
321                .client
322                .connect(addr.server_name.clone(), stream)
323                .map_err(|e| {
324                    tracing::debug!(?addr.server_name, "TLS handshake failed: {}", e);
325                    Error::Tls(tls::Error::from(e))
326                })
327                .await?;
328
329            let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Left(stream));
330            stream
331        } else {
332            // continue with plain stream
333            future::Either::Right(stream)
334        };
335
336        tracing::trace!(port=%addr.host_port, "Sending websocket handshake");
337
338        let mut client = handshake::Client::new(stream, &addr.host_port, addr.path.as_ref());
339
340        match client
341            .handshake()
342            .map_err(|e| Error::Handshake(Box::new(e)))
343            .await?
344        {
345            handshake::ServerResponse::Redirect {
346                status_code,
347                location,
348            } => {
349                tracing::debug!(
350                    %status_code,
351                    %location,
352                    "received redirect"
353                );
354                Ok(Either::Left(location))
355            }
356            handshake::ServerResponse::Rejected { status_code } => {
357                let msg = format!("server rejected handshake; status code = {status_code}");
358                Err(Error::Handshake(msg.into()))
359            }
360            handshake::ServerResponse::Accepted { .. } => {
361                tracing::trace!(port=%addr.host_port, "websocket handshake successful");
362                Ok(Either::Right(Connection::new(client.into_builder())))
363            }
364        }
365    }
366
367    fn map_upgrade(
368        &self,
369        upgrade: T::ListenerUpgrade,
370        remote_addr: Multiaddr,
371        use_tls: bool,
372    ) -> <Self as Transport>::ListenerUpgrade {
373        let remote_addr2 = remote_addr.clone(); // used for logging
374        let tls_config = self.tls_config.clone();
375        let max_size = self.max_data_size;
376
377        async move {
378            let stream = upgrade.map_err(Error::Transport).await?;
379            tracing::trace!(address=%remote_addr, "incoming connection from address");
380
381            let stream = if use_tls {
382                // begin TLS session
383                let server = tls_config
384                    .server
385                    .expect("for use_tls we checked server is not none");
386
387                tracing::trace!(address=%remote_addr, "awaiting TLS handshake with address");
388
389                let stream = server
390                    .accept(stream)
391                    .map_err(move |e| {
392                        tracing::debug!(address=%remote_addr, "TLS handshake with address failed: {}", e);
393                        Error::Tls(tls::Error::from(e))
394                    })
395                    .await?;
396
397                let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Right(stream));
398
399                stream
400            } else {
401                // continue with plain stream
402                future::Either::Right(stream)
403            };
404
405            tracing::trace!(
406                address=%remote_addr2,
407                "receiving websocket handshake request from address"
408            );
409
410            let mut server = handshake::Server::new(stream);
411
412            let ws_key = {
413                let request = server
414                    .receive_request()
415                    .map_err(|e| Error::Handshake(Box::new(e)))
416                    .await?;
417                request.key()
418            };
419
420            tracing::trace!(
421                address=%remote_addr2,
422                "accepting websocket handshake request from address"
423            );
424
425            let response = handshake::server::Response::Accept {
426                key: ws_key,
427                protocol: None,
428            };
429
430            server
431                .send_response(&response)
432                .map_err(|e| Error::Handshake(Box::new(e)))
433                .await?;
434
435            let conn = {
436                let mut builder = server.into_builder();
437                builder.set_max_message_size(max_size);
438                builder.set_max_frame_size(max_size);
439                Connection::new(builder)
440            };
441
442            Ok(conn)
443        }
444        .boxed()
445    }
446}
447
448#[derive(Debug, PartialEq)]
449pub(crate) enum WsListenProto<'a> {
450    Ws(Cow<'a, str>),
451    Wss(Cow<'a, str>),
452    TlsWs(Cow<'a, str>),
453}
454
455impl WsListenProto<'_> {
456    pub(crate) fn append_on_addr(&self, addr: &mut Multiaddr) {
457        match self {
458            WsListenProto::Ws(path) => {
459                addr.push(Protocol::Ws(path.clone()));
460            }
461            // `/tls/ws` and `/wss` are equivalend, however we regenerate
462            // the one that user passed at `listen_on` for backward compatibility.
463            WsListenProto::Wss(path) => {
464                addr.push(Protocol::Wss(path.clone()));
465            }
466            WsListenProto::TlsWs(path) => {
467                addr.push(Protocol::Tls);
468                addr.push(Protocol::Ws(path.clone()));
469            }
470        }
471    }
472
473    pub(crate) fn use_tls(&self) -> bool {
474        match self {
475            WsListenProto::Ws(_) => false,
476            WsListenProto::Wss(_) => true,
477            WsListenProto::TlsWs(_) => true,
478        }
479    }
480
481    pub(crate) fn prefix(&self) -> &'static str {
482        match self {
483            WsListenProto::Ws(_) => "/ws",
484            WsListenProto::Wss(_) => "/wss",
485            WsListenProto::TlsWs(_) => "/tls/ws",
486        }
487    }
488}
489
490#[derive(Debug)]
491struct WsAddress {
492    host_port: String,
493    path: String,
494    server_name: ServerName<'static>,
495    use_tls: bool,
496    tcp_addr: Multiaddr,
497}
498
499/// Tries to parse the given `Multiaddr` into a `WsAddress` used
500/// for dialing.
501///
502/// Fails if the given `Multiaddr` does not represent a TCP/IP-based
503/// websocket protocol stack.
504fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
505    // The encapsulating protocol must be based on TCP/IP, possibly via DNS.
506    // We peek at it in order to learn the hostname and port to use for
507    // the websocket handshake.
508    let mut protocols = addr.iter();
509    let mut ip = protocols.next();
510    let mut tcp = protocols.next();
511    let (host_port, server_name) = loop {
512        match (ip, tcp) {
513            (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
514                let server_name = ServerName::IpAddress(IpAddr::V4(ip).into());
515                break (format!("{ip}:{port}"), server_name);
516            }
517            (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
518                let server_name = ServerName::IpAddress(IpAddr::V6(ip).into());
519                break (format!("[{ip}]:{port}"), server_name);
520            }
521            (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
522            | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
523            | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
524                break (format!("{h}:{port}"), tls::dns_name_ref(&h)?)
525            }
526            (Some(_), Some(p)) => {
527                ip = Some(p);
528                tcp = protocols.next();
529            }
530            _ => return Err(Error::InvalidMultiaddr(addr)),
531        }
532    };
533
534    // Now consume the `Ws` / `Wss` protocol from the end of the address,
535    // preserving the trailing `P2p` protocol that identifies the remote,
536    // if any.
537    let mut protocols = addr.clone();
538    let mut p2p = None;
539    let (use_tls, path) = loop {
540        match protocols.pop() {
541            p @ Some(Protocol::P2p(_)) => p2p = p,
542            Some(Protocol::Ws(path)) => match protocols.pop() {
543                Some(Protocol::Tls) => break (true, path.into_owned()),
544                Some(p) => {
545                    protocols.push(p);
546                    break (false, path.into_owned());
547                }
548                None => return Err(Error::InvalidMultiaddr(addr)),
549            },
550            Some(Protocol::Wss(path)) => break (true, path.into_owned()),
551            _ => return Err(Error::InvalidMultiaddr(addr)),
552        }
553    };
554
555    // The original address, stripped of the `/ws` and `/wss` protocols,
556    // makes up the address for the inner TCP-based transport.
557    let tcp_addr = match p2p {
558        Some(p) => protocols.with(p),
559        None => protocols,
560    };
561
562    Ok(WsAddress {
563        host_port,
564        server_name,
565        path,
566        use_tls,
567        tcp_addr,
568    })
569}
570
571fn parse_ws_listen_addr(addr: &Multiaddr) -> Option<(Multiaddr, WsListenProto<'static>)> {
572    let mut inner_addr = addr.clone();
573
574    match inner_addr.pop()? {
575        Protocol::Wss(path) => Some((inner_addr, WsListenProto::Wss(path))),
576        Protocol::Ws(path) => match inner_addr.pop()? {
577            Protocol::Tls => Some((inner_addr, WsListenProto::TlsWs(path))),
578            p => {
579                inner_addr.push(p);
580                Some((inner_addr, WsListenProto::Ws(path)))
581            }
582        },
583        _ => None,
584    }
585}
586
587// Given a location URL, build a new websocket [`Multiaddr`].
588fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
589    match Url::parse(location) {
590        Ok(url) => {
591            let mut a = Multiaddr::empty();
592            match url.host() {
593                Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())),
594                Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)),
595                Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)),
596                None => return Err(Error::InvalidRedirectLocation),
597            }
598            if let Some(p) = url.port() {
599                a.push(Protocol::Tcp(p))
600            }
601            let s = url.scheme();
602            if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
603                a.push(Protocol::Tls);
604                a.push(Protocol::Ws(url.path().into()));
605            } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
606                a.push(Protocol::Ws(url.path().into()))
607            } else {
608                tracing::debug!(scheme=%s, "unsupported scheme");
609                return Err(Error::InvalidRedirectLocation);
610            }
611            Ok(a)
612        }
613        Err(e) => {
614            tracing::debug!("failed to parse url as multi-address: {:?}", e);
615            Err(Error::InvalidRedirectLocation)
616        }
617    }
618}
619
620/// The websocket connection.
621pub struct Connection<T> {
622    receiver: BoxStream<'static, Result<Incoming, connection::Error>>,
623    sender: Pin<Box<dyn Sink<OutgoingData, Error = quicksink::Error<connection::Error>> + Send>>,
624    _marker: std::marker::PhantomData<T>,
625}
626
627/// Data or control information received over the websocket connection.
628#[derive(Debug, Clone)]
629pub enum Incoming {
630    /// Application data.
631    Data(Data),
632    /// PONG control frame data.
633    Pong(Vec<u8>),
634    /// Close reason.
635    Closed(CloseReason),
636}
637
638/// Application data received over the websocket connection
639#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
640pub enum Data {
641    /// UTF-8 encoded textual data.
642    Text(Vec<u8>),
643    /// Binary data.
644    Binary(Vec<u8>),
645}
646
647impl Data {
648    pub fn into_bytes(self) -> Vec<u8> {
649        match self {
650            Data::Text(d) => d,
651            Data::Binary(d) => d,
652        }
653    }
654}
655
656impl AsRef<[u8]> for Data {
657    fn as_ref(&self) -> &[u8] {
658        match self {
659            Data::Text(d) => d,
660            Data::Binary(d) => d,
661        }
662    }
663}
664
665impl Incoming {
666    pub fn is_data(&self) -> bool {
667        self.is_binary() || self.is_text()
668    }
669
670    pub fn is_binary(&self) -> bool {
671        matches!(self, Incoming::Data(Data::Binary(_)))
672    }
673
674    pub fn is_text(&self) -> bool {
675        matches!(self, Incoming::Data(Data::Text(_)))
676    }
677
678    pub fn is_pong(&self) -> bool {
679        matches!(self, Incoming::Pong(_))
680    }
681
682    pub fn is_close(&self) -> bool {
683        matches!(self, Incoming::Closed(_))
684    }
685}
686
687/// Data sent over the websocket connection.
688#[derive(Debug, Clone)]
689pub enum OutgoingData {
690    /// Send some bytes.
691    Binary(Vec<u8>),
692    /// Send a PING message.
693    Ping(Vec<u8>),
694    /// Send an unsolicited PONG message.
695    /// (Incoming PINGs are answered automatically.)
696    Pong(Vec<u8>),
697}
698
699impl<T> fmt::Debug for Connection<T> {
700    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
701        f.write_str("Connection")
702    }
703}
704
705impl<T> Connection<T>
706where
707    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
708{
709    fn new(builder: connection::Builder<TlsOrPlain<T>>) -> Self {
710        let (sender, receiver) = builder.finish();
711        let sink = quicksink::make_sink(sender, |mut sender, action| async move {
712            match action {
713                quicksink::Action::Send(OutgoingData::Binary(x)) => {
714                    sender.send_binary_mut(x).await?
715                }
716                quicksink::Action::Send(OutgoingData::Ping(x)) => {
717                    let data = x[..].try_into().map_err(|_| {
718                        io::Error::new(io::ErrorKind::InvalidInput, "PING data must be < 126 bytes")
719                    })?;
720                    sender.send_ping(data).await?
721                }
722                quicksink::Action::Send(OutgoingData::Pong(x)) => {
723                    let data = x[..].try_into().map_err(|_| {
724                        io::Error::new(io::ErrorKind::InvalidInput, "PONG data must be < 126 bytes")
725                    })?;
726                    sender.send_pong(data).await?
727                }
728                quicksink::Action::Flush => sender.flush().await?,
729                quicksink::Action::Close => sender.close().await?,
730            }
731            Ok(sender)
732        });
733        let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async {
734            match receiver.receive(&mut data).await {
735                Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => Some((
736                    Ok(Incoming::Data(Data::Text(mem::take(&mut data)))),
737                    (data, receiver),
738                )),
739                Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => Some((
740                    Ok(Incoming::Data(Data::Binary(mem::take(&mut data)))),
741                    (data, receiver),
742                )),
743                Ok(soketto::Incoming::Pong(pong)) => {
744                    Some((Ok(Incoming::Pong(Vec::from(pong))), (data, receiver)))
745                }
746                Ok(soketto::Incoming::Closed(reason)) => {
747                    Some((Ok(Incoming::Closed(reason)), (data, receiver)))
748                }
749                Err(connection::Error::Closed) => None,
750                Err(e) => Some((Err(e), (data, receiver))),
751            }
752        });
753        Connection {
754            receiver: stream.boxed(),
755            sender: Box::pin(sink),
756            _marker: std::marker::PhantomData,
757        }
758    }
759
760    /// Send binary application data to the remote.
761    pub fn send_data(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
762        self.send(OutgoingData::Binary(data))
763    }
764
765    /// Send a PING to the remote.
766    pub fn send_ping(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
767        self.send(OutgoingData::Ping(data))
768    }
769
770    /// Send an unsolicited PONG to the remote.
771    pub fn send_pong(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
772        self.send(OutgoingData::Pong(data))
773    }
774}
775
776impl<T> Stream for Connection<T>
777where
778    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
779{
780    type Item = io::Result<Incoming>;
781
782    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
783        let item = ready!(self.receiver.poll_next_unpin(cx));
784        let item = item.map(|result| result.map_err(|e| io::Error::new(io::ErrorKind::Other, e)));
785        Poll::Ready(item)
786    }
787}
788
789impl<T> Sink<OutgoingData> for Connection<T>
790where
791    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
792{
793    type Error = io::Error;
794
795    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
796        Pin::new(&mut self.sender)
797            .poll_ready(cx)
798            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
799    }
800
801    fn start_send(mut self: Pin<&mut Self>, item: OutgoingData) -> io::Result<()> {
802        Pin::new(&mut self.sender)
803            .start_send(item)
804            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
805    }
806
807    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
808        Pin::new(&mut self.sender)
809            .poll_flush(cx)
810            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
811    }
812
813    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
814        Pin::new(&mut self.sender)
815            .poll_close(cx)
816            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
817    }
818}
819
820#[cfg(test)]
821mod tests {
822    use std::io;
823
824    use libp2p_identity::PeerId;
825
826    use super::*;
827
828    #[test]
829    fn listen_addr() {
830        let tcp_addr = "/ip4/0.0.0.0/tcp/2222".parse::<Multiaddr>().unwrap();
831
832        // Check `/tls/ws`
833        let addr = tcp_addr
834            .clone()
835            .with(Protocol::Tls)
836            .with(Protocol::Ws("/".into()));
837        let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
838        assert_eq!(&inner_addr, &tcp_addr);
839        assert_eq!(proto, WsListenProto::TlsWs("/".into()));
840
841        let mut listen_addr = tcp_addr.clone();
842        proto.append_on_addr(&mut listen_addr);
843        assert_eq!(listen_addr, addr);
844
845        // Check `/wss`
846        let addr = tcp_addr.clone().with(Protocol::Wss("/".into()));
847        let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
848        assert_eq!(&inner_addr, &tcp_addr);
849        assert_eq!(proto, WsListenProto::Wss("/".into()));
850
851        let mut listen_addr = tcp_addr.clone();
852        proto.append_on_addr(&mut listen_addr);
853        assert_eq!(listen_addr, addr);
854
855        // Check `/ws`
856        let addr = tcp_addr.clone().with(Protocol::Ws("/".into()));
857        let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
858        assert_eq!(&inner_addr, &tcp_addr);
859        assert_eq!(proto, WsListenProto::Ws("/".into()));
860
861        let mut listen_addr = tcp_addr.clone();
862        proto.append_on_addr(&mut listen_addr);
863        assert_eq!(listen_addr, addr);
864    }
865
866    #[test]
867    fn dial_addr() {
868        let peer_id = PeerId::random();
869
870        // Check `/tls/ws`
871        let addr = "/dns4/example.com/tcp/2222/tls/ws"
872            .parse::<Multiaddr>()
873            .unwrap();
874        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
875        assert_eq!(info.host_port, "example.com:2222");
876        assert_eq!(info.path, "/");
877        assert!(info.use_tls);
878        assert_eq!(info.server_name, "example.com".try_into().unwrap());
879        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
880
881        // Check `/tls/ws` with `/p2p`
882        let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
883            .parse()
884            .unwrap();
885        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
886        assert_eq!(info.host_port, "example.com:2222");
887        assert_eq!(info.path, "/");
888        assert!(info.use_tls);
889        assert_eq!(info.server_name, "example.com".try_into().unwrap());
890        assert_eq!(
891            info.tcp_addr,
892            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
893                .parse()
894                .unwrap()
895        );
896
897        // Check `/tls/ws` with `/ip4`
898        let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
899            .parse::<Multiaddr>()
900            .unwrap();
901        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
902        assert_eq!(info.host_port, "127.0.0.1:2222");
903        assert_eq!(info.path, "/");
904        assert!(info.use_tls);
905        assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
906        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
907
908        // Check `/tls/ws` with `/ip6`
909        let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
910        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
911        assert_eq!(info.host_port, "[::1]:2222");
912        assert_eq!(info.path, "/");
913        assert!(info.use_tls);
914        assert_eq!(info.server_name, "::1".try_into().unwrap());
915        assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
916
917        // Check `/wss`
918        let addr = "/dns4/example.com/tcp/2222/wss"
919            .parse::<Multiaddr>()
920            .unwrap();
921        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
922        assert_eq!(info.host_port, "example.com:2222");
923        assert_eq!(info.path, "/");
924        assert!(info.use_tls);
925        assert_eq!(info.server_name, "example.com".try_into().unwrap());
926        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
927
928        // Check `/wss` with `/p2p`
929        let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
930            .parse()
931            .unwrap();
932        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
933        assert_eq!(info.host_port, "example.com:2222");
934        assert_eq!(info.path, "/");
935        assert!(info.use_tls);
936        assert_eq!(info.server_name, "example.com".try_into().unwrap());
937        assert_eq!(
938            info.tcp_addr,
939            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
940                .parse()
941                .unwrap()
942        );
943
944        // Check `/wss` with `/ip4`
945        let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
946        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
947        assert_eq!(info.host_port, "127.0.0.1:2222");
948        assert_eq!(info.path, "/");
949        assert!(info.use_tls);
950        assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
951        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
952
953        // Check `/wss` with `/ip6`
954        let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
955        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
956        assert_eq!(info.host_port, "[::1]:2222");
957        assert_eq!(info.path, "/");
958        assert!(info.use_tls);
959        assert_eq!(info.server_name, "::1".try_into().unwrap());
960        assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
961
962        // Check `/ws`
963        let addr = "/dns4/example.com/tcp/2222/ws"
964            .parse::<Multiaddr>()
965            .unwrap();
966        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
967        assert_eq!(info.host_port, "example.com:2222");
968        assert_eq!(info.path, "/");
969        assert!(!info.use_tls);
970        assert_eq!(info.server_name, "example.com".try_into().unwrap());
971        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
972
973        // Check `/ws` with `/p2p`
974        let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
975            .parse()
976            .unwrap();
977        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
978        assert_eq!(info.host_port, "example.com:2222");
979        assert_eq!(info.path, "/");
980        assert!(!info.use_tls);
981        assert_eq!(info.server_name, "example.com".try_into().unwrap());
982        assert_eq!(
983            info.tcp_addr,
984            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
985                .parse()
986                .unwrap()
987        );
988
989        // Check `/ws` with `/ip4`
990        let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
991        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
992        assert_eq!(info.host_port, "127.0.0.1:2222");
993        assert_eq!(info.path, "/");
994        assert!(!info.use_tls);
995        assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
996        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
997
998        // Check `/ws` with `/ip6`
999        let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
1000        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
1001        assert_eq!(info.host_port, "[::1]:2222");
1002        assert_eq!(info.path, "/");
1003        assert!(!info.use_tls);
1004        assert_eq!(info.server_name, "::1".try_into().unwrap());
1005        assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
1006
1007        // Check `/dnsaddr`
1008        let addr = "/dnsaddr/example.com/tcp/2222/ws"
1009            .parse::<Multiaddr>()
1010            .unwrap();
1011        parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1012
1013        // Check non-ws address
1014        let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
1015        parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1016    }
1017}