1#![recursion_limit = "256"]
22#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
23
24mod syn_ext;
25
26use heck::ToUpperCamelCase;
27use proc_macro::TokenStream;
28use quote::quote;
29use syn::{parse_macro_input, punctuated::Punctuated, Data, DataStruct, DeriveInput, Meta, Token};
30
31use crate::syn_ext::RequireStrLit;
32
33#[proc_macro_derive(NetworkBehaviour, attributes(behaviour))]
36pub fn hello_macro_derive(input: TokenStream) -> TokenStream {
37 let ast = parse_macro_input!(input as DeriveInput);
38 build(&ast).unwrap_or_else(|e| e.to_compile_error().into())
39}
40
41fn build(ast: &DeriveInput) -> syn::Result<TokenStream> {
43 match ast.data {
44 Data::Struct(ref s) => build_struct(ast, s),
45 Data::Enum(_) => Err(syn::Error::new_spanned(
46 ast,
47 "Cannot derive `NetworkBehaviour` on enums",
48 )),
49 Data::Union(_) => Err(syn::Error::new_spanned(
50 ast,
51 "Cannot derive `NetworkBehaviour` on union",
52 )),
53 }
54}
55
56fn build_struct(ast: &DeriveInput, data_struct: &DataStruct) -> syn::Result<TokenStream> {
58 let name = &ast.ident;
59 let (_, ty_generics, where_clause) = ast.generics.split_for_impl();
60 let BehaviourAttributes {
61 prelude_path,
62 user_specified_out_event,
63 } = parse_attributes(ast)?;
64
65 let multiaddr = quote! { #prelude_path::Multiaddr };
66 let trait_to_impl = quote! { #prelude_path::NetworkBehaviour };
67 let either_ident = quote! { #prelude_path::Either };
68 let network_behaviour_action = quote! { #prelude_path::ToSwarm };
69 let connection_handler = quote! { #prelude_path::ConnectionHandler };
70 let proto_select_ident = quote! { #prelude_path::ConnectionHandlerSelect };
71 let peer_id = quote! { #prelude_path::PeerId };
72 let connection_id = quote! { #prelude_path::ConnectionId };
73 let from_swarm = quote! { #prelude_path::FromSwarm };
74 let t_handler = quote! { #prelude_path::THandler };
75 let t_handler_in_event = quote! { #prelude_path::THandlerInEvent };
76 let t_handler_out_event = quote! { #prelude_path::THandlerOutEvent };
77 let endpoint = quote! { #prelude_path::Endpoint };
78 let connection_denied = quote! { #prelude_path::ConnectionDenied };
79 let port_use = quote! { #prelude_path::PortUse };
80
81 let impl_generics = {
83 let tp = ast.generics.type_params();
84 let lf = ast.generics.lifetimes();
85 let cst = ast.generics.const_params();
86 quote! {<#(#lf,)* #(#tp,)* #(#cst,)*>}
87 };
88
89 let (out_event_name, out_event_definition, out_event_from_clauses) = {
90 match user_specified_out_event {
94 Some(name) => {
96 let definition = None;
97 let from_clauses = data_struct
98 .fields
99 .iter()
100 .map(|field| {
101 let ty = &field.ty;
102 quote! {#name: From< <#ty as #trait_to_impl>::ToSwarm >}
103 })
104 .collect::<Vec<_>>();
105 (name, definition, from_clauses)
106 }
107 None => {
109 let enum_name_str = ast.ident.to_string() + "Event";
110 let enum_name: syn::Type =
111 syn::parse_str(&enum_name_str).expect("ident + `Event` is a valid type");
112 let definition = {
113 let fields = data_struct.fields.iter().map(|field| {
114 let variant: syn::Variant = syn::parse_str(
115 &field
116 .ident
117 .clone()
118 .expect("Fields of NetworkBehaviour implementation to be named.")
119 .to_string()
120 .to_upper_camel_case(),
121 )
122 .expect("uppercased field name to be a valid enum variant");
123 let ty = &field.ty;
124 (variant, ty)
125 });
126
127 let enum_variants = fields
128 .clone()
129 .map(|(variant, ty)| quote! {#variant(<#ty as #trait_to_impl>::ToSwarm)});
130
131 let visibility = &ast.vis;
132
133 let additional = fields
134 .clone()
135 .map(|(_variant, tp)| quote! { #tp : #trait_to_impl })
136 .collect::<Vec<_>>();
137
138 let additional_debug = fields
139 .clone()
140 .map(|(_variant, ty)| quote! { <#ty as #trait_to_impl>::ToSwarm : ::core::fmt::Debug })
141 .collect::<Vec<_>>();
142
143 let where_clause = {
144 if let Some(where_clause) = where_clause {
145 if where_clause.predicates.trailing_punct() {
146 Some(quote! {#where_clause #(#additional),* })
147 } else {
148 Some(quote! {#where_clause, #(#additional),*})
149 }
150 } else if additional.is_empty() {
151 None
152 } else {
153 Some(quote! {where #(#additional),*})
154 }
155 };
156
157 let where_clause_debug = where_clause
158 .as_ref()
159 .map(|where_clause| quote! {#where_clause, #(#additional_debug),*});
160
161 let match_variants = fields.map(|(variant, _ty)| variant);
162 let msg = format!("`NetworkBehaviour::ToSwarm` produced by {name}.");
163
164 Some(quote! {
165 #[doc = #msg]
166 #visibility enum #enum_name #impl_generics
167 #where_clause
168 {
169 #(#enum_variants),*
170 }
171
172 impl #impl_generics ::core::fmt::Debug for #enum_name #ty_generics #where_clause_debug {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
174 match &self {
175 #(#enum_name::#match_variants(event) => {
176 write!(f, "{}: {:?}", #enum_name_str, event)
177 }),*
178 }
179 }
180 }
181 })
182 };
183 let from_clauses = vec![];
184 (enum_name, definition, from_clauses)
185 }
186 }
187 };
188
189 let where_clause = {
191 let additional = data_struct
192 .fields
193 .iter()
194 .map(|field| {
195 let ty = &field.ty;
196 quote! {#ty: #trait_to_impl}
197 })
198 .chain(out_event_from_clauses)
199 .collect::<Vec<_>>();
200
201 if let Some(where_clause) = where_clause {
202 if where_clause.predicates.trailing_punct() {
203 Some(quote! {#where_clause #(#additional),* })
204 } else {
205 Some(quote! {#where_clause, #(#additional),*})
206 }
207 } else {
208 Some(quote! {where #(#additional),*})
209 }
210 };
211
212 let on_swarm_event_stmts = {
214 data_struct
215 .fields
216 .iter()
217 .enumerate()
218 .map(|(field_n, field)| match field.ident {
219 Some(ref i) => quote! {
220 self.#i.on_swarm_event(event);
221 },
222 None => quote! {
223 self.#field_n.on_swarm_event(event);
224 },
225 })
226 };
227
228 let on_node_event_stmts =
233 data_struct
234 .fields
235 .iter()
236 .enumerate()
237 .enumerate()
238 .map(|(enum_n, (field_n, field))| {
239 let mut elem = if enum_n != 0 {
240 quote! { #either_ident::Right(ev) }
241 } else {
242 quote! { ev }
243 };
244
245 for _ in 0..data_struct.fields.len() - 1 - enum_n {
246 elem = quote! { #either_ident::Left(#elem) };
247 }
248
249 Some(match field.ident {
250 Some(ref i) => quote! { #elem => {
251 #trait_to_impl::on_connection_handler_event(&mut self.#i, peer_id, connection_id, ev) }},
252 None => quote! { #elem => {
253 #trait_to_impl::on_connection_handler_event(&mut self.#field_n, peer_id, connection_id, ev) }},
254 })
255 });
256
257 let connection_handler_ty = {
259 let mut ph_ty = None;
260 for field in data_struct.fields.iter() {
261 let ty = &field.ty;
262 let field_info = quote! { #t_handler<#ty> };
263 match ph_ty {
264 Some(ev) => ph_ty = Some(quote! { #proto_select_ident<#ev, #field_info> }),
265 ref mut ev @ None => *ev = Some(field_info),
266 }
267 }
268 ph_ty.unwrap_or(quote! {()}) };
271
272 let handle_pending_inbound_connection_stmts =
274 data_struct
275 .fields
276 .iter()
277 .enumerate()
278 .map(|(field_n, field)| {
279 match field.ident {
280 Some(ref i) => quote! {
281 #trait_to_impl::handle_pending_inbound_connection(&mut self.#i, connection_id, local_addr, remote_addr)?;
282 },
283 None => quote! {
284 #trait_to_impl::handle_pending_inbound_connection(&mut self.#field_n, connection_id, local_addr, remote_addr)?;
285 }
286 }
287 });
288
289 let handle_established_inbound_connection = {
291 let mut out_handler = None;
292
293 for (field_n, field) in data_struct.fields.iter().enumerate() {
294 let field_name = match field.ident {
295 Some(ref i) => quote! { self.#i },
296 None => quote! { self.#field_n },
297 };
298
299 let builder = quote! {
300 #field_name.handle_established_inbound_connection(connection_id, peer, local_addr, remote_addr)?
301 };
302
303 match out_handler {
304 Some(h) => out_handler = Some(quote! { #connection_handler::select(#h, #builder) }),
305 ref mut h @ None => *h = Some(builder),
306 }
307 }
308
309 out_handler.unwrap_or(quote! {()}) };
311
312 let handle_pending_outbound_connection = {
314 let extend_stmts =
315 data_struct
316 .fields
317 .iter()
318 .enumerate()
319 .map(|(field_n, field)| {
320 match field.ident {
321 Some(ref i) => quote! {
322 combined_addresses.extend(#trait_to_impl::handle_pending_outbound_connection(&mut self.#i, connection_id, maybe_peer, addresses, effective_role)?);
323 },
324 None => quote! {
325 combined_addresses.extend(#trait_to_impl::handle_pending_outbound_connection(&mut self.#field_n, connection_id, maybe_peer, addresses, effective_role)?);
326 }
327 }
328 });
329
330 quote! {
331 let mut combined_addresses = vec![];
332
333 #(#extend_stmts)*
334
335 Ok(combined_addresses)
336 }
337 };
338
339 let handle_established_outbound_connection = {
341 let mut out_handler = None;
342
343 for (field_n, field) in data_struct.fields.iter().enumerate() {
344 let field_name = match field.ident {
345 Some(ref i) => quote! { self.#i },
346 None => quote! { self.#field_n },
347 };
348
349 let builder = quote! {
350 #field_name.handle_established_outbound_connection(connection_id, peer, addr, role_override, port_use)?
351 };
352
353 match out_handler {
354 Some(h) => out_handler = Some(quote! { #connection_handler::select(#h, #builder) }),
355 ref mut h @ None => *h = Some(builder),
356 }
357 }
358
359 out_handler.unwrap_or(quote! {()}) };
361
362 let poll_stmts = data_struct
366 .fields
367 .iter()
368 .enumerate()
369 .map(|(field_n, field)| {
370 let field = field
371 .ident
372 .clone()
373 .expect("Fields of NetworkBehaviour implementation to be named.");
374
375 let mut wrapped_event = if field_n != 0 {
376 quote! { #either_ident::Right(event) }
377 } else {
378 quote! { event }
379 };
380 for _ in 0..data_struct.fields.len() - 1 - field_n {
381 wrapped_event = quote! { #either_ident::Left(#wrapped_event) };
382 }
383
384 let map_out_event = if out_event_definition.is_some() {
389 let event_variant: syn::Variant =
390 syn::parse_str(&field.to_string().to_upper_camel_case())
391 .expect("uppercased field name to be a valid enum variant name");
392 quote! { #out_event_name::#event_variant }
393 } else {
394 quote! { |e| e.into() }
395 };
396
397 let map_in_event = quote! { |event| #wrapped_event };
398
399 quote! {
400 match #trait_to_impl::poll(&mut self.#field, cx) {
401 std::task::Poll::Ready(e) => return std::task::Poll::Ready(e.map_out(#map_out_event).map_in(#map_in_event)),
402 std::task::Poll::Pending => {},
403 }
404 }
405 });
406
407 let out_event_reference = if out_event_definition.is_some() {
408 quote! { #out_event_name #ty_generics }
409 } else {
410 quote! { #out_event_name }
411 };
412
413 let final_quote = quote! {
415 #out_event_definition
416
417 impl #impl_generics #trait_to_impl for #name #ty_generics
418 #where_clause
419 {
420 type ConnectionHandler = #connection_handler_ty;
421 type ToSwarm = #out_event_reference;
422
423 #[allow(clippy::needless_question_mark)]
424 fn handle_pending_inbound_connection(
425 &mut self,
426 connection_id: #connection_id,
427 local_addr: &#multiaddr,
428 remote_addr: &#multiaddr,
429 ) -> std::result::Result<(), #connection_denied> {
430 #(#handle_pending_inbound_connection_stmts)*
431
432 Ok(())
433 }
434
435 #[allow(clippy::needless_question_mark)]
436 fn handle_established_inbound_connection(
437 &mut self,
438 connection_id: #connection_id,
439 peer: #peer_id,
440 local_addr: &#multiaddr,
441 remote_addr: &#multiaddr,
442 ) -> std::result::Result<#t_handler<Self>, #connection_denied> {
443 Ok(#handle_established_inbound_connection)
444 }
445
446 #[allow(clippy::needless_question_mark)]
447 fn handle_pending_outbound_connection(
448 &mut self,
449 connection_id: #connection_id,
450 maybe_peer: Option<#peer_id>,
451 addresses: &[#multiaddr],
452 effective_role: #endpoint,
453 ) -> std::result::Result<::std::vec::Vec<#multiaddr>, #connection_denied> {
454 #handle_pending_outbound_connection
455 }
456
457 #[allow(clippy::needless_question_mark)]
458 fn handle_established_outbound_connection(
459 &mut self,
460 connection_id: #connection_id,
461 peer: #peer_id,
462 addr: &#multiaddr,
463 role_override: #endpoint,
464 port_use: #port_use,
465 ) -> std::result::Result<#t_handler<Self>, #connection_denied> {
466 Ok(#handle_established_outbound_connection)
467 }
468
469 fn on_connection_handler_event(
470 &mut self,
471 peer_id: #peer_id,
472 connection_id: #connection_id,
473 event: #t_handler_out_event<Self>
474 ) {
475 match event {
476 #(#on_node_event_stmts),*
477 }
478 }
479
480 fn poll(&mut self, cx: &mut std::task::Context) -> std::task::Poll<#network_behaviour_action<Self::ToSwarm, #t_handler_in_event<Self>>> {
481 #(#poll_stmts)*
482 std::task::Poll::Pending
483 }
484
485 fn on_swarm_event(&mut self, event: #from_swarm) {
486 #(#on_swarm_event_stmts)*
487 }
488 }
489 };
490
491 Ok(final_quote.into())
492}
493
494struct BehaviourAttributes {
495 prelude_path: syn::Path,
496 user_specified_out_event: Option<syn::Type>,
497}
498
499fn parse_attributes(ast: &DeriveInput) -> syn::Result<BehaviourAttributes> {
501 let mut attributes = BehaviourAttributes {
502 prelude_path: syn::parse_quote! { ::libp2p::swarm::derive_prelude },
503 user_specified_out_event: None,
504 };
505
506 for attr in ast
507 .attrs
508 .iter()
509 .filter(|attr| attr.path().is_ident("behaviour"))
510 {
511 let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
512
513 for meta in nested {
514 if meta.path().is_ident("prelude") {
515 let value = meta.require_name_value()?.value.require_str_lit()?;
516
517 attributes.prelude_path = syn::parse_str(&value)?;
518
519 continue;
520 }
521
522 if meta.path().is_ident("to_swarm") || meta.path().is_ident("out_event") {
523 let value = meta.require_name_value()?.value.require_str_lit()?;
524
525 attributes.user_specified_out_event = Some(syn::parse_str(&value)?);
526
527 continue;
528 }
529 }
530 }
531
532 Ok(attributes)
533}