libp2p_mdns/behaviour/
iface.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21mod dns;
22mod query;
23
24use std::{
25    collections::VecDeque,
26    future::Future,
27    io,
28    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket},
29    pin::Pin,
30    sync::{Arc, RwLock},
31    task::{Context, Poll},
32    time::{Duration, Instant},
33};
34
35use futures::{channel::mpsc, SinkExt, StreamExt};
36use libp2p_core::Multiaddr;
37use libp2p_identity::PeerId;
38use libp2p_swarm::ListenAddresses;
39use socket2::{Domain, Socket, Type};
40
41use self::{
42    dns::{build_query, build_query_response, build_service_discovery_response},
43    query::MdnsPacket,
44};
45use crate::{
46    behaviour::{socket::AsyncSocket, timer::Builder},
47    Config,
48};
49
50/// Initial interval for starting probe
51const INITIAL_TIMEOUT_INTERVAL: Duration = Duration::from_millis(500);
52
53#[derive(Debug, Clone)]
54enum ProbeState {
55    Probing(Duration),
56    Finished(Duration),
57}
58
59impl Default for ProbeState {
60    fn default() -> Self {
61        ProbeState::Probing(INITIAL_TIMEOUT_INTERVAL)
62    }
63}
64
65impl ProbeState {
66    fn interval(&self) -> &Duration {
67        match self {
68            ProbeState::Probing(query_interval) => query_interval,
69            ProbeState::Finished(query_interval) => query_interval,
70        }
71    }
72}
73
74/// An mDNS instance for a networking interface. To discover all peers when having multiple
75/// interfaces an [`InterfaceState`] is required for each interface.
76#[derive(Debug)]
77pub(crate) struct InterfaceState<U, T> {
78    /// Address this instance is bound to.
79    addr: IpAddr,
80    /// Receive socket.
81    recv_socket: U,
82    /// Send socket.
83    send_socket: U,
84
85    listen_addresses: Arc<RwLock<ListenAddresses>>,
86
87    query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,
88
89    /// Buffer used for receiving data from the main socket.
90    /// RFC6762 discourages packets larger than the interface MTU, but allows sizes of up to 9000
91    /// bytes, if it can be ensured that all participating devices can handle such large packets.
92    /// For computers with several interfaces and IP addresses responses can easily reach sizes in
93    /// the range of 3000 bytes, so 4096 seems sensible for now. For more information see
94    /// [rfc6762](https://tools.ietf.org/html/rfc6762#page-46).
95    recv_buffer: [u8; 4096],
96    /// Buffers pending to send on the main socket.
97    send_buffer: VecDeque<Vec<u8>>,
98    /// Discovery interval.
99    query_interval: Duration,
100    /// Discovery timer.
101    timeout: T,
102    /// Multicast address.
103    multicast_addr: IpAddr,
104    /// Discovered addresses.
105    discovered: VecDeque<(PeerId, Multiaddr, Instant)>,
106    /// TTL
107    ttl: Duration,
108    probe_state: ProbeState,
109    local_peer_id: PeerId,
110}
111
112impl<U, T> InterfaceState<U, T>
113where
114    U: AsyncSocket,
115    T: Builder + futures::Stream,
116{
117    /// Builds a new [`InterfaceState`].
118    pub(crate) fn new(
119        addr: IpAddr,
120        config: Config,
121        local_peer_id: PeerId,
122        listen_addresses: Arc<RwLock<ListenAddresses>>,
123        query_response_sender: mpsc::Sender<(PeerId, Multiaddr, Instant)>,
124    ) -> io::Result<Self> {
125        tracing::info!(address=%addr, "creating instance on iface address");
126        let recv_socket = match addr {
127            IpAddr::V4(addr) => {
128                let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(socket2::Protocol::UDP))?;
129                socket.set_reuse_address(true)?;
130                #[cfg(unix)]
131                socket.set_reuse_port(true)?;
132                socket.bind(&SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 5353).into())?;
133                socket.set_multicast_loop_v4(true)?;
134                socket.set_multicast_ttl_v4(255)?;
135                socket.join_multicast_v4(&crate::IPV4_MDNS_MULTICAST_ADDRESS, &addr)?;
136                U::from_std(UdpSocket::from(socket))?
137            }
138            IpAddr::V6(_) => {
139                let socket = Socket::new(Domain::IPV6, Type::DGRAM, Some(socket2::Protocol::UDP))?;
140                socket.set_reuse_address(true)?;
141                #[cfg(unix)]
142                socket.set_reuse_port(true)?;
143                socket.bind(&SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353).into())?;
144                socket.set_multicast_loop_v6(true)?;
145                // TODO: find interface matching addr.
146                socket.join_multicast_v6(&crate::IPV6_MDNS_MULTICAST_ADDRESS, 0)?;
147                U::from_std(UdpSocket::from(socket))?
148            }
149        };
150        let bind_addr = match addr {
151            IpAddr::V4(_) => SocketAddr::new(addr, 0),
152            IpAddr::V6(_addr) => {
153                // TODO: if-watch should return the scope_id of an address
154                // as a workaround we bind to unspecified, which means that
155                // this probably won't work when using multiple interfaces.
156                // SocketAddr::V6(SocketAddrV6::new(addr, 0, 0, scope_id))
157                SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0)
158            }
159        };
160        let send_socket = U::from_std(UdpSocket::bind(bind_addr)?)?;
161
162        // randomize timer to prevent all converging and firing at the same time.
163        let query_interval = {
164            use rand::Rng;
165            let mut rng = rand::thread_rng();
166            let jitter = rng.gen_range(0..100);
167            config.query_interval + Duration::from_millis(jitter)
168        };
169        let multicast_addr = match addr {
170            IpAddr::V4(_) => IpAddr::V4(crate::IPV4_MDNS_MULTICAST_ADDRESS),
171            IpAddr::V6(_) => IpAddr::V6(crate::IPV6_MDNS_MULTICAST_ADDRESS),
172        };
173        Ok(Self {
174            addr,
175            recv_socket,
176            send_socket,
177            listen_addresses,
178            query_response_sender,
179            recv_buffer: [0; 4096],
180            send_buffer: Default::default(),
181            discovered: Default::default(),
182            query_interval,
183            timeout: T::interval_at(Instant::now(), INITIAL_TIMEOUT_INTERVAL),
184            multicast_addr,
185            ttl: config.ttl,
186            probe_state: Default::default(),
187            local_peer_id,
188        })
189    }
190
191    pub(crate) fn reset_timer(&mut self) {
192        tracing::trace!(address=%self.addr, probe_state=?self.probe_state, "reset timer");
193        let interval = *self.probe_state.interval();
194        self.timeout = T::interval(interval);
195    }
196
197    fn mdns_socket(&self) -> SocketAddr {
198        SocketAddr::new(self.multicast_addr, 5353)
199    }
200}
201
202impl<U, T> Future for InterfaceState<U, T>
203where
204    U: AsyncSocket,
205    T: Builder + futures::Stream,
206{
207    type Output = ();
208
209    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
210        let this = self.get_mut();
211
212        loop {
213            // 1st priority: Low latency: Create packet ASAP after timeout.
214            if this.timeout.poll_next_unpin(cx).is_ready() {
215                tracing::trace!(address=%this.addr, "sending query on iface");
216                this.send_buffer.push_back(build_query());
217                tracing::trace!(address=%this.addr, probe_state=?this.probe_state, "tick");
218
219                // Stop to probe when the initial interval reach the query interval
220                if let ProbeState::Probing(interval) = this.probe_state {
221                    let interval = interval * 2;
222                    this.probe_state = if interval >= this.query_interval {
223                        ProbeState::Finished(this.query_interval)
224                    } else {
225                        ProbeState::Probing(interval)
226                    };
227                }
228
229                this.reset_timer();
230            }
231
232            // 2nd priority: Keep local buffers small: Send packets to remote.
233            if let Some(packet) = this.send_buffer.pop_front() {
234                match this.send_socket.poll_write(cx, &packet, this.mdns_socket()) {
235                    Poll::Ready(Ok(_)) => {
236                        tracing::trace!(address=%this.addr, "sent packet on iface address");
237                        continue;
238                    }
239                    Poll::Ready(Err(err)) => {
240                        tracing::error!(address=%this.addr, "error sending packet on iface address {}", err);
241                        continue;
242                    }
243                    Poll::Pending => {
244                        this.send_buffer.push_front(packet);
245                    }
246                }
247            }
248
249            // 3rd priority: Keep local buffers small: Return discovered addresses.
250            if this.query_response_sender.poll_ready_unpin(cx).is_ready() {
251                if let Some(discovered) = this.discovered.pop_front() {
252                    match this.query_response_sender.try_send(discovered) {
253                        Ok(()) => {}
254                        Err(e) if e.is_disconnected() => {
255                            return Poll::Ready(());
256                        }
257                        Err(e) => {
258                            this.discovered.push_front(e.into_inner());
259                        }
260                    }
261
262                    continue;
263                }
264            }
265
266            // 4th priority: Remote work: Answer incoming requests.
267            match this
268                .recv_socket
269                .poll_read(cx, &mut this.recv_buffer)
270                .map_ok(|(len, from)| MdnsPacket::new_from_bytes(&this.recv_buffer[..len], from))
271            {
272                Poll::Ready(Ok(Ok(Some(MdnsPacket::Query(query))))) => {
273                    tracing::trace!(
274                        address=%this.addr,
275                        remote_address=%query.remote_addr(),
276                        "received query from remote address on address"
277                    );
278
279                    this.send_buffer.extend(build_query_response(
280                        query.query_id(),
281                        this.local_peer_id,
282                        this.listen_addresses
283                            .read()
284                            .unwrap_or_else(|e| e.into_inner())
285                            .iter(),
286                        this.ttl,
287                    ));
288                    continue;
289                }
290                Poll::Ready(Ok(Ok(Some(MdnsPacket::Response(response))))) => {
291                    tracing::trace!(
292                        address=%this.addr,
293                        remote_address=%response.remote_addr(),
294                        "received response from remote address on address"
295                    );
296
297                    this.discovered
298                        .extend(response.extract_discovered(Instant::now(), this.local_peer_id));
299
300                    // Stop probing when we have a valid response
301                    if !this.discovered.is_empty() {
302                        this.probe_state = ProbeState::Finished(this.query_interval);
303                        this.reset_timer();
304                    }
305                    continue;
306                }
307                Poll::Ready(Ok(Ok(Some(MdnsPacket::ServiceDiscovery(disc))))) => {
308                    tracing::trace!(
309                        address=%this.addr,
310                        remote_address=%disc.remote_addr(),
311                        "received service discovery from remote address on address"
312                    );
313
314                    this.send_buffer
315                        .push_back(build_service_discovery_response(disc.query_id(), this.ttl));
316                    continue;
317                }
318                Poll::Ready(Err(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
319                    // No more bytes available on the socket to read
320                    continue;
321                }
322                Poll::Ready(Err(err)) => {
323                    tracing::error!("failed reading datagram: {}", err);
324                    return Poll::Ready(());
325                }
326                Poll::Ready(Ok(Err(err))) => {
327                    tracing::debug!("Parsing mdns packet failed: {:?}", err);
328                    continue;
329                }
330                Poll::Ready(Ok(Ok(None))) => continue,
331                Poll::Pending => {}
332            }
333
334            return Poll::Pending;
335        }
336    }
337}