1use std::{
22 borrow::Cow,
23 collections::HashMap,
24 fmt, io, mem,
25 net::IpAddr,
26 ops::DerefMut,
27 pin::Pin,
28 sync::Arc,
29 task::{Context, Poll},
30};
31
32use either::Either;
33use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
34use futures_rustls::{client, rustls::pki_types::ServerName, server};
35use libp2p_core::{
36 multiaddr::{Multiaddr, Protocol},
37 transport::{DialOpts, ListenerId, TransportError, TransportEvent},
38 Transport,
39};
40use parking_lot::Mutex;
41use soketto::{
42 connection::{self, CloseReason},
43 handshake,
44};
45use url::Url;
46
47use crate::{error::Error, quicksink, tls};
48
49const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
51
52#[deprecated = "Use `Config` instead"]
56pub type WsConfig<T> = Config<T>;
57
58#[derive(Debug)]
59pub struct Config<T> {
60 transport: Arc<Mutex<T>>,
61 max_data_size: usize,
62 tls_config: tls::Config,
63 max_redirects: u8,
64 listener_protos: HashMap<ListenerId, WsListenProto<'static>>,
66}
67
68impl<T> Config<T>
69where
70 T: Send,
71{
72 pub fn new(transport: T) -> Self {
74 Config {
75 transport: Arc::new(Mutex::new(transport)),
76 max_data_size: MAX_DATA_SIZE,
77 tls_config: tls::Config::client(),
78 max_redirects: 0,
79 listener_protos: HashMap::new(),
80 }
81 }
82
83 pub fn max_redirects(&self) -> u8 {
85 self.max_redirects
86 }
87
88 pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
90 self.max_redirects = max;
91 self
92 }
93
94 pub fn max_data_size(&self) -> usize {
96 self.max_data_size
97 }
98
99 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
101 self.max_data_size = size;
102 self
103 }
104
105 pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
107 self.tls_config = c;
108 self
109 }
110}
111
112type TlsOrPlain<T> = future::Either<future::Either<client::TlsStream<T>, server::TlsStream<T>>, T>;
113
114impl<T> Transport for Config<T>
115where
116 T: Transport + Send + Unpin + 'static,
117 T::Error: Send + 'static,
118 T::Dial: Send + 'static,
119 T::ListenerUpgrade: Send + 'static,
120 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
121{
122 type Output = Connection<T::Output>;
123 type Error = Error<T::Error>;
124 type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
125 type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
126
127 fn listen_on(
128 &mut self,
129 id: ListenerId,
130 addr: Multiaddr,
131 ) -> Result<(), TransportError<Self::Error>> {
132 let (inner_addr, proto) = parse_ws_listen_addr(&addr).ok_or_else(|| {
133 tracing::debug!(address=%addr, "Address is not a websocket multiaddr");
134 TransportError::MultiaddrNotSupported(addr.clone())
135 })?;
136
137 if proto.use_tls() && self.tls_config.server.is_none() {
138 tracing::debug!(
139 "{} address but TLS server support is not configured",
140 proto.prefix()
141 );
142 return Err(TransportError::MultiaddrNotSupported(addr));
143 }
144
145 match self.transport.lock().listen_on(id, inner_addr) {
146 Ok(()) => {
147 self.listener_protos.insert(id, proto);
148 Ok(())
149 }
150 Err(e) => Err(e.map(Error::Transport)),
151 }
152 }
153
154 fn remove_listener(&mut self, id: ListenerId) -> bool {
155 self.transport.lock().remove_listener(id)
156 }
157
158 fn dial(
159 &mut self,
160 addr: Multiaddr,
161 dial_opts: DialOpts,
162 ) -> Result<Self::Dial, TransportError<Self::Error>> {
163 self.do_dial(addr, dial_opts)
164 }
165
166 fn poll(
167 mut self: Pin<&mut Self>,
168 cx: &mut Context<'_>,
169 ) -> Poll<libp2p_core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>> {
170 let inner_event = {
171 let mut transport = self.transport.lock();
172 match Transport::poll(Pin::new(transport.deref_mut()), cx) {
173 Poll::Ready(ev) => ev,
174 Poll::Pending => return Poll::Pending,
175 }
176 };
177 let event = match inner_event {
178 TransportEvent::NewAddress {
179 listener_id,
180 mut listen_addr,
181 } => {
182 self.listener_protos
184 .get(&listener_id)
185 .expect("Protocol was inserted in Transport::listen_on.")
186 .append_on_addr(&mut listen_addr);
187 tracing::debug!(address=%listen_addr, "Listening on address");
188 TransportEvent::NewAddress {
189 listener_id,
190 listen_addr,
191 }
192 }
193 TransportEvent::AddressExpired {
194 listener_id,
195 mut listen_addr,
196 } => {
197 self.listener_protos
198 .get(&listener_id)
199 .expect("Protocol was inserted in Transport::listen_on.")
200 .append_on_addr(&mut listen_addr);
201 TransportEvent::AddressExpired {
202 listener_id,
203 listen_addr,
204 }
205 }
206 TransportEvent::ListenerError { listener_id, error } => TransportEvent::ListenerError {
207 listener_id,
208 error: Error::Transport(error),
209 },
210 TransportEvent::ListenerClosed {
211 listener_id,
212 reason,
213 } => {
214 self.listener_protos
215 .remove(&listener_id)
216 .expect("Protocol was inserted in Transport::listen_on.");
217 TransportEvent::ListenerClosed {
218 listener_id,
219 reason: reason.map_err(Error::Transport),
220 }
221 }
222 TransportEvent::Incoming {
223 listener_id,
224 upgrade,
225 mut local_addr,
226 mut send_back_addr,
227 } => {
228 let proto = self
229 .listener_protos
230 .get(&listener_id)
231 .expect("Protocol was inserted in Transport::listen_on.");
232 let use_tls = proto.use_tls();
233 proto.append_on_addr(&mut local_addr);
234 proto.append_on_addr(&mut send_back_addr);
235 let upgrade = self.map_upgrade(upgrade, send_back_addr.clone(), use_tls);
236 TransportEvent::Incoming {
237 listener_id,
238 upgrade,
239 local_addr,
240 send_back_addr,
241 }
242 }
243 };
244 Poll::Ready(event)
245 }
246}
247
248impl<T> Config<T>
249where
250 T: Transport + Send + Unpin + 'static,
251 T::Error: Send + 'static,
252 T::Dial: Send + 'static,
253 T::ListenerUpgrade: Send + 'static,
254 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
255{
256 fn do_dial(
257 &mut self,
258 addr: Multiaddr,
259 dial_opts: DialOpts,
260 ) -> Result<<Self as Transport>::Dial, TransportError<<Self as Transport>::Error>> {
261 let mut addr = match parse_ws_dial_addr(addr) {
262 Ok(addr) => addr,
263 Err(Error::InvalidMultiaddr(a)) => {
264 return Err(TransportError::MultiaddrNotSupported(a))
265 }
266 Err(e) => return Err(TransportError::Other(e)),
267 };
268
269 let mut remaining_redirects = self.max_redirects;
271
272 let transport = self.transport.clone();
273 let tls_config = self.tls_config.clone();
274 let max_redirects = self.max_redirects;
275
276 let future = async move {
277 loop {
278 match Self::dial_once(transport.clone(), addr, tls_config.clone(), dial_opts).await
279 {
280 Ok(Either::Left(redirect)) => {
281 if remaining_redirects == 0 {
282 tracing::debug!(%max_redirects, "Too many redirects");
283 return Err(Error::TooManyRedirects);
284 }
285 remaining_redirects -= 1;
286 addr = parse_ws_dial_addr(location_to_multiaddr(&redirect)?)?
287 }
288 Ok(Either::Right(conn)) => return Ok(conn),
289 Err(e) => return Err(e),
290 }
291 }
292 };
293
294 Ok(Box::pin(future))
295 }
296
297 async fn dial_once(
299 transport: Arc<Mutex<T>>,
300 addr: WsAddress,
301 tls_config: tls::Config,
302 dial_opts: DialOpts,
303 ) -> Result<Either<String, Connection<T::Output>>, Error<T::Error>> {
304 tracing::trace!(address=?addr, "Dialing websocket address");
305
306 let dial = transport
307 .lock()
308 .dial(addr.tcp_addr, dial_opts)
309 .map_err(|e| match e {
310 TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a),
311 TransportError::Other(e) => Error::Transport(e),
312 })?;
313
314 let stream = dial.map_err(Error::Transport).await?;
315 tracing::trace!(port=%addr.host_port, "TCP connection established");
316
317 let stream = if addr.use_tls {
318 tracing::trace!(?addr.server_name, "Starting TLS handshake");
320 let stream = tls_config
321 .client
322 .connect(addr.server_name.clone(), stream)
323 .map_err(|e| {
324 tracing::debug!(?addr.server_name, "TLS handshake failed: {}", e);
325 Error::Tls(tls::Error::from(e))
326 })
327 .await?;
328
329 let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Left(stream));
330 stream
331 } else {
332 future::Either::Right(stream)
334 };
335
336 tracing::trace!(port=%addr.host_port, "Sending websocket handshake");
337
338 let mut client = handshake::Client::new(stream, &addr.host_port, addr.path.as_ref());
339
340 match client
341 .handshake()
342 .map_err(|e| Error::Handshake(Box::new(e)))
343 .await?
344 {
345 handshake::ServerResponse::Redirect {
346 status_code,
347 location,
348 } => {
349 tracing::debug!(
350 %status_code,
351 %location,
352 "received redirect"
353 );
354 Ok(Either::Left(location))
355 }
356 handshake::ServerResponse::Rejected { status_code } => {
357 let msg = format!("server rejected handshake; status code = {status_code}");
358 Err(Error::Handshake(msg.into()))
359 }
360 handshake::ServerResponse::Accepted { .. } => {
361 tracing::trace!(port=%addr.host_port, "websocket handshake successful");
362 Ok(Either::Right(Connection::new(client.into_builder())))
363 }
364 }
365 }
366
367 fn map_upgrade(
368 &self,
369 upgrade: T::ListenerUpgrade,
370 remote_addr: Multiaddr,
371 use_tls: bool,
372 ) -> <Self as Transport>::ListenerUpgrade {
373 let remote_addr2 = remote_addr.clone(); let tls_config = self.tls_config.clone();
375 let max_size = self.max_data_size;
376
377 async move {
378 let stream = upgrade.map_err(Error::Transport).await?;
379 tracing::trace!(address=%remote_addr, "incoming connection from address");
380
381 let stream = if use_tls {
382 let server = tls_config
384 .server
385 .expect("for use_tls we checked server is not none");
386
387 tracing::trace!(address=%remote_addr, "awaiting TLS handshake with address");
388
389 let stream = server
390 .accept(stream)
391 .map_err(move |e| {
392 tracing::debug!(address=%remote_addr, "TLS handshake with address failed: {}", e);
393 Error::Tls(tls::Error::from(e))
394 })
395 .await?;
396
397 let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Right(stream));
398
399 stream
400 } else {
401 future::Either::Right(stream)
403 };
404
405 tracing::trace!(
406 address=%remote_addr2,
407 "receiving websocket handshake request from address"
408 );
409
410 let mut server = handshake::Server::new(stream);
411
412 let ws_key = {
413 let request = server
414 .receive_request()
415 .map_err(|e| Error::Handshake(Box::new(e)))
416 .await?;
417 request.key()
418 };
419
420 tracing::trace!(
421 address=%remote_addr2,
422 "accepting websocket handshake request from address"
423 );
424
425 let response = handshake::server::Response::Accept {
426 key: ws_key,
427 protocol: None,
428 };
429
430 server
431 .send_response(&response)
432 .map_err(|e| Error::Handshake(Box::new(e)))
433 .await?;
434
435 let conn = {
436 let mut builder = server.into_builder();
437 builder.set_max_message_size(max_size);
438 builder.set_max_frame_size(max_size);
439 Connection::new(builder)
440 };
441
442 Ok(conn)
443 }
444 .boxed()
445 }
446}
447
448#[derive(Debug, PartialEq)]
449pub(crate) enum WsListenProto<'a> {
450 Ws(Cow<'a, str>),
451 Wss(Cow<'a, str>),
452 TlsWs(Cow<'a, str>),
453}
454
455impl WsListenProto<'_> {
456 pub(crate) fn append_on_addr(&self, addr: &mut Multiaddr) {
457 match self {
458 WsListenProto::Ws(path) => {
459 addr.push(Protocol::Ws(path.clone()));
460 }
461 WsListenProto::Wss(path) => {
464 addr.push(Protocol::Wss(path.clone()));
465 }
466 WsListenProto::TlsWs(path) => {
467 addr.push(Protocol::Tls);
468 addr.push(Protocol::Ws(path.clone()));
469 }
470 }
471 }
472
473 pub(crate) fn use_tls(&self) -> bool {
474 match self {
475 WsListenProto::Ws(_) => false,
476 WsListenProto::Wss(_) => true,
477 WsListenProto::TlsWs(_) => true,
478 }
479 }
480
481 pub(crate) fn prefix(&self) -> &'static str {
482 match self {
483 WsListenProto::Ws(_) => "/ws",
484 WsListenProto::Wss(_) => "/wss",
485 WsListenProto::TlsWs(_) => "/tls/ws",
486 }
487 }
488}
489
490#[derive(Debug)]
491struct WsAddress {
492 host_port: String,
493 path: String,
494 server_name: ServerName<'static>,
495 use_tls: bool,
496 tcp_addr: Multiaddr,
497}
498
499fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
505 let mut protocols = addr.iter();
509 let mut ip = protocols.next();
510 let mut tcp = protocols.next();
511 let (host_port, server_name) = loop {
512 match (ip, tcp) {
513 (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
514 let server_name = ServerName::IpAddress(IpAddr::V4(ip).into());
515 break (format!("{ip}:{port}"), server_name);
516 }
517 (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
518 let server_name = ServerName::IpAddress(IpAddr::V6(ip).into());
519 break (format!("[{ip}]:{port}"), server_name);
520 }
521 (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
522 | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
523 | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port))) => {
524 break (format!("{h}:{port}"), tls::dns_name_ref(&h)?)
525 }
526 (Some(_), Some(p)) => {
527 ip = Some(p);
528 tcp = protocols.next();
529 }
530 _ => return Err(Error::InvalidMultiaddr(addr)),
531 }
532 };
533
534 let mut protocols = addr.clone();
538 let mut p2p = None;
539 let (use_tls, path) = loop {
540 match protocols.pop() {
541 p @ Some(Protocol::P2p(_)) => p2p = p,
542 Some(Protocol::Ws(path)) => match protocols.pop() {
543 Some(Protocol::Tls) => break (true, path.into_owned()),
544 Some(p) => {
545 protocols.push(p);
546 break (false, path.into_owned());
547 }
548 None => return Err(Error::InvalidMultiaddr(addr)),
549 },
550 Some(Protocol::Wss(path)) => break (true, path.into_owned()),
551 _ => return Err(Error::InvalidMultiaddr(addr)),
552 }
553 };
554
555 let tcp_addr = match p2p {
558 Some(p) => protocols.with(p),
559 None => protocols,
560 };
561
562 Ok(WsAddress {
563 host_port,
564 server_name,
565 path,
566 use_tls,
567 tcp_addr,
568 })
569}
570
571fn parse_ws_listen_addr(addr: &Multiaddr) -> Option<(Multiaddr, WsListenProto<'static>)> {
572 let mut inner_addr = addr.clone();
573
574 match inner_addr.pop()? {
575 Protocol::Wss(path) => Some((inner_addr, WsListenProto::Wss(path))),
576 Protocol::Ws(path) => match inner_addr.pop()? {
577 Protocol::Tls => Some((inner_addr, WsListenProto::TlsWs(path))),
578 p => {
579 inner_addr.push(p);
580 Some((inner_addr, WsListenProto::Ws(path)))
581 }
582 },
583 _ => None,
584 }
585}
586
587fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
589 match Url::parse(location) {
590 Ok(url) => {
591 let mut a = Multiaddr::empty();
592 match url.host() {
593 Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())),
594 Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)),
595 Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)),
596 None => return Err(Error::InvalidRedirectLocation),
597 }
598 if let Some(p) = url.port() {
599 a.push(Protocol::Tcp(p))
600 }
601 let s = url.scheme();
602 if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
603 a.push(Protocol::Tls);
604 a.push(Protocol::Ws(url.path().into()));
605 } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
606 a.push(Protocol::Ws(url.path().into()))
607 } else {
608 tracing::debug!(scheme=%s, "unsupported scheme");
609 return Err(Error::InvalidRedirectLocation);
610 }
611 Ok(a)
612 }
613 Err(e) => {
614 tracing::debug!("failed to parse url as multi-address: {:?}", e);
615 Err(Error::InvalidRedirectLocation)
616 }
617 }
618}
619
620pub struct Connection<T> {
622 receiver: BoxStream<'static, Result<Incoming, connection::Error>>,
623 sender: Pin<Box<dyn Sink<OutgoingData, Error = quicksink::Error<connection::Error>> + Send>>,
624 _marker: std::marker::PhantomData<T>,
625}
626
627#[derive(Debug, Clone)]
629pub enum Incoming {
630 Data(Data),
632 Pong(Vec<u8>),
634 Closed(CloseReason),
636}
637
638#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
640pub enum Data {
641 Text(Vec<u8>),
643 Binary(Vec<u8>),
645}
646
647impl Data {
648 pub fn into_bytes(self) -> Vec<u8> {
649 match self {
650 Data::Text(d) => d,
651 Data::Binary(d) => d,
652 }
653 }
654}
655
656impl AsRef<[u8]> for Data {
657 fn as_ref(&self) -> &[u8] {
658 match self {
659 Data::Text(d) => d,
660 Data::Binary(d) => d,
661 }
662 }
663}
664
665impl Incoming {
666 pub fn is_data(&self) -> bool {
667 self.is_binary() || self.is_text()
668 }
669
670 pub fn is_binary(&self) -> bool {
671 matches!(self, Incoming::Data(Data::Binary(_)))
672 }
673
674 pub fn is_text(&self) -> bool {
675 matches!(self, Incoming::Data(Data::Text(_)))
676 }
677
678 pub fn is_pong(&self) -> bool {
679 matches!(self, Incoming::Pong(_))
680 }
681
682 pub fn is_close(&self) -> bool {
683 matches!(self, Incoming::Closed(_))
684 }
685}
686
687#[derive(Debug, Clone)]
689pub enum OutgoingData {
690 Binary(Vec<u8>),
692 Ping(Vec<u8>),
694 Pong(Vec<u8>),
697}
698
699impl<T> fmt::Debug for Connection<T> {
700 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
701 f.write_str("Connection")
702 }
703}
704
705impl<T> Connection<T>
706where
707 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
708{
709 fn new(builder: connection::Builder<TlsOrPlain<T>>) -> Self {
710 let (sender, receiver) = builder.finish();
711 let sink = quicksink::make_sink(sender, |mut sender, action| async move {
712 match action {
713 quicksink::Action::Send(OutgoingData::Binary(x)) => {
714 sender.send_binary_mut(x).await?
715 }
716 quicksink::Action::Send(OutgoingData::Ping(x)) => {
717 let data = x[..].try_into().map_err(|_| {
718 io::Error::new(io::ErrorKind::InvalidInput, "PING data must be < 126 bytes")
719 })?;
720 sender.send_ping(data).await?
721 }
722 quicksink::Action::Send(OutgoingData::Pong(x)) => {
723 let data = x[..].try_into().map_err(|_| {
724 io::Error::new(io::ErrorKind::InvalidInput, "PONG data must be < 126 bytes")
725 })?;
726 sender.send_pong(data).await?
727 }
728 quicksink::Action::Flush => sender.flush().await?,
729 quicksink::Action::Close => sender.close().await?,
730 }
731 Ok(sender)
732 });
733 let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async {
734 match receiver.receive(&mut data).await {
735 Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => Some((
736 Ok(Incoming::Data(Data::Text(mem::take(&mut data)))),
737 (data, receiver),
738 )),
739 Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => Some((
740 Ok(Incoming::Data(Data::Binary(mem::take(&mut data)))),
741 (data, receiver),
742 )),
743 Ok(soketto::Incoming::Pong(pong)) => {
744 Some((Ok(Incoming::Pong(Vec::from(pong))), (data, receiver)))
745 }
746 Ok(soketto::Incoming::Closed(reason)) => {
747 Some((Ok(Incoming::Closed(reason)), (data, receiver)))
748 }
749 Err(connection::Error::Closed) => None,
750 Err(e) => Some((Err(e), (data, receiver))),
751 }
752 });
753 Connection {
754 receiver: stream.boxed(),
755 sender: Box::pin(sink),
756 _marker: std::marker::PhantomData,
757 }
758 }
759
760 pub fn send_data(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
762 self.send(OutgoingData::Binary(data))
763 }
764
765 pub fn send_ping(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
767 self.send(OutgoingData::Ping(data))
768 }
769
770 pub fn send_pong(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
772 self.send(OutgoingData::Pong(data))
773 }
774}
775
776impl<T> Stream for Connection<T>
777where
778 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
779{
780 type Item = io::Result<Incoming>;
781
782 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
783 let item = ready!(self.receiver.poll_next_unpin(cx));
784 let item = item.map(|result| result.map_err(|e| io::Error::new(io::ErrorKind::Other, e)));
785 Poll::Ready(item)
786 }
787}
788
789impl<T> Sink<OutgoingData> for Connection<T>
790where
791 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
792{
793 type Error = io::Error;
794
795 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
796 Pin::new(&mut self.sender)
797 .poll_ready(cx)
798 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
799 }
800
801 fn start_send(mut self: Pin<&mut Self>, item: OutgoingData) -> io::Result<()> {
802 Pin::new(&mut self.sender)
803 .start_send(item)
804 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
805 }
806
807 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
808 Pin::new(&mut self.sender)
809 .poll_flush(cx)
810 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
811 }
812
813 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
814 Pin::new(&mut self.sender)
815 .poll_close(cx)
816 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
817 }
818}
819
820#[cfg(test)]
821mod tests {
822 use std::io;
823
824 use libp2p_identity::PeerId;
825
826 use super::*;
827
828 #[test]
829 fn listen_addr() {
830 let tcp_addr = "/ip4/0.0.0.0/tcp/2222".parse::<Multiaddr>().unwrap();
831
832 let addr = tcp_addr
834 .clone()
835 .with(Protocol::Tls)
836 .with(Protocol::Ws("/".into()));
837 let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
838 assert_eq!(&inner_addr, &tcp_addr);
839 assert_eq!(proto, WsListenProto::TlsWs("/".into()));
840
841 let mut listen_addr = tcp_addr.clone();
842 proto.append_on_addr(&mut listen_addr);
843 assert_eq!(listen_addr, addr);
844
845 let addr = tcp_addr.clone().with(Protocol::Wss("/".into()));
847 let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
848 assert_eq!(&inner_addr, &tcp_addr);
849 assert_eq!(proto, WsListenProto::Wss("/".into()));
850
851 let mut listen_addr = tcp_addr.clone();
852 proto.append_on_addr(&mut listen_addr);
853 assert_eq!(listen_addr, addr);
854
855 let addr = tcp_addr.clone().with(Protocol::Ws("/".into()));
857 let (inner_addr, proto) = parse_ws_listen_addr(&addr).unwrap();
858 assert_eq!(&inner_addr, &tcp_addr);
859 assert_eq!(proto, WsListenProto::Ws("/".into()));
860
861 let mut listen_addr = tcp_addr.clone();
862 proto.append_on_addr(&mut listen_addr);
863 assert_eq!(listen_addr, addr);
864 }
865
866 #[test]
867 fn dial_addr() {
868 let peer_id = PeerId::random();
869
870 let addr = "/dns4/example.com/tcp/2222/tls/ws"
872 .parse::<Multiaddr>()
873 .unwrap();
874 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
875 assert_eq!(info.host_port, "example.com:2222");
876 assert_eq!(info.path, "/");
877 assert!(info.use_tls);
878 assert_eq!(info.server_name, "example.com".try_into().unwrap());
879 assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
880
881 let addr = format!("/dns4/example.com/tcp/2222/tls/ws/p2p/{peer_id}")
883 .parse()
884 .unwrap();
885 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
886 assert_eq!(info.host_port, "example.com:2222");
887 assert_eq!(info.path, "/");
888 assert!(info.use_tls);
889 assert_eq!(info.server_name, "example.com".try_into().unwrap());
890 assert_eq!(
891 info.tcp_addr,
892 format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
893 .parse()
894 .unwrap()
895 );
896
897 let addr = "/ip4/127.0.0.1/tcp/2222/tls/ws"
899 .parse::<Multiaddr>()
900 .unwrap();
901 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
902 assert_eq!(info.host_port, "127.0.0.1:2222");
903 assert_eq!(info.path, "/");
904 assert!(info.use_tls);
905 assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
906 assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
907
908 let addr = "/ip6/::1/tcp/2222/tls/ws".parse::<Multiaddr>().unwrap();
910 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
911 assert_eq!(info.host_port, "[::1]:2222");
912 assert_eq!(info.path, "/");
913 assert!(info.use_tls);
914 assert_eq!(info.server_name, "::1".try_into().unwrap());
915 assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
916
917 let addr = "/dns4/example.com/tcp/2222/wss"
919 .parse::<Multiaddr>()
920 .unwrap();
921 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
922 assert_eq!(info.host_port, "example.com:2222");
923 assert_eq!(info.path, "/");
924 assert!(info.use_tls);
925 assert_eq!(info.server_name, "example.com".try_into().unwrap());
926 assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
927
928 let addr = format!("/dns4/example.com/tcp/2222/wss/p2p/{peer_id}")
930 .parse()
931 .unwrap();
932 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
933 assert_eq!(info.host_port, "example.com:2222");
934 assert_eq!(info.path, "/");
935 assert!(info.use_tls);
936 assert_eq!(info.server_name, "example.com".try_into().unwrap());
937 assert_eq!(
938 info.tcp_addr,
939 format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
940 .parse()
941 .unwrap()
942 );
943
944 let addr = "/ip4/127.0.0.1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
946 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
947 assert_eq!(info.host_port, "127.0.0.1:2222");
948 assert_eq!(info.path, "/");
949 assert!(info.use_tls);
950 assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
951 assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
952
953 let addr = "/ip6/::1/tcp/2222/wss".parse::<Multiaddr>().unwrap();
955 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
956 assert_eq!(info.host_port, "[::1]:2222");
957 assert_eq!(info.path, "/");
958 assert!(info.use_tls);
959 assert_eq!(info.server_name, "::1".try_into().unwrap());
960 assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
961
962 let addr = "/dns4/example.com/tcp/2222/ws"
964 .parse::<Multiaddr>()
965 .unwrap();
966 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
967 assert_eq!(info.host_port, "example.com:2222");
968 assert_eq!(info.path, "/");
969 assert!(!info.use_tls);
970 assert_eq!(info.server_name, "example.com".try_into().unwrap());
971 assert_eq!(info.tcp_addr, "/dns4/example.com/tcp/2222".parse().unwrap());
972
973 let addr = format!("/dns4/example.com/tcp/2222/ws/p2p/{peer_id}")
975 .parse()
976 .unwrap();
977 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
978 assert_eq!(info.host_port, "example.com:2222");
979 assert_eq!(info.path, "/");
980 assert!(!info.use_tls);
981 assert_eq!(info.server_name, "example.com".try_into().unwrap());
982 assert_eq!(
983 info.tcp_addr,
984 format!("/dns4/example.com/tcp/2222/p2p/{peer_id}")
985 .parse()
986 .unwrap()
987 );
988
989 let addr = "/ip4/127.0.0.1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
991 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
992 assert_eq!(info.host_port, "127.0.0.1:2222");
993 assert_eq!(info.path, "/");
994 assert!(!info.use_tls);
995 assert_eq!(info.server_name, "127.0.0.1".try_into().unwrap());
996 assert_eq!(info.tcp_addr, "/ip4/127.0.0.1/tcp/2222".parse().unwrap());
997
998 let addr = "/ip6/::1/tcp/2222/ws".parse::<Multiaddr>().unwrap();
1000 let info = parse_ws_dial_addr::<io::Error>(addr).unwrap();
1001 assert_eq!(info.host_port, "[::1]:2222");
1002 assert_eq!(info.path, "/");
1003 assert!(!info.use_tls);
1004 assert_eq!(info.server_name, "::1".try_into().unwrap());
1005 assert_eq!(info.tcp_addr, "/ip6/::1/tcp/2222".parse().unwrap());
1006
1007 let addr = "/dnsaddr/example.com/tcp/2222/ws"
1009 .parse::<Multiaddr>()
1010 .unwrap();
1011 parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1012
1013 let addr = "/ip4/127.0.0.1/tcp/2222".parse::<Multiaddr>().unwrap();
1015 parse_ws_dial_addr::<io::Error>(addr).unwrap_err();
1016 }
1017}