1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
24
25pub mod error;
26pub mod framed;
27mod quicksink;
28pub mod tls;
29
30use std::{
31 io,
32 pin::Pin,
33 task::{Context, Poll},
34};
35
36use error::Error;
37use framed::{Connection, Incoming};
38use futures::{future::BoxFuture, prelude::*, ready};
39use libp2p_core::{
40 connection::ConnectedPoint,
41 multiaddr::Multiaddr,
42 transport::{map::MapFuture, DialOpts, ListenerId, TransportError, TransportEvent},
43 Transport,
44};
45use rw_stream_sink::RwStreamSink;
46
47#[deprecated = "Use `Config` instead"]
140pub type WsConfig<Transport> = Config<Transport>;
141
142#[derive(Debug)]
143pub struct Config<T: Transport>
144where
145 T: Transport,
146 T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static,
147{
148 transport: libp2p_core::transport::map::Map<framed::Config<T>, WrapperFn<T::Output>>,
149}
150
151impl<T: Transport> Config<T>
152where
153 T: Transport + Send + Unpin + 'static,
154 T::Error: Send + 'static,
155 T::Dial: Send + 'static,
156 T::ListenerUpgrade: Send + 'static,
157 T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static,
158{
159 pub fn new(transport: T) -> Self {
168 Self {
169 transport: framed::Config::new(transport).map(wrap_connection as WrapperFn<T::Output>),
170 }
171 }
172
173 pub fn max_redirects(&self) -> u8 {
175 self.transport.inner().max_redirects()
176 }
177
178 pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
180 self.transport.inner_mut().set_max_redirects(max);
181 self
182 }
183
184 pub fn max_data_size(&self) -> usize {
186 self.transport.inner().max_data_size()
187 }
188
189 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
191 self.transport.inner_mut().set_max_data_size(size);
192 self
193 }
194
195 pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
197 self.transport.inner_mut().set_tls_config(c);
198 self
199 }
200}
201
202impl<T> Transport for Config<T>
203where
204 T: Transport + Send + Unpin + 'static,
205 T::Error: Send + 'static,
206 T::Dial: Send + 'static,
207 T::ListenerUpgrade: Send + 'static,
208 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
209{
210 type Output = RwStreamSink<BytesConnection<T::Output>>;
211 type Error = Error<T::Error>;
212 type ListenerUpgrade = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
213 type Dial = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
214
215 fn listen_on(
216 &mut self,
217 id: ListenerId,
218 addr: Multiaddr,
219 ) -> Result<(), TransportError<Self::Error>> {
220 self.transport.listen_on(id, addr)
221 }
222
223 fn remove_listener(&mut self, id: ListenerId) -> bool {
224 self.transport.remove_listener(id)
225 }
226
227 fn dial(
228 &mut self,
229 addr: Multiaddr,
230 opts: DialOpts,
231 ) -> Result<Self::Dial, TransportError<Self::Error>> {
232 self.transport.dial(addr, opts)
233 }
234
235 fn poll(
236 mut self: Pin<&mut Self>,
237 cx: &mut Context<'_>,
238 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
239 Pin::new(&mut self.transport).poll(cx)
240 }
241}
242
243pub type InnerFuture<T, E> = BoxFuture<'static, Result<Connection<T>, Error<E>>>;
245
246pub type WrapperFn<T> = fn(Connection<T>, ConnectedPoint) -> RwStreamSink<BytesConnection<T>>;
248
249fn wrap_connection<T>(c: Connection<T>, _: ConnectedPoint) -> RwStreamSink<BytesConnection<T>>
252where
253 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
254{
255 RwStreamSink::new(BytesConnection(c))
256}
257
258#[derive(Debug)]
260pub struct BytesConnection<T>(Connection<T>);
261
262impl<T> Stream for BytesConnection<T>
263where
264 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
265{
266 type Item = io::Result<Vec<u8>>;
267
268 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
269 loop {
270 if let Some(item) = ready!(self.0.try_poll_next_unpin(cx)?) {
271 if let Incoming::Data(payload) = item {
272 return Poll::Ready(Some(Ok(payload.into_bytes())));
273 }
274 } else {
275 return Poll::Ready(None);
276 }
277 }
278 }
279}
280
281impl<T> Sink<Vec<u8>> for BytesConnection<T>
282where
283 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
284{
285 type Error = io::Error;
286
287 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
288 Pin::new(&mut self.0).poll_ready(cx)
289 }
290
291 fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> io::Result<()> {
292 Pin::new(&mut self.0).start_send(framed::OutgoingData::Binary(item))
293 }
294
295 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
296 Pin::new(&mut self.0).poll_flush(cx)
297 }
298
299 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
300 Pin::new(&mut self.0).poll_close(cx)
301 }
302}
303
304#[cfg(test)]
307mod tests {
308 use futures::prelude::*;
309 use libp2p_core::{
310 multiaddr::Protocol,
311 transport::{DialOpts, ListenerId, PortUse},
312 Endpoint, Multiaddr, Transport,
313 };
314 use libp2p_identity::PeerId;
315 use libp2p_tcp as tcp;
316
317 use super::Config;
318
319 #[test]
320 fn dialer_connects_to_listener_ipv4() {
321 let a = "/ip4/127.0.0.1/tcp/0/ws".parse().unwrap();
322 futures::executor::block_on(connect(a))
323 }
324
325 #[test]
326 fn dialer_connects_to_listener_ipv6() {
327 let a = "/ip6/::1/tcp/0/ws".parse().unwrap();
328 futures::executor::block_on(connect(a))
329 }
330
331 fn new_ws_config() -> Config<tcp::async_io::Transport> {
332 Config::new(tcp::async_io::Transport::new(tcp::Config::default()))
333 }
334
335 async fn connect(listen_addr: Multiaddr) {
336 let mut ws_config = new_ws_config().boxed();
337 ws_config
338 .listen_on(ListenerId::next(), listen_addr)
339 .expect("listener");
340
341 let addr = ws_config
342 .next()
343 .await
344 .expect("no error")
345 .into_new_address()
346 .expect("listen address");
347
348 assert_eq!(Some(Protocol::Ws("/".into())), addr.iter().nth(2));
349 assert_ne!(Some(Protocol::Tcp(0)), addr.iter().nth(1));
350
351 let inbound = async move {
352 let (conn, _addr) = ws_config
353 .select_next_some()
354 .map(|ev| ev.into_incoming())
355 .await
356 .unwrap();
357 conn.await
358 };
359
360 let outbound = new_ws_config()
361 .boxed()
362 .dial(
363 addr.with(Protocol::P2p(PeerId::random())),
364 DialOpts {
365 role: Endpoint::Dialer,
366 port_use: PortUse::New,
367 },
368 )
369 .unwrap();
370
371 let (a, b) = futures::join!(inbound, outbound);
372 a.and(b).unwrap();
373 }
374}