libp2p_swarm_derive/
lib.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21#![recursion_limit = "256"]
22#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
23
24mod syn_ext;
25
26use heck::ToUpperCamelCase;
27use proc_macro::TokenStream;
28use quote::quote;
29use syn::{parse_macro_input, punctuated::Punctuated, Data, DataStruct, DeriveInput, Meta, Token};
30
31use crate::syn_ext::RequireStrLit;
32
33/// Generates a delegating `NetworkBehaviour` implementation for the struct this is used for. See
34/// the trait documentation for better description.
35#[proc_macro_derive(NetworkBehaviour, attributes(behaviour))]
36pub fn hello_macro_derive(input: TokenStream) -> TokenStream {
37    let ast = parse_macro_input!(input as DeriveInput);
38    build(&ast).unwrap_or_else(|e| e.to_compile_error().into())
39}
40
41/// The actual implementation.
42fn build(ast: &DeriveInput) -> syn::Result<TokenStream> {
43    match ast.data {
44        Data::Struct(ref s) => build_struct(ast, s),
45        Data::Enum(_) => Err(syn::Error::new_spanned(
46            ast,
47            "Cannot derive `NetworkBehaviour` on enums",
48        )),
49        Data::Union(_) => Err(syn::Error::new_spanned(
50            ast,
51            "Cannot derive `NetworkBehaviour` on union",
52        )),
53    }
54}
55
56/// The version for structs
57fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> syn::Result<TokenStream> {
58    let name = &ast.ident;
59    let (_, ty_generics, where_clause) = ast.generics.split_for_impl();
60    let BehaviourAttributes {
61        prelude_path,
62        user_specified_out_event,
63    } = parse_attributes(ast)?;
64
65    let multiaddr = quote! { #prelude_path::Multiaddr };
66    let trait_to_impl = quote! { #prelude_path::NetworkBehaviour };
67    let either_ident = quote! { #prelude_path::Either };
68    let network_behaviour_action = quote! { #prelude_path::ToSwarm };
69    let connection_handler = quote! { #prelude_path::ConnectionHandler };
70    let proto_select_ident = quote! { #prelude_path::ConnectionHandlerSelect };
71    let peer_id = quote! { #prelude_path::PeerId };
72    let connection_id = quote! { #prelude_path::ConnectionId };
73    let from_swarm = quote! { #prelude_path::FromSwarm };
74    let t_handler = quote! { #prelude_path::THandler };
75    let t_handler_in_event = quote! { #prelude_path::THandlerInEvent };
76    let t_handler_out_event = quote! { #prelude_path::THandlerOutEvent };
77    let endpoint = quote! { #prelude_path::Endpoint };
78    let connection_denied = quote! { #prelude_path::ConnectionDenied };
79    let port_use = quote! { #prelude_path::PortUse };
80
81    // Build the generics.
82    let impl_generics = {
83        let tp = ast.generics.type_params();
84        let lf = ast.generics.lifetimes();
85        let cst = ast.generics.const_params();
86        quote! {<#(#lf,)* #(#tp,)* #(#cst,)*>}
87    };
88
89    let (out_event_name, out_event_definition, out_event_from_clauses) = {
90        // If we find a `#[behaviour(to_swarm = "Foo")]` attribute on the
91        // struct, we set `Foo` as the out event. If not, the `ToSwarm` is
92        // generated.
93        match user_specified_out_event {
94            // User provided `ToSwarm`.
95            Some(name) => {
96                let definition = None;
97                let from_clauses = data_struct
98                    .fields
99                    .iter()
100                    .map(|field| {
101                        let ty = &field.ty;
102                        quote! {#name: From< <#ty as #trait_to_impl>::ToSwarm >}
103                    })
104                    .collect::<Vec<_>>();
105                (name, definition, from_clauses)
106            }
107            // User did not provide `ToSwarm`. Generate it.
108            None => {
109                let enum_name_str = ast.ident.to_string() + "Event";
110                let enum_name: syn::Type =
111                    syn::parse_str(&enum_name_str).expect("ident + `Event` is a valid type");
112                let definition = {
113                    let fields = data_struct.fields.iter().map(|field| {
114                        let variant: syn::Variant = syn::parse_str(
115                            &field
116                                .ident
117                                .clone()
118                                .expect("Fields of NetworkBehaviour implementation to be named.")
119                                .to_string()
120                                .to_upper_camel_case(),
121                        )
122                        .expect("uppercased field name to be a valid enum variant");
123                        let ty = &field.ty;
124                        (variant, ty)
125                    });
126
127                    let enum_variants = fields
128                        .clone()
129                        .map(|(variant, ty)| quote! {#variant(<#ty as #trait_to_impl>::ToSwarm)});
130
131                    let visibility = &ast.vis;
132
133                    let additional = fields
134                        .clone()
135                        .map(|(_variant, tp)| quote! { #tp : #trait_to_impl })
136                        .collect::<Vec<_>>();
137
138                    let additional_debug = fields
139                        .clone()
140                        .map(|(_variant, ty)| quote! { <#ty as #trait_to_impl>::ToSwarm : ::core::fmt::Debug })
141                        .collect::<Vec<_>>();
142
143                    let where_clause = {
144                        if let Some(where_clause) = where_clause {
145                            if where_clause.predicates.trailing_punct() {
146                                Some(quote! {#where_clause #(#additional),* })
147                            } else {
148                                Some(quote! {#where_clause, #(#additional),*})
149                            }
150                        } else if additional.is_empty() {
151                            None
152                        } else {
153                            Some(quote! {where #(#additional),*})
154                        }
155                    };
156
157                    let where_clause_debug = where_clause
158                        .as_ref()
159                        .map(|where_clause| quote! {#where_clause, #(#additional_debug),*});
160
161                    let match_variants = fields.map(|(variant, _ty)| variant);
162                    let msg = format!("`NetworkBehaviour::ToSwarm` produced by {name}.");
163
164                    Some(quote! {
165                        #[doc = #msg]
166                        #visibility enum #enum_name #impl_generics
167                            #where_clause
168                        {
169                            #(#enum_variants),*
170                        }
171
172                        impl #impl_generics ::core::fmt::Debug for #enum_name #ty_generics #where_clause_debug {
173                            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
174                                match &self {
175                                    #(#enum_name::#match_variants(event) => {
176                                        write!(f, "{}: {:?}", #enum_name_str, event)
177                                    }),*
178                                }
179                            }
180                        }
181                    })
182                };
183                let from_clauses = vec![];
184                (enum_name, definition, from_clauses)
185            }
186        }
187    };
188
189    // Build the `where ...` clause of the trait implementation.
190    let where_clause = {
191        let additional = data_struct
192            .fields
193            .iter()
194            .map(|field| {
195                let ty = &field.ty;
196                quote! {#ty: #trait_to_impl}
197            })
198            .chain(out_event_from_clauses)
199            .collect::<Vec<_>>();
200
201        if let Some(where_clause) = where_clause {
202            if where_clause.predicates.trailing_punct() {
203                Some(quote! {#where_clause #(#additional),* })
204            } else {
205                Some(quote! {#where_clause, #(#additional),*})
206            }
207        } else {
208            Some(quote! {where #(#additional),*})
209        }
210    };
211
212    // Build the list of statements to put in the body of `on_swarm_event()`.
213    let on_swarm_event_stmts = {
214        data_struct
215            .fields
216            .iter()
217            .enumerate()
218            .map(|(field_n, field)| match field.ident {
219                Some(ref i) => quote! {
220                    self.#i.on_swarm_event(event);
221                },
222                None => quote! {
223                    self.#field_n.on_swarm_event(event);
224                },
225            })
226    };
227
228    // Build the list of variants to put in the body of `on_connection_handler_event()`.
229    //
230    // The event type is a construction of nested `#either_ident`s of the events of the children.
231    // We call `on_connection_handler_event` on the corresponding child.
232    let on_node_event_stmts =
233        data_struct
234            .fields
235            .iter()
236            .enumerate()
237            .enumerate()
238            .map(|(enum_n, (field_n, field))| {
239                let mut elem = if enum_n != 0 {
240                    quote! { #either_ident::Right(ev) }
241                } else {
242                    quote! { ev }
243                };
244
245                for _ in 0..data_struct.fields.len() - 1 - enum_n {
246                    elem = quote! { #either_ident::Left(#elem) };
247                }
248
249                Some(match field.ident {
250                    Some(ref i) => quote! { #elem => {
251                    #trait_to_impl::on_connection_handler_event(&mut self.#i, peer_id, connection_id, ev) }},
252                    None => quote! { #elem => {
253                    #trait_to_impl::on_connection_handler_event(&mut self.#field_n, peer_id, connection_id, ev) }},
254                })
255            });
256
257    // The [`ConnectionHandler`] associated type.
258    let connection_handler_ty = {
259        let mut ph_ty = None;
260        for field in data_struct.fields.iter() {
261            let ty = &field.ty;
262            let field_info = quote! { #t_handler<#ty> };
263            match ph_ty {
264                Some(ev) => ph_ty = Some(quote! { #proto_select_ident<#ev, #field_info> }),
265                ref mut ev @ None => *ev = Some(field_info),
266            }
267        }
268        // ph_ty = Some(quote! )
269        ph_ty.unwrap_or(quote! {()}) // TODO: `!` instead
270    };
271
272    // The content of `handle_pending_inbound_connection`.
273    let handle_pending_inbound_connection_stmts =
274        data_struct
275            .fields
276            .iter()
277            .enumerate()
278            .map(|(field_n, field)| {
279                match field.ident {
280                    Some(ref i) => quote! {
281                        #trait_to_impl::handle_pending_inbound_connection(&mut self.#i, connection_id, local_addr, remote_addr)?;
282                    },
283                    None => quote! {
284                        #trait_to_impl::handle_pending_inbound_connection(&mut self.#field_n, connection_id, local_addr, remote_addr)?;
285                    }
286                }
287            });
288
289    // The content of `handle_established_inbound_connection`.
290    let handle_established_inbound_connection = {
291        let mut out_handler = None;
292
293        for (field_n, field) in data_struct.fields.iter().enumerate() {
294            let field_name = match field.ident {
295                Some(ref i) => quote! { self.#i },
296                None => quote! { self.#field_n },
297            };
298
299            let builder = quote! {
300                #field_name.handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)?
301            };
302
303            match out_handler {
304                Some(h) => out_handler = Some(quote! { #connection_handler::select(#h, #builder) }),
305                ref mut h @ None => *h = Some(builder),
306            }
307        }
308
309        out_handler.unwrap_or(quote! {()}) // TODO: See test `empty`.
310    };
311
312    // The content of `handle_pending_outbound_connection`.
313    let handle_pending_outbound_connection = {
314        let extend_stmts =
315            data_struct
316                .fields
317                .iter()
318                .enumerate()
319                .map(|(field_n, field)| {
320                    match field.ident {
321                        Some(ref i) => quote! {
322                            combined_addresses.extend(#trait_to_impl::handle_pending_outbound_connection(&mut self.#i, connection_id, maybe_peer, addresses, effective_role)?);
323                        },
324                        None => quote! {
325                            combined_addresses.extend(#trait_to_impl::handle_pending_outbound_connection(&mut self.#field_n, connection_id, maybe_peer, addresses, effective_role)?);
326                        }
327                    }
328                });
329
330        quote! {
331            let mut combined_addresses = vec![];
332
333            #(#extend_stmts)*
334
335            Ok(combined_addresses)
336        }
337    };
338
339    // The content of `handle_established_outbound_connection`.
340    let handle_established_outbound_connection = {
341        let mut out_handler = None;
342
343        for (field_n, field) in data_struct.fields.iter().enumerate() {
344            let field_name = match field.ident {
345                Some(ref i) => quote! { self.#i },
346                None => quote! { self.#field_n },
347            };
348
349            let builder = quote! {
350                #field_name.handle_established_outbound_connection(connection_id, peer, addr, role_override, port_use)?
351            };
352
353            match out_handler {
354                Some(h) => out_handler = Some(quote! { #connection_handler::select(#h, #builder) }),
355                ref mut h @ None => *h = Some(builder),
356            }
357        }
358
359        out_handler.unwrap_or(quote! {()}) // TODO: See test `empty`.
360    };
361
362    // List of statements to put in `poll()`.
363    //
364    // We poll each child one by one and wrap around the output.
365    let poll_stmts = data_struct
366        .fields
367        .iter()
368        .enumerate()
369        .map(|(field_n, field)| {
370            let field = field
371                .ident
372                .clone()
373                .expect("Fields of NetworkBehaviour implementation to be named.");
374
375            let mut wrapped_event = if field_n != 0 {
376                quote! { #either_ident::Right(event) }
377            } else {
378                quote! { event }
379            };
380            for _ in 0..data_struct.fields.len() - 1 - field_n {
381                wrapped_event = quote! { #either_ident::Left(#wrapped_event) };
382            }
383
384            // If the `NetworkBehaviour`'s `ToSwarm` is generated by the derive macro, wrap the sub
385            // `NetworkBehaviour` `ToSwarm` in the variant of the generated `ToSwarm`. If the
386            // `NetworkBehaviour`'s `ToSwarm` is provided by the user, use the corresponding `From`
387            // implementation.
388            let map_out_event = if out_event_definition.is_some() {
389                let event_variant: syn::Variant =
390                    syn::parse_str(&field.to_string().to_upper_camel_case())
391                        .expect("uppercased field name to be a valid enum variant name");
392                quote! { #out_event_name::#event_variant }
393            } else {
394                quote! { |e| e.into() }
395            };
396
397            let map_in_event = quote! { |event| #wrapped_event };
398
399            quote! {
400                match #trait_to_impl::poll(&mut self.#field, cx) {
401                    std::task::Poll::Ready(e) => return std::task::Poll::Ready(e.map_out(#map_out_event).map_in(#map_in_event)),
402                    std::task::Poll::Pending => {},
403                }
404            }
405        });
406
407    let out_event_reference = if out_event_definition.is_some() {
408        quote! { #out_event_name #ty_generics }
409    } else {
410        quote! { #out_event_name }
411    };
412
413    // Now the magic happens.
414    let final_quote = quote! {
415        #out_event_definition
416
417        impl #impl_generics #trait_to_impl for #name #ty_generics
418        #where_clause
419        {
420            type ConnectionHandler = #connection_handler_ty;
421            type ToSwarm = #out_event_reference;
422
423            #[allow(clippy::needless_question_mark)]
424            fn handle_pending_inbound_connection(
425                &mut self,
426                connection_id: #connection_id,
427                local_addr: &#multiaddr,
428                remote_addr: &#multiaddr,
429            ) -> std::result::Result<(), #connection_denied> {
430                #(#handle_pending_inbound_connection_stmts)*
431
432                Ok(())
433            }
434
435            #[allow(clippy::needless_question_mark)]
436            fn handle_established_inbound_connection(
437                &mut self,
438                connection_id: #connection_id,
439                peer: #peer_id,
440                local_addr: &#multiaddr,
441                remote_addr: &#multiaddr,
442            ) -> std::result::Result<#t_handler<Self>, #connection_denied> {
443                Ok(#handle_established_inbound_connection)
444            }
445
446            #[allow(clippy::needless_question_mark)]
447            fn handle_pending_outbound_connection(
448                &mut self,
449                connection_id: #connection_id,
450                maybe_peer: Option<#peer_id>,
451                addresses: &[#multiaddr],
452                effective_role: #endpoint,
453            ) -> std::result::Result<::std::vec::Vec<#multiaddr>, #connection_denied> {
454                #handle_pending_outbound_connection
455            }
456
457            #[allow(clippy::needless_question_mark)]
458            fn handle_established_outbound_connection(
459                &mut self,
460                connection_id: #connection_id,
461                peer: #peer_id,
462                addr: &#multiaddr,
463                role_override: #endpoint,
464                port_use: #port_use,
465            ) -> std::result::Result<#t_handler<Self>, #connection_denied> {
466                Ok(#handle_established_outbound_connection)
467            }
468
469            fn on_connection_handler_event(
470                &mut self,
471                peer_id: #peer_id,
472                connection_id: #connection_id,
473                event: #t_handler_out_event<Self>
474            ) {
475                match event {
476                    #(#on_node_event_stmts),*
477                }
478            }
479
480            fn poll(&mut self, cx: &mut std::task::Context) -> std::task::Poll<#network_behaviour_action<Self::ToSwarm, #t_handler_in_event<Self>>> {
481                #(#poll_stmts)*
482                std::task::Poll::Pending
483            }
484
485            fn on_swarm_event(&mut self, event: #from_swarm) {
486                #(#on_swarm_event_stmts)*
487            }
488        }
489    };
490
491    Ok(final_quote.into())
492}
493
494struct BehaviourAttributes {
495    prelude_path: syn::Path,
496    user_specified_out_event: Option<syn::Type>,
497}
498
499/// Parses the `value` of a key=value pair in the `#[behaviour]` attribute into the requested type.
500fn parse_attributes(ast: &DeriveInput) -> syn::Result<BehaviourAttributes> {
501    let mut attributes = BehaviourAttributes {
502        prelude_path: syn::parse_quote! { ::libp2p::swarm::derive_prelude },
503        user_specified_out_event: None,
504    };
505
506    for attr in ast
507        .attrs
508        .iter()
509        .filter(|attr| attr.path().is_ident("behaviour"))
510    {
511        let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
512
513        for meta in nested {
514            if meta.path().is_ident("prelude") {
515                let value = meta.require_name_value()?.value.require_str_lit()?;
516
517                attributes.prelude_path = syn::parse_str(&value)?;
518
519                continue;
520            }
521
522            if meta.path().is_ident("to_swarm") || meta.path().is_ident("out_event") {
523                let value = meta.require_name_value()?.value.require_str_lit()?;
524
525                attributes.user_specified_out_event = Some(syn::parse_str(&value)?);
526
527                continue;
528            }
529        }
530    }
531
532    Ok(attributes)
533}