1use std::{convert::Infallible, marker::PhantomData, sync::Arc};
2
3use libp2p_core::{
4    upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade},
5    Transport,
6};
7#[cfg(feature = "relay")]
8use libp2p_core::{Negotiated, UpgradeInfo};
9#[cfg(feature = "relay")]
10use libp2p_identity::PeerId;
11
12use super::*;
13use crate::SwarmBuilder;
14
15pub struct OtherTransportPhase<T> {
16    pub(crate) transport: T,
17}
18
19impl<Provider, T: AuthenticatedMultiplexedTransport>
20    SwarmBuilder<Provider, OtherTransportPhase<T>>
21{
22    pub fn with_other_transport<
23        Muxer: libp2p_core::muxing::StreamMuxer + Send + 'static,
24        OtherTransport: Transport<Output = (libp2p_identity::PeerId, Muxer)> + Send + Unpin + 'static,
25        R: TryIntoTransport<OtherTransport>,
26    >(
27        self,
28        constructor: impl FnOnce(&libp2p_identity::Keypair) -> R,
29    ) -> Result<
30        SwarmBuilder<Provider, OtherTransportPhase<impl AuthenticatedMultiplexedTransport>>,
31        R::Error,
32    >
33    where
34        <OtherTransport as Transport>::Error: Send + Sync + 'static,
35        <OtherTransport as Transport>::Dial: Send,
36        <OtherTransport as Transport>::ListenerUpgrade: Send,
37        <Muxer as libp2p_core::muxing::StreamMuxer>::Substream: Send,
38        <Muxer as libp2p_core::muxing::StreamMuxer>::Error: Send + Sync,
39    {
40        Ok(SwarmBuilder {
41            phase: OtherTransportPhase {
42                transport: self
43                    .phase
44                    .transport
45                    .or_transport(
46                        constructor(&self.keypair)
47                            .try_into_transport()?
48                            .map(|(peer_id, conn), _| (peer_id, StreamMuxerBox::new(conn))),
49                    )
50                    .map(|either, _| either.into_inner()),
51            },
52            keypair: self.keypair,
53            phantom: PhantomData,
54        })
55    }
56
57    pub(crate) fn without_any_other_transports(self) -> SwarmBuilder<Provider, DnsPhase<T>> {
58        SwarmBuilder {
59            keypair: self.keypair,
60            phantom: PhantomData,
61            phase: DnsPhase {
62                transport: self.phase.transport,
63            },
64        }
65    }
66}
67
68#[cfg(all(not(target_arch = "wasm32"), feature = "tokio", feature = "dns"))]
70impl<T: AuthenticatedMultiplexedTransport>
71    SwarmBuilder<super::provider::Tokio, OtherTransportPhase<T>>
72{
73    pub fn with_dns(
74        self,
75    ) -> Result<
76        SwarmBuilder<
77            super::provider::Tokio,
78            WebsocketPhase<impl AuthenticatedMultiplexedTransport>,
79        >,
80        std::io::Error,
81    > {
82        self.without_any_other_transports().with_dns()
83    }
84}
85#[cfg(all(not(target_arch = "wasm32"), feature = "tokio", feature = "dns"))]
86impl<T: AuthenticatedMultiplexedTransport>
87    SwarmBuilder<super::provider::Tokio, OtherTransportPhase<T>>
88{
89    pub fn with_dns_config(
90        self,
91        cfg: libp2p_dns::ResolverConfig,
92        opts: libp2p_dns::ResolverOpts,
93    ) -> SwarmBuilder<super::provider::Tokio, WebsocketPhase<impl AuthenticatedMultiplexedTransport>>
94    {
95        self.without_any_other_transports()
96            .with_dns_config(cfg, opts)
97    }
98}
99#[cfg(feature = "relay")]
100impl<T: AuthenticatedMultiplexedTransport, Provider>
101    SwarmBuilder<Provider, OtherTransportPhase<T>>
102{
103    pub fn with_relay_client<SecUpgrade, SecStream, SecError, MuxUpgrade, MuxStream, MuxError>(
105        self,
106        security_upgrade: SecUpgrade,
107        multiplexer_upgrade: MuxUpgrade,
108    ) -> Result<
109        SwarmBuilder<
110            Provider,
111            BandwidthMetricsPhase<impl AuthenticatedMultiplexedTransport, libp2p_relay::client::Behaviour>,
112        >,
113        SecUpgrade::Error,
114        > where
115
116        SecStream: futures::AsyncRead + futures::AsyncWrite + Unpin + Send + 'static,
117        SecError: std::error::Error + Send + Sync + 'static,
118        SecUpgrade: IntoSecurityUpgrade<libp2p_relay::client::Connection>,
119        SecUpgrade::Upgrade: InboundConnectionUpgrade<Negotiated<libp2p_relay::client::Connection>, Output = (PeerId, SecStream), Error = SecError> + OutboundConnectionUpgrade<Negotiated<libp2p_relay::client::Connection>, Output = (PeerId, SecStream), Error = SecError> + Clone + Send + 'static,
120    <SecUpgrade::Upgrade as InboundConnectionUpgrade<Negotiated<libp2p_relay::client::Connection>>>::Future: Send,
121    <SecUpgrade::Upgrade as OutboundConnectionUpgrade<Negotiated<libp2p_relay::client::Connection>>>::Future: Send,
122    <<<SecUpgrade as IntoSecurityUpgrade<libp2p_relay::client::Connection>>::Upgrade as UpgradeInfo>::InfoIter as IntoIterator>::IntoIter: Send,
123    <<SecUpgrade as IntoSecurityUpgrade<libp2p_relay::client::Connection>>::Upgrade as UpgradeInfo>::Info: Send,
124
125        MuxStream: libp2p_core::muxing::StreamMuxer + Send + 'static,
126        MuxStream::Substream: Send + 'static,
127        MuxStream::Error: Send + Sync + 'static,
128        MuxUpgrade: IntoMultiplexerUpgrade<SecStream>,
129        MuxUpgrade::Upgrade: InboundConnectionUpgrade<Negotiated<SecStream>, Output = MuxStream, Error = MuxError> + OutboundConnectionUpgrade<Negotiated<SecStream>, Output = MuxStream, Error = MuxError> + Clone + Send + 'static,
130    <MuxUpgrade::Upgrade as InboundConnectionUpgrade<Negotiated<SecStream>>>::Future: Send,
131    <MuxUpgrade::Upgrade as OutboundConnectionUpgrade<Negotiated<SecStream>>>::Future: Send,
132        MuxError: std::error::Error + Send + Sync + 'static,
133    <<<MuxUpgrade as IntoMultiplexerUpgrade<SecStream>>::Upgrade as UpgradeInfo>::InfoIter as IntoIterator>::IntoIter: Send,
134    <<MuxUpgrade as IntoMultiplexerUpgrade<SecStream>>::Upgrade as UpgradeInfo>::Info: Send,
135    {
136        self.without_any_other_transports()
137            .without_dns()
138            .without_websocket()
139            .with_relay_client(security_upgrade, multiplexer_upgrade)
140    }
141}
142#[cfg(feature = "metrics")]
143impl<Provider, T: AuthenticatedMultiplexedTransport>
144    SwarmBuilder<Provider, OtherTransportPhase<T>>
145{
146    pub fn with_bandwidth_metrics(
147        self,
148        registry: &mut libp2p_metrics::Registry,
149    ) -> SwarmBuilder<
150        Provider,
151        BehaviourPhase<impl AuthenticatedMultiplexedTransport, NoRelayBehaviour>,
152    > {
153        self.without_any_other_transports()
154            .without_dns()
155            .without_websocket()
156            .without_relay()
157            .with_bandwidth_metrics(registry)
158    }
159}
160impl<Provider, T: AuthenticatedMultiplexedTransport>
161    SwarmBuilder<Provider, OtherTransportPhase<T>>
162{
163    pub fn with_behaviour<B, R: TryIntoBehaviour<B>>(
164        self,
165        constructor: impl FnOnce(&libp2p_identity::Keypair) -> R,
166    ) -> Result<SwarmBuilder<Provider, SwarmPhase<T, B>>, R::Error> {
167        self.without_any_other_transports()
168            .without_dns()
169            .without_websocket()
170            .without_relay()
171            .with_behaviour(constructor)
172    }
173}
174
175pub trait TryIntoTransport<T>: private::Sealed<Self::Error> {
176    type Error;
177
178    fn try_into_transport(self) -> Result<T, Self::Error>;
179}
180
181impl<T: Transport> TryIntoTransport<T> for T {
182    type Error = Infallible;
183
184    fn try_into_transport(self) -> Result<T, Self::Error> {
185        Ok(self)
186    }
187}
188
189impl<T: Transport> TryIntoTransport<T> for Result<T, Box<dyn std::error::Error + Send + Sync>> {
190    type Error = TransportError;
191
192    fn try_into_transport(self) -> Result<T, Self::Error> {
193        self.map_err(TransportError)
194    }
195}
196
197mod private {
198    pub trait Sealed<Error> {}
199}
200
201impl<T: Transport> private::Sealed<Infallible> for T {}
202
203impl<T: Transport> private::Sealed<TransportError>
204    for Result<T, Box<dyn std::error::Error + Send + Sync>>
205{
206}
207
208#[derive(Debug, thiserror::Error)]
209#[error("failed to build transport: {0}")]
210pub struct TransportError(Box<dyn std::error::Error + Send + Sync + 'static>);