libp2p_mdns/behaviour/
socket.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use std::{
22    io::Error,
23    net::{SocketAddr, UdpSocket},
24    task::{Context, Poll},
25};
26
27/// Interface that must be implemented by the different runtimes to use the [`UdpSocket`] in async
28/// mode
29#[allow(unreachable_pub)] // Users should not depend on this.
30pub trait AsyncSocket: Unpin + Send + 'static {
31    /// Create the async socket from the [`std::net::UdpSocket`]
32    fn from_std(socket: UdpSocket) -> std::io::Result<Self>
33    where
34        Self: Sized;
35
36    /// Attempts to receive a single packet on the socket
37    /// from the remote address to which it is connected.
38    fn poll_read(
39        &mut self,
40        _cx: &mut Context,
41        _buf: &mut [u8],
42    ) -> Poll<Result<(usize, SocketAddr), Error>>;
43
44    /// Attempts to send data on the socket to a given address.
45    fn poll_write(
46        &mut self,
47        _cx: &mut Context,
48        _packet: &[u8],
49        _to: SocketAddr,
50    ) -> Poll<Result<(), Error>>;
51}
52
53#[cfg(feature = "async-io")]
54pub(crate) mod asio {
55    use async_io::Async;
56    use futures::FutureExt;
57
58    use super::*;
59
60    /// AsyncIo UdpSocket
61    pub(crate) type AsyncUdpSocket = Async<UdpSocket>;
62    impl AsyncSocket for AsyncUdpSocket {
63        fn from_std(socket: UdpSocket) -> std::io::Result<Self> {
64            Async::new(socket)
65        }
66
67        fn poll_read(
68            &mut self,
69            cx: &mut Context,
70            buf: &mut [u8],
71        ) -> Poll<Result<(usize, SocketAddr), Error>> {
72            // Poll receive socket.
73            futures::ready!(self.poll_readable(cx))?;
74            match self.recv_from(buf).now_or_never() {
75                Some(data) => Poll::Ready(data),
76                None => Poll::Pending,
77            }
78        }
79
80        fn poll_write(
81            &mut self,
82            cx: &mut Context,
83            packet: &[u8],
84            to: SocketAddr,
85        ) -> Poll<Result<(), Error>> {
86            futures::ready!(self.poll_writable(cx))?;
87            match self.send_to(packet, to).now_or_never() {
88                Some(Ok(_)) => Poll::Ready(Ok(())),
89                Some(Err(err)) => Poll::Ready(Err(err)),
90                None => Poll::Pending,
91            }
92        }
93    }
94}
95
96#[cfg(feature = "tokio")]
97pub(crate) mod tokio {
98    use ::tokio::{io::ReadBuf, net::UdpSocket as TkUdpSocket};
99
100    use super::*;
101
102    /// Tokio ASync Socket`
103    pub(crate) type TokioUdpSocket = TkUdpSocket;
104    impl AsyncSocket for TokioUdpSocket {
105        fn from_std(socket: UdpSocket) -> std::io::Result<Self> {
106            socket.set_nonblocking(true)?;
107            TokioUdpSocket::from_std(socket)
108        }
109
110        fn poll_read(
111            &mut self,
112            cx: &mut Context,
113            buf: &mut [u8],
114        ) -> Poll<Result<(usize, SocketAddr), Error>> {
115            let mut rbuf = ReadBuf::new(buf);
116            match self.poll_recv_from(cx, &mut rbuf) {
117                Poll::Pending => Poll::Pending,
118                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
119                Poll::Ready(Ok(addr)) => Poll::Ready(Ok((rbuf.filled().len(), addr))),
120            }
121        }
122
123        fn poll_write(
124            &mut self,
125            cx: &mut Context,
126            packet: &[u8],
127            to: SocketAddr,
128        ) -> Poll<Result<(), Error>> {
129            match self.poll_send_to(cx, packet, to) {
130                Poll::Pending => Poll::Pending,
131                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
132                Poll::Ready(Ok(_len)) => Poll::Ready(Ok(())),
133            }
134        }
135    }
136}