libp2p_websocket/
framed.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    borrow::Cow,
23    collections::HashMap,
24    fmt, io, mem,
25    net::IpAddr,
26    ops::DerefMut,
27    pin::Pin,
28    sync::Arc,
29    task::{Context, Poll},
30};
31
32use either::Either;
33use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
34use futures_rustls::{client, rustls::pki_types::ServerName, server};
35use libp2p_core::{
36    multiaddr::{Multiaddr, Protocol},
37    transport::{DialOpts, ListenerId, TransportError, TransportEvent},
38    Transport,
39};
40use parking_lot::Mutex;
41use soketto::{
42    connection::{self, CloseReason},
43    handshake,
44};
45use url::Url;
46
47use crate::{error::Error, quicksink, tls};
48
49/// Max. number of payload bytes of a single frame.
50const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
51
52/// A Websocket transport whose output type is a [`Stream`] and [`Sink`] of
53/// frame payloads which does not implement [`AsyncRead`] or
54/// [`AsyncWrite`]. See [`crate::Config`] if you require the latter.
55#[deprecated = "Use `Config` instead"]
56pub type WsConfig<T> = Config<T>;
57
58#[derive(Debug)]
59pub struct Config<T> {
60    transport: Arc<Mutex<T>>,
61    max_data_size: usize,
62    tls_config: tls::Config,
63    max_redirects: u8,
64    /// Websocket protocol of the inner listener.
65    listener_protos: HashMap<ListenerId, WsListenProto<'static>>,
66}
67
68impl<T> Config<T>
69where
70    T: Send,
71{
72    /// Create a new websocket transport based on another transport.
73    pub fn new(transport: T) -> Self {
74        Config {
75            transport: Arc::new(Mutex::new(transport)),
76            max_data_size: MAX_DATA_SIZE,
77            tls_config: tls::Config::client(),
78            max_redirects: 0,
79            listener_protos: HashMap::new(),
80        }
81    }
82
83    /// Return the configured maximum number of redirects.
84    pub fn max_redirects(&self) -> u8 {
85        self.max_redirects
86    }
87
88    /// Set max. number of redirects to follow.
89    pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
90        self.max_redirects = max;
91        self
92    }
93
94    /// Get the max. frame data size we support.
95    pub fn max_data_size(&self) -> usize {
96        self.max_data_size
97    }
98
99    /// Set the max. frame data size we support.
100    pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
101        self.max_data_size = size;
102        self
103    }
104
105    /// Set the TLS configuration if TLS support is desired.
106    pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
107        self.tls_config = c;
108        self
109    }
110}
111
112type TlsOrPlain<T> = future::Either<future::Either<client::TlsStream<T>, server::TlsStream<T>>, T>;
113
114impl<T> Transport for Config<T>
115where
116    T: Transport + Send + Unpin + 'static,
117    T::Error: Send + 'static,
118    T::Dial: Send + 'static,
119    T::ListenerUpgrade: Send + 'static,
120    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
121{
122    type Output = Connection<T::Output>;
123    type Error = Error<T::Error>;
124    type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
125    type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
126
127    fn listen_on(
128        &mut self,
129        id: ListenerId,
130        addr: Multiaddr,
131    ) -> Result<(), TransportError<Self::Error>> {
132        let (inner_addr, proto) = parse_ws_listen_addr(&addr).ok_or_else(|| {
133            tracing::debug!(address=%addr, "Address is not a websocket multiaddr");
134            TransportError::MultiaddrNotSupported(addr.clone())
135        })?;
136
137        if proto.use_tls() && self.tls_config.server.is_none() {
138            tracing::debug!(
139                "{} address but TLS server support is not configured",
140                proto.prefix()
141            );
142            return Err(TransportError::MultiaddrNotSupported(addr));
143        }
144
145        match self.transport.lock().listen_on(id, inner_addr) {
146            Ok(()) => {
147                self.listener_protos.insert(id, proto);
148                Ok(())
149            }
150            Err(e) => Err(e.map(Error::Transport)),
151        }
152    }
153
154    fn remove_listener(&mut self, id: ListenerId) -> bool {
155        self.transport.lock().remove_listener(id)
156    }
157
158    fn dial(
159        &mut self,
160        addr: Multiaddr,
161        dial_opts: DialOpts,
162    ) -> Result<Self::Dial, TransportError<Self::Error>> {
163        self.do_dial(addr, dial_opts)
164    }
165
166    fn poll(
167        mut self: Pin<&mut Self>,
168        cx: &mut Context<'_>,
169    ) -> Poll<libp2p_core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>> {
170        let inner_event = {
171            let mut transport = self.transport.lock();
172            match Transport::poll(Pin::new(transport.deref_mut()), cx) {
173                Poll::Ready(ev) => ev,
174                Poll::Pending => return Poll::Pending,
175            }
176        };
177        let event = match inner_event {
178            TransportEvent::NewAddress {
179                listener_id,
180                mut listen_addr,
181            } => {
182                // Append the ws / wss protocol back to the inner address.
183                self.listener_protos
184                    .get(&listener_id)
185                    .expect("Protocol was inserted in Transport::listen_on.")
186                    .append_on_addr(&mut listen_addr);
187                tracing::debug!(address=%listen_addr, "Listening on address");
188                TransportEvent::NewAddress {
189                    listener_id,
190                    listen_addr,
191                }
192            }
193            TransportEvent::AddressExpired {
194                listener_id,
195                mut listen_addr,
196            } => {
197                self.listener_protos
198                    .get(&listener_id)
199                    .expect("Protocol was inserted in Transport::listen_on.")
200                    .append_on_addr(&mut listen_addr);
201                TransportEvent::AddressExpired {
202                    listener_id,
203                    listen_addr,
204                }
205            }
206            TransportEvent::ListenerError { listener_id, error } => TransportEvent::ListenerError {
207                listener_id,
208                error: Error::Transport(error),
209            },
210            TransportEvent::ListenerClosed {
211                listener_id,
212                reason,
213            } => {
214                self.listener_protos
215                    .remove(&listener_id)
216                    .expect("Protocol was inserted in Transport::listen_on.");
217                TransportEvent::ListenerClosed {
218                    listener_id,
219                    reason: reason.map_err(Error::Transport),
220                }
221            }
222            TransportEvent::Incoming {
223                listener_id,
224                upgrade,
225                mut local_addr,
226                mut send_back_addr,
227            } => {
228                let proto = self
229                    .listener_protos
230                    .get(&listener_id)
231                    .expect("Protocol was inserted in Transport::listen_on.");
232                let use_tls = proto.use_tls();
233                proto.append_on_addr(&mut local_addr);
234                proto.append_on_addr(&mut send_back_addr);
235                let upgrade = self.map_upgrade(upgrade, send_back_addr.clone(), use_tls);
236                TransportEvent::Incoming {
237                    listener_id,
238                    upgrade,
239                    local_addr,
240                    send_back_addr,
241                }
242            }
243        };
244        Poll::Ready(event)
245    }
246}
247
248impl<T> Config<T>
249where
250    T: Transport + Send + Unpin + 'static,
251    T::Error: Send + 'static,
252    T::Dial: Send + 'static,
253    T::ListenerUpgrade: Send + 'static,
254    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
255{
256    fn do_dial(
257        &mut self,
258        addr: Multiaddr,
259        dial_opts: DialOpts,
260    ) -> Result<<Self as Transport>::Dial, TransportError<<Self as Transport>::Error>> {
261        let mut addr = match parse_ws_dial_addr(addr) {
262            Ok(addr) => addr,
263            Err(Error::InvalidMultiaddr(a)) => {
264                return Err(TransportError::MultiaddrNotSupported(a))
265            }
266            Err(e) => return Err(TransportError::Other(e)),
267        };
268
269        // We are looping here in order to follow redirects (if any):
270        let mut remaining_redirects = self.max_redirects;
271
272        let transport = self.transport.clone();
273        let tls_config = self.tls_config.clone();
274        let max_redirects = self.max_redirects;
275
276        let future = async move {
277            loop {
278                match Self::dial_once(transport.clone(), addr, tls_config.clone(), dial_opts).await
279                {
280                    Ok(Either::Left(redirect)) => {
281                        if remaining_redirects == 0 {
282                            tracing::debug!(%max_redirects, "Too many redirects");
283                            return Err(Error::TooManyRedirects);
284                        }
285                        remaining_redirects -= 1;
286                        addr = parse_ws_dial_addr(location_to_multiaddr(&redirect)?)?
287                    }
288                    Ok(Either::Right(conn)) => return Ok(conn),
289                    Err(e) => return Err(e),
290                }
291            }
292        };
293
294        Ok(Box::pin(future))
295    }
296
297    /// Attempts to dial the given address and perform a websocket handshake.
298    async fn dial_once(
299        transport: Arc<Mutex<T>>,
300        addr: WsAddress,
301        tls_config: tls::Config,
302        dial_opts: DialOpts,
303    ) -> Result<Either<String, Connection<T::Output>>, Error<T::Error>> {
304        tracing::trace!(address=?addr, "Dialing websocket address");
305
306        let dial = transport
307            .lock()
308            .dial(addr.tcp_addr, dial_opts)
309            .map_err(|e| match e {
310                TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a),
311                TransportError::Other(e) => Error::Transport(e),
312            })?;
313
314        let stream = dial.map_err(Error::Transport).await?;
315        tracing::trace!(port=%addr.host_port, "TCP connection established");
316
317        let stream = if addr.use_tls {
318            // begin TLS session
319            tracing::trace!(?addr.server_name, "Starting TLS handshake");
320            let stream = tls_config
321                .client
322                .connect(addr.server_name.clone(), stream)
323                .map_err(|e| {
324                    tracing::debug!(?addr.server_name, "TLS handshake failed: {}", e);
325                    Error::Tls(tls::Error::from(e))
326                })
327                .await?;
328
329            let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Left(stream));
330            stream
331        } else {
332            // continue with plain stream
333            future::Either::Right(stream)
334        };
335
336        tracing::trace!(port=%addr.host_port, "Sending websocket handshake");
337
338        let mut client = handshake::Client::new(stream, &addr.host_port, addr.path.as_ref());
339
340        match client
341            .handshake()
342            .map_err(|e| Error::Handshake(Box::new(e)))
343            .await?
344        {
345            handshake::ServerResponse::Redirect {
346                status_code,
347                location,
348            } => {
349                tracing::debug!(
350                    %status_code,
351                    %location,
352                    "received redirect"
353                );
354                Ok(Either::Left(location))
355            }
356            handshake::ServerResponse::Rejected { status_code } => {
357                let msg = format!("server rejected handshake; status code = {status_code}");
358                Err(Error::Handshake(msg.into()))
359            }
360            handshake::ServerResponse::Accepted { .. } => {
361                tracing::trace!(port=%addr.host_port, "websocket handshake successful");
362                Ok(Either::Right(Connection::new(client.into_builder())))
363            }
364        }
365    }
366
367    fn map_upgrade(
368        &self,
369        upgrade: T::ListenerUpgrade,
370        remote_addr: Multiaddr,
371        use_tls: bool,
372    ) -> <Self as Transport>::ListenerUpgrade {
373        let remote_addr2 = remote_addr.clone(); // used for logging
374        let tls_config = self.tls_config.clone();
375        let max_size = self.max_data_size;
376
377        async move {
378            let stream = upgrade.map_err(Error::Transport).await?;
379            tracing::trace!(address=%remote_addr, "incoming connection from address");
380
381            let stream = if use_tls {
382                // begin TLS session
383                let server = tls_config
384                    .server
385                    .expect("for use_tls we checked server is not none");
386
387                tracing::trace!(address=%remote_addr, "awaiting TLS handshake with address");
388
389                let stream = server
390                    .accept(stream)
391                    .map_err(move |e| {
392                        tracing::debug!(address=%remote_addr, "TLS handshake with address failed: {}", e);
393                        Error::Tls(tls::Error::from(e))
394                    })
395                    .await?;
396
397                let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Right(stream));
398
399                stream
400            } else {
401                // continue with plain stream
402                future::Either::Right(stream)
403            };
404
405            tracing::trace!(
406                address=%remote_addr2,
407                "receiving websocket handshake request from address"
408            );
409
410            let mut server = handshake::Server::new(stream);
411
412            let ws_key = {
413                let request = server
414                    .receive_request()
415                    .map_err(|e| Error::Handshake(Box::new(e)))
416                    .await?;
417                request.key()
418            };
419
420            tracing::trace!(
421                address=%remote_addr2,
422                "accepting websocket handshake request from address"
423            );
424
425            let response = handshake::server::Response::Accept {
426                key: ws_key,
427                protocol: None,
428            };
429
430            server
431                .send_response(&response)
432                .map_err(|e| Error::Handshake(Box::new(e)))
433                .await?;
434
435            let conn = {
436                let mut builder = server.into_builder();
437                builder.set_max_message_size(max_size);
438                builder.set_max_frame_size(max_size);
439                Connection::new(builder)
440            };
441
442            Ok(conn)
443        }
444        .boxed()
445    }
446}
447
448#[derive(Debug, PartialEq)]
449pub(crate) enum WsListenProto<'a> {
450    Ws(Cow<'a, str>),
451    Wss(Cow<'a, str>),
452    TlsWs(Cow<'a, str>),
453}
454
455impl WsListenProto<'_> {
456    pub(crate) fn append_on_addr(&self, addr: &mut Multiaddr) {
457        match self {
458            WsListenProto::Ws(path) => {
459                addr.push(Protocol::Ws(path.clone()));
460            }
461            // `/tls/ws` and `/wss` are equivalend, however we regenerate
462            // the one that user passed at `listen_on` for backward compatibility.
463            WsListenProto::Wss(path) => {
464                addr.push(Protocol::Wss(path.clone()));
465            }
466            WsListenProto::TlsWs(path) => {
467                addr.push(Protocol::Tls);
468                addr.push(Protocol::Ws(path.clone()));
469            }
470        }
471    }
472
473    pub(crate) fn use_tls(&self) -> bool {
474        match self {
475            WsListenProto::Ws(_) => false,
476            WsListenProto::Wss(_) => true,
477            WsListenProto::TlsWs(_) => true,
478        }
479    }
480
481    pub(crate) fn prefix(&self) -> &'static str {
482        match self {
483            WsListenProto::Ws(_) => "/ws",
484            WsListenProto::Wss(_) => "/wss",
485            WsListenProto::TlsWs(_) => "/tls/ws",
486        }
487    }
488}
489
490#[derive(Debug)]
491struct WsAddress {
492    host_port: String,
493    path: String,
494    server_name: ServerName<'static>,
495    use_tls: bool,
496    tcp_addr: Multiaddr,
497}
498
499/// Tries to parse the given `Multiaddr` into a `WsAddress` used
500/// for dialing.
501///
502/// Fails if the given `Multiaddr` does not represent a TCP/IP-based
503/// websocket protocol stack.
504fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
505    // The encapsulating protocol must be based on TCP/IP, possibly via DNS.
506    // We peek at it in order to learn the hostname and port to use for
507    // the websocket handshake.
508    let mut protocols = addr.iter();
509    let mut ip = protocols.next();
510    let mut tcp = protocols.next();
511    let (host_port, mut server_name) = loop {
512        match (ip, tcp) {
513            (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
514                let server_name = ServerName::IpAddress(IpAddr::V4(ip).into());
515                break (format!("{ip}:{port}"), server_name);
516            }
517            (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
518                let server_name = ServerName::IpAddress(IpAddr::V6(ip).into());
519                break (format!("[{ip}]:{port}"), server_name);
520            }
521            (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
522            | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
523            | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
524                break (format!("{h}:{port}"), tls::dns_name_ref(&h)?)
525            }
526            (Some(_), Some(p)) => {
527                ip = Some(p);
528                tcp = protocols.next();
529            }
530            _ => return Err(Error::InvalidMultiaddr(addr)),
531        }
532    };
533
534    // Will hold a value if the multiaddr carries `/tls/sni/<host>`.
535    let mut sni_override: Option<ServerName<'static>> = None;
536    // Now consume the `Ws` / `Wss` protocol from the end of the address,
537    // preserving the trailing `P2p` protocol that identifies the remote,
538    // if any.
539    let mut protocols = addr.clone();
540    let mut p2p = None;
541    let (use_tls, path) = loop {
542        match protocols.pop() {
543            p @ Some(Protocol::P2p(_)) => p2p = p,
544            Some(Protocol::Ws(path)) => match protocols.pop() {
545                Some(Protocol::Sni(domain)) => match protocols.pop() {
546                    Some(Protocol::Tls) => {
547                        sni_override = Some(tls::dns_name_ref(&domain)?);
548                        break (true, path.into_owned());
549                    }
550                    _ => return Err(Error::InvalidMultiaddr(addr)),
551                },
552                Some(Protocol::Tls) => break (true, path.into_owned()),
553                Some(p) => {
554                    protocols.push(p);
555                    break (false, path.into_owned());
556                }
557                None => return Err(Error::InvalidMultiaddr(addr)),
558            },
559            Some(Protocol::Wss(path)) => break (true, path.into_owned()),
560            _ => return Err(Error::InvalidMultiaddr(addr)),
561        }
562    };
563
564    // The original address, stripped of the `/ws` and `/wss` protocols,
565    // makes up the address for the inner TCP-based transport.
566    let tcp_addr = match p2p {
567        Some(p) => protocols.with(p),
568        None => protocols,
569    };
570
571    if let Some(name) = sni_override {
572        server_name = name;
573    }
574    Ok(WsAddress {
575        host_port,
576        server_name,
577        path,
578        use_tls,
579        tcp_addr,
580    })
581}
582
583fn parse_ws_listen_addr(addr: &Multiaddr) -> Option<(Multiaddr, WsListenProto<'static>)> {
584    let mut inner_addr = addr.clone();
585
586    match inner_addr.pop()? {
587        Protocol::Wss(path) => Some((inner_addr, WsListenProto::Wss(path))),
588        Protocol::Ws(path) => match inner_addr.pop()? {
589            Protocol::Tls => Some((inner_addr, WsListenProto::TlsWs(path))),
590            p => {
591                inner_addr.push(p);
592                Some((inner_addr, WsListenProto::Ws(path)))
593            }
594        },
595        _ => None,
596    }
597}
598
599// Given a location URL, build a new websocket [`Multiaddr`].
600fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
601    match Url::parse(location) {
602        Ok(url) => {
603            let mut a = Multiaddr::empty();
604            match url.host() {
605                Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())),
606                Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)),
607                Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)),
608                None => return Err(Error::InvalidRedirectLocation),
609            }
610            if let Some(p) = url.port() {
611                a.push(Protocol::Tcp(p))
612            }
613            let s = url.scheme();
614            if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
615                a.push(Protocol::Tls);
616                a.push(Protocol::Ws(url.path().into()));
617            } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
618                a.push(Protocol::Ws(url.path().into()))
619            } else {
620                tracing::debug!(scheme=%s, "unsupported scheme");
621                return Err(Error::InvalidRedirectLocation);
622            }
623            Ok(a)
624        }
625        Err(e) => {
626            tracing::debug!("failed to parse url as multi-address: {:?}", e);
627            Err(Error::InvalidRedirectLocation)
628        }
629    }
630}
631
632/// The websocket connection.
633pub struct Connection<T> {
634    receiver: BoxStream<'static, Result<Incoming, connection::Error>>,
635    sender: Pin<Box<dyn Sink<OutgoingData, Error = quicksink::Error<connection::Error>> + Send>>,
636    _marker: std::marker::PhantomData<T>,
637}
638
639/// Data or control information received over the websocket connection.
640#[derive(Debug, Clone)]
641pub enum Incoming {
642    /// Application data.
643    Data(Data),
644    /// PONG control frame data.
645    Pong(Vec<u8>),
646    /// Close reason.
647    Closed(CloseReason),
648}
649
650/// Application data received over the websocket connection
651#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
652pub enum Data {
653    /// UTF-8 encoded textual data.
654    Text(Vec<u8>),
655    /// Binary data.
656    Binary(Vec<u8>),
657}
658
659impl Data {
660    pub fn into_bytes(self) -> Vec<u8> {
661        match self {
662            Data::Text(d) => d,
663            Data::Binary(d) => d,
664        }
665    }
666}
667
668impl AsRef<[u8]> for Data {
669    fn as_ref(&self) -> &[u8] {
670        match self {
671            Data::Text(d) => d,
672            Data::Binary(d) => d,
673        }
674    }
675}
676
677impl Incoming {
678    pub fn is_data(&self) -> bool {
679        self.is_binary() || self.is_text()
680    }
681
682    pub fn is_binary(&self) -> bool {
683        matches!(self, Incoming::Data(Data::Binary(_)))
684    }
685
686    pub fn is_text(&self) -> bool {
687        matches!(self, Incoming::Data(Data::Text(_)))
688    }
689
690    pub fn is_pong(&self) -> bool {
691        matches!(self, Incoming::Pong(_))
692    }
693
694    pub fn is_close(&self) -> bool {
695        matches!(self, Incoming::Closed(_))
696    }
697}
698
699/// Data sent over the websocket connection.
700#[derive(Debug, Clone)]
701pub enum OutgoingData {
702    /// Send some bytes.
703    Binary(Vec<u8>),
704    /// Send a PING message.
705    Ping(Vec<u8>),
706    /// Send an unsolicited PONG message.
707    /// (Incoming PINGs are answered automatically.)
708    Pong(Vec<u8>),
709}
710
711impl<T> fmt::Debug for Connection<T> {
712    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
713        f.write_str("Connection")
714    }
715}
716
717impl<T> Connection<T>
718where
719    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
720{
721    fn new(builder: connection::Builder<TlsOrPlain<T>>) -> Self {
722        let (sender, receiver) = builder.finish();
723        let sink = quicksink::make_sink(sender, |mut sender, action| async move {
724            match action {
725                quicksink::Action::Send(OutgoingData::Binary(x)) => {
726                    sender.send_binary_mut(x).await?
727                }
728                quicksink::Action::Send(OutgoingData::Ping(x)) => {
729                    let data = x[..].try_into().map_err(|_| {
730                        io::Error::new(io::ErrorKind::InvalidInput, "PING data must be < 126 bytes")
731                    })?;
732                    sender.send_ping(data).await?
733                }
734                quicksink::Action::Send(OutgoingData::Pong(x)) => {
735                    let data = x[..].try_into().map_err(|_| {
736                        io::Error::new(io::ErrorKind::InvalidInput, "PONG data must be < 126 bytes")
737                    })?;
738                    sender.send_pong(data).await?
739                }
740                quicksink::Action::Flush => sender.flush().await?,
741                quicksink::Action::Close => sender.close().await?,
742            }
743            Ok(sender)
744        });
745        let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async {
746            match receiver.receive(&mut data).await {
747                Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => Some((
748                    Ok(Incoming::Data(Data::Text(mem::take(&mut data)))),
749                    (data, receiver),
750                )),
751                Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => Some((
752                    Ok(Incoming::Data(Data::Binary(mem::take(&mut data)))),
753                    (data, receiver),
754                )),
755                Ok(soketto::Incoming::Pong(pong)) => {
756                    Some((Ok(Incoming::Pong(Vec::from(pong))), (data, receiver)))
757                }
758                Ok(soketto::Incoming::Closed(reason)) => {
759                    Some((Ok(Incoming::Closed(reason)), (data, receiver)))
760                }
761                Err(connection::Error::Closed) => None,
762                Err(e) => Some((Err(e), (data, receiver))),
763            }
764        });
765        Connection {
766            receiver: stream.boxed(),
767            sender: Box::pin(sink),
768            _marker: std::marker::PhantomData,
769        }
770    }
771
772    /// Send binary application data to the remote.
773    pub fn send_data(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
774        self.send(OutgoingData::Binary(data))
775    }
776
777    /// Send a PING to the remote.
778    pub fn send_ping(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
779        self.send(OutgoingData::Ping(data))
780    }
781
782    /// Send an unsolicited PONG to the remote.
783    pub fn send_pong(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
784        self.send(OutgoingData::Pong(data))
785    }
786}
787
788impl<T> Stream for Connection<T>
789where
790    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
791{
792    type Item = io::Result<Incoming>;
793
794    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
795        let item = ready!(self.receiver.poll_next_unpin(cx));
796        let item = item.map(|result| result.map_err(io::Error::other));
797        Poll::Ready(item)
798    }
799}
800
801impl<T> Sink<OutgoingData> for Connection<T>
802where
803    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
804{
805    type Error = io::Error;
806
807    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
808        Pin::new(&mut self.sender)
809            .poll_ready(cx)
810            .map_err(io::Error::other)
811    }
812
813    fn start_send(mut self: Pin<&mut Self>, item: OutgoingData) -> io::Result<()> {
814        Pin::new(&mut self.sender)
815            .start_send(item)
816            .map_err(io::Error::other)
817    }
818
819    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
820        Pin::new(&mut self.sender)
821            .poll_flush(cx)
822            .map_err(io::Error::other)
823    }
824
825    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
826        Pin::new(&mut self.sender)
827            .poll_close(cx)
828            .map_err(io::Error::other)
829    }
830}
831
832#[cfg(test)]
833mod tests {
834    use std::io;
835
836    use libp2p_identity::PeerId;
837
838    use super::*;
839
840    #[test]
841    fn listen_addr() {
842        let tcp_addr = "/ip4/0.0.0.0/tcp/2222".parse::<Multiaddr>().unwrap();
843
844        // Check `/tls/ws`
845        let addr = tcp_addr
846            .clone()
847            .with(Protocol::Tls)
848            .with(Protocol::Ws("/".into()));
849        let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
850        assert_eq!(&inner_addr, &tcp_addr);
851        assert_eq!(proto, WsListenProto::TlsWs("/".into()));
852
853        let mut listen_addr = tcp_addr.clone();
854        proto.append_on_addr(&mut listen_addr);
855        assert_eq!(listen_addr, addr);
856
857        // Check `/wss`
858        let addr = tcp_addr.clone().with(Protocol::Wss("/".into()));
859        let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
860        assert_eq!(&inner_addr, &tcp_addr);
861        assert_eq!(proto, WsListenProto::Wss("/".into()));
862
863        let mut listen_addr = tcp_addr.clone();
864        proto.append_on_addr(&mut listen_addr);
865        assert_eq!(listen_addr, addr);
866
867        // Check `/ws`
868        let addr = tcp_addr.clone().with(Protocol::Ws("/".into()));
869        let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
870        assert_eq!(&inner_addr, &tcp_addr);
871        assert_eq!(proto, WsListenProto::Ws("/".into()));
872
873        let mut listen_addr = tcp_addr.clone();
874        proto.append_on_addr(&mut listen_addr);
875        assert_eq!(listen_addr, addr);
876    }
877
878    #[test]
879    fn dial_addr() {
880        let peer_id = PeerId::random();
881
882        // Check `/tls/ws`
883        let addr = "/dns4/example.com/tcp/2222/tls/ws"
884            .parse::<Multiaddr>()
885            .unwrap();
886        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
887        assert_eq!(info.host_port, "example.com:2222");
888        assert_eq!(info.path, "/");
889        assert!(info.use_tls);
890        assert_eq!(info.server_name, "example.com".try_into().unwrap());
891        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
892
893        // Check `/tls/ws` with `/p2p`
894        let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
895            .parse()
896            .unwrap();
897        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
898        assert_eq!(info.host_port, "example.com:2222");
899        assert_eq!(info.path, "/");
900        assert!(info.use_tls);
901        assert_eq!(info.server_name, "example.com".try_into().unwrap());
902        assert_eq!(
903            info.tcp_addr,
904            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
905                .parse()
906                .unwrap()
907        );
908
909        // Check `/tls/ws` with `/ip4`
910        let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
911            .parse::<Multiaddr>()
912            .unwrap();
913        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
914        assert_eq!(info.host_port, "127.0.0.1:2222");
915        assert_eq!(info.path, "/");
916        assert!(info.use_tls);
917        assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
918        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
919
920        // Check `/tls/ws` with `/ip6`
921        let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
922        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
923        assert_eq!(info.host_port, "[::1]:2222");
924        assert_eq!(info.path, "/");
925        assert!(info.use_tls);
926        assert_eq!(info.server_name, "::1".try_into().unwrap());
927        assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
928
929        // Check `/wss`
930        let addr = "/dns4/example.com/tcp/2222/wss"
931            .parse::<Multiaddr>()
932            .unwrap();
933        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
934        assert_eq!(info.host_port, "example.com:2222");
935        assert_eq!(info.path, "/");
936        assert!(info.use_tls);
937        assert_eq!(info.server_name, "example.com".try_into().unwrap());
938        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
939
940        // Check `/wss` with `/p2p`
941        let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
942            .parse()
943            .unwrap();
944        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
945        assert_eq!(info.host_port, "example.com:2222");
946        assert_eq!(info.path, "/");
947        assert!(info.use_tls);
948        assert_eq!(info.server_name, "example.com".try_into().unwrap());
949        assert_eq!(
950            info.tcp_addr,
951            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
952                .parse()
953                .unwrap()
954        );
955
956        // Check `/wss` with `/ip4`
957        let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
958        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
959        assert_eq!(info.host_port, "127.0.0.1:2222");
960        assert_eq!(info.path, "/");
961        assert!(info.use_tls);
962        assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
963        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
964
965        // Check `/wss` with `/ip6`
966        let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
967        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
968        assert_eq!(info.host_port, "[::1]:2222");
969        assert_eq!(info.path, "/");
970        assert!(info.use_tls);
971        assert_eq!(info.server_name, "::1".try_into().unwrap());
972        assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
973
974        // Check `/ws`
975        let addr = "/dns4/example.com/tcp/2222/ws"
976            .parse::<Multiaddr>()
977            .unwrap();
978        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
979        assert_eq!(info.host_port, "example.com:2222");
980        assert_eq!(info.path, "/");
981        assert!(!info.use_tls);
982        assert_eq!(info.server_name, "example.com".try_into().unwrap());
983        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
984
985        // Check `/ws` with `/p2p`
986        let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
987            .parse()
988            .unwrap();
989        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
990        assert_eq!(info.host_port, "example.com:2222");
991        assert_eq!(info.path, "/");
992        assert!(!info.use_tls);
993        assert_eq!(info.server_name, "example.com".try_into().unwrap());
994        assert_eq!(
995            info.tcp_addr,
996            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
997                .parse()
998                .unwrap()
999        );
1000
1001        // Check `/ws` with `/ip4`
1002        let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
1003        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
1004        assert_eq!(info.host_port, "127.0.0.1:2222");
1005        assert_eq!(info.path, "/");
1006        assert!(!info.use_tls);
1007        assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
1008        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
1009
1010        // Check `/ws` with `/ip6`
1011        let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
1012        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
1013        assert_eq!(info.host_port, "[::1]:2222");
1014        assert_eq!(info.path, "/");
1015        assert!(!info.use_tls);
1016        assert_eq!(info.server_name, "::1".try_into().unwrap());
1017        assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
1018
1019        // Check `/dnsaddr`
1020        let addr = "/dnsaddr/example.com/tcp/2222/ws"
1021            .parse::<Multiaddr>()
1022            .unwrap();
1023        parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1024
1025        // Check non-ws address
1026        let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
1027        parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1028
1029        // Check `/tls/sni/.../ws` with `/dns4`
1030        let addr = "/dns4/example.com/tcp/2222/tls/sni/example.com/ws"
1031            .parse::<Multiaddr>()
1032            .unwrap();
1033        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
1034        assert_eq!(info.host_port, "example.com:2222");
1035        assert_eq!(info.path, "/");
1036        assert!(info.use_tls);
1037        assert_eq!(info.server_name, "example.com".try_into().unwrap());
1038        assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
1039
1040        // Check `/tls/sni/.../ws` with `/ip4`
1041        let addr = "/ip4/127.0.0.1/tcp/2222/tls/sni/example.test/ws"
1042            .parse::<Multiaddr>()
1043            .unwrap();
1044        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
1045        assert_eq!(info.host_port, "127.0.0.1:2222");
1046        assert_eq!(info.path, "/");
1047        assert!(info.use_tls);
1048        assert_eq!(info.server_name, "example.test".try_into().unwrap());
1049        assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
1050
1051        // Check `/tls/sni/.../ws` with trailing `/p2p`
1052        let addr = format!("/dns4/example.com/tcp/2222/tls/sni/example.com/ws/p2p/{peer_id}")
1053            .parse()
1054            .unwrap();
1055        let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
1056        assert_eq!(info.host_port, "example.com:2222");
1057        assert_eq!(info.path, "/");
1058        assert!(info.use_tls);
1059        assert_eq!(info.server_name, "example.com".try_into().unwrap());
1060        assert_eq!(
1061            info.tcp_addr,
1062            format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
1063                .parse()
1064                .unwrap()
1065        );
1066
1067        // Negative: `/tls/sni/...` *without* `/ws` → error
1068        let bad = "/dns4/example.com/tcp/2222/tls/sni/example.com"
1069            .parse::<Multiaddr>()
1070            .unwrap();
1071        parse_ws_dial_addr::<io::Error>(bad).unwrap_err();
1072    }
1073}