1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
24
25pub mod error;
26pub mod framed;
27mod quicksink;
28pub mod tls;
29
30use std::{
31 io,
32 pin::Pin,
33 task::{Context, Poll},
34};
35
36use error::Error;
37use framed::{Connection, Incoming};
38use futures::{future::BoxFuture, prelude::*, ready};
39use libp2p_core::{
40 connection::ConnectedPoint,
41 multiaddr::Multiaddr,
42 transport::{map::MapFuture, DialOpts, ListenerId, TransportError, TransportEvent},
43 Transport,
44};
45use rw_stream_sink::RwStreamSink;
46
47#[deprecated = "Use `Config` instead"]
138pub type WsConfig<Transport> = Config<Transport>;
139
140#[derive(Debug)]
141pub struct Config<T: Transport>
142where
143 T: Transport,
144 T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static,
145{
146 transport: libp2p_core::transport::map::Map<framed::Config<T>, WrapperFn<T::Output>>,
147}
148
149impl<T: Transport> Config<T>
150where
151 T: Transport + Send + Unpin + 'static,
152 T::Error: Send + 'static,
153 T::Dial: Send + 'static,
154 T::ListenerUpgrade: Send + 'static,
155 T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static,
156{
157 pub fn new(transport: T) -> Self {
166 Self {
167 transport: framed::Config::new(transport).map(wrap_connection as WrapperFn<T::Output>),
168 }
169 }
170
171 pub fn max_redirects(&self) -> u8 {
173 self.transport.inner().max_redirects()
174 }
175
176 pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
178 self.transport.inner_mut().set_max_redirects(max);
179 self
180 }
181
182 pub fn max_data_size(&self) -> usize {
184 self.transport.inner().max_data_size()
185 }
186
187 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
189 self.transport.inner_mut().set_max_data_size(size);
190 self
191 }
192
193 pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
195 self.transport.inner_mut().set_tls_config(c);
196 self
197 }
198}
199
200impl<T> Transport for Config<T>
201where
202 T: Transport + Send + Unpin + 'static,
203 T::Error: Send + 'static,
204 T::Dial: Send + 'static,
205 T::ListenerUpgrade: Send + 'static,
206 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
207{
208 type Output = RwStreamSink<BytesConnection<T::Output>>;
209 type Error = Error<T::Error>;
210 type ListenerUpgrade = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
211 type Dial = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
212
213 fn listen_on(
214 &mut self,
215 id: ListenerId,
216 addr: Multiaddr,
217 ) -> Result<(), TransportError<Self::Error>> {
218 self.transport.listen_on(id, addr)
219 }
220
221 fn remove_listener(&mut self, id: ListenerId) -> bool {
222 self.transport.remove_listener(id)
223 }
224
225 fn dial(
226 &mut self,
227 addr: Multiaddr,
228 opts: DialOpts,
229 ) -> Result<Self::Dial, TransportError<Self::Error>> {
230 self.transport.dial(addr, opts)
231 }
232
233 fn poll(
234 mut self: Pin<&mut Self>,
235 cx: &mut Context<'_>,
236 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
237 Pin::new(&mut self.transport).poll(cx)
238 }
239}
240
241pub type InnerFuture<T, E> = BoxFuture<'static, Result<Connection<T>, Error<E>>>;
243
244pub type WrapperFn<T> = fn(Connection<T>, ConnectedPoint) -> RwStreamSink<BytesConnection<T>>;
246
247fn wrap_connection<T>(c: Connection<T>, _: ConnectedPoint) -> RwStreamSink<BytesConnection<T>>
250where
251 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
252{
253 RwStreamSink::new(BytesConnection(c))
254}
255
256#[derive(Debug)]
258pub struct BytesConnection<T>(Connection<T>);
259
260impl<T> Stream for BytesConnection<T>
261where
262 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
263{
264 type Item = io::Result<Vec<u8>>;
265
266 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
267 loop {
268 if let Some(item) = ready!(self.0.try_poll_next_unpin(cx)?) {
269 if let Incoming::Data(payload) = item {
270 return Poll::Ready(Some(Ok(payload.into_bytes())));
271 }
272 } else {
273 return Poll::Ready(None);
274 }
275 }
276 }
277}
278
279impl<T> Sink<Vec<u8>> for BytesConnection<T>
280where
281 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
282{
283 type Error = io::Error;
284
285 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
286 Pin::new(&mut self.0).poll_ready(cx)
287 }
288
289 fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> io::Result<()> {
290 Pin::new(&mut self.0).start_send(framed::OutgoingData::Binary(item))
291 }
292
293 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
294 Pin::new(&mut self.0).poll_flush(cx)
295 }
296
297 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
298 Pin::new(&mut self.0).poll_close(cx)
299 }
300}
301
302#[cfg(test)]
305mod tests {
306 use futures::prelude::*;
307 use libp2p_core::{
308 multiaddr::Protocol,
309 transport::{DialOpts, ListenerId, PortUse},
310 Endpoint, Multiaddr, Transport,
311 };
312 use libp2p_identity::PeerId;
313 use libp2p_tcp as tcp;
314
315 use super::Config;
316
317 #[test]
318 fn dialer_connects_to_listener_ipv4() {
319 let a = "/ip4/127.0.0.1/tcp/0/ws".parse().unwrap();
320 futures::executor::block_on(connect(a))
321 }
322
323 #[test]
324 fn dialer_connects_to_listener_ipv6() {
325 let a = "/ip6/::1/tcp/0/ws".parse().unwrap();
326 futures::executor::block_on(connect(a))
327 }
328
329 fn new_ws_config() -> Config<tcp::async_io::Transport> {
330 Config::new(tcp::async_io::Transport::new(tcp::Config::default()))
331 }
332
333 async fn connect(listen_addr: Multiaddr) {
334 let mut ws_config = new_ws_config().boxed();
335 ws_config
336 .listen_on(ListenerId::next(), listen_addr)
337 .expect("listener");
338
339 let addr = ws_config
340 .next()
341 .await
342 .expect("no error")
343 .into_new_address()
344 .expect("listen address");
345
346 assert_eq!(Some(Protocol::Ws("/".into())), addr.iter().nth(2));
347 assert_ne!(Some(Protocol::Tcp(0)), addr.iter().nth(1));
348
349 let inbound = async move {
350 let (conn, _addr) = ws_config
351 .select_next_some()
352 .map(|ev| ev.into_incoming())
353 .await
354 .unwrap();
355 conn.await
356 };
357
358 let outbound = new_ws_config()
359 .boxed()
360 .dial(
361 addr.with(Protocol::P2p(PeerId::random())),
362 DialOpts {
363 role: Endpoint::Dialer,
364 port_use: PortUse::New,
365 },
366 )
367 .unwrap();
368
369 let (a, b) = futures::join!(inbound, outbound);
370 a.and(b).unwrap();
371 }
372}