1pub mod either;
42mod map_in;
43mod map_out;
44pub mod multi;
45mod one_shot;
46mod pending;
47mod select;
48
49use core::slice;
50use std::{
51 collections::{HashMap, HashSet},
52 error, fmt, io,
53 task::{Context, Poll},
54 time::Duration,
55};
56
57use libp2p_core::Multiaddr;
58pub use map_in::MapInEvent;
59pub use map_out::MapOutEvent;
60pub use one_shot::{OneShotHandler, OneShotHandlerConfig};
61pub use pending::PendingConnectionHandler;
62pub use select::ConnectionHandlerSelect;
63use smallvec::SmallVec;
64
65pub use crate::upgrade::{InboundUpgradeSend, OutboundUpgradeSend, SendWrapper, UpgradeInfoSend};
66use crate::{connection::AsStrHashEq, StreamProtocol};
67
68pub trait ConnectionHandler: Send + 'static {
102 type FromBehaviour: fmt::Debug + Send + 'static;
106 type ToBehaviour: fmt::Debug + Send + 'static;
110 type InboundProtocol: InboundUpgradeSend;
112 type OutboundProtocol: OutboundUpgradeSend;
114 type InboundOpenInfo: Send + 'static;
116 type OutboundOpenInfo: Send + 'static;
118
119 fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo>;
127
128 fn connection_keep_alive(&self) -> bool {
154 false
155 }
156
157 fn poll(
159 &mut self,
160 cx: &mut Context<'_>,
161 ) -> Poll<
162 ConnectionHandlerEvent<Self::OutboundProtocol, Self::OutboundOpenInfo, Self::ToBehaviour>,
163 >;
164
165 fn poll_close(&mut self, _: &mut Context<'_>) -> Poll<Option<Self::ToBehaviour>> {
180 Poll::Ready(None)
181 }
182
183 fn map_in_event<TNewIn, TMap>(self, map: TMap) -> MapInEvent<Self, TNewIn, TMap>
185 where
186 Self: Sized,
187 TMap: Fn(&TNewIn) -> Option<&Self::FromBehaviour>,
188 {
189 MapInEvent::new(self, map)
190 }
191
192 fn map_out_event<TMap, TNewOut>(self, map: TMap) -> MapOutEvent<Self, TMap>
194 where
195 Self: Sized,
196 TMap: FnMut(Self::ToBehaviour) -> TNewOut,
197 {
198 MapOutEvent::new(self, map)
199 }
200
201 fn select<TProto2>(self, other: TProto2) -> ConnectionHandlerSelect<Self, TProto2>
204 where
205 Self: Sized,
206 {
207 ConnectionHandlerSelect::new(self, other)
208 }
209
210 fn on_behaviour_event(&mut self, _event: Self::FromBehaviour);
212
213 fn on_connection_event(
214 &mut self,
215 event: ConnectionEvent<
216 Self::InboundProtocol,
217 Self::OutboundProtocol,
218 Self::InboundOpenInfo,
219 Self::OutboundOpenInfo,
220 >,
221 );
222}
223
224#[non_exhaustive]
227pub enum ConnectionEvent<'a, IP: InboundUpgradeSend, OP: OutboundUpgradeSend, IOI = (), OOI = ()> {
228 FullyNegotiatedInbound(FullyNegotiatedInbound<IP, IOI>),
230 FullyNegotiatedOutbound(FullyNegotiatedOutbound<OP, OOI>),
232 AddressChange(AddressChange<'a>),
234 DialUpgradeError(DialUpgradeError<OOI, OP>),
236 ListenUpgradeError(ListenUpgradeError<IOI, IP>),
238 LocalProtocolsChange(ProtocolsChange<'a>),
240 RemoteProtocolsChange(ProtocolsChange<'a>),
242}
243
244impl<IP, OP, IOI, OOI> fmt::Debug for ConnectionEvent<'_, IP, OP, IOI, OOI>
245where
246 IP: InboundUpgradeSend + fmt::Debug,
247 IP::Output: fmt::Debug,
248 IP::Error: fmt::Debug,
249 OP: OutboundUpgradeSend + fmt::Debug,
250 OP::Output: fmt::Debug,
251 OP::Error: fmt::Debug,
252 IOI: fmt::Debug,
253 OOI: fmt::Debug,
254{
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 match self {
257 ConnectionEvent::FullyNegotiatedInbound(v) => {
258 f.debug_tuple("FullyNegotiatedInbound").field(v).finish()
259 }
260 ConnectionEvent::FullyNegotiatedOutbound(v) => {
261 f.debug_tuple("FullyNegotiatedOutbound").field(v).finish()
262 }
263 ConnectionEvent::AddressChange(v) => f.debug_tuple("AddressChange").field(v).finish(),
264 ConnectionEvent::DialUpgradeError(v) => {
265 f.debug_tuple("DialUpgradeError").field(v).finish()
266 }
267 ConnectionEvent::ListenUpgradeError(v) => {
268 f.debug_tuple("ListenUpgradeError").field(v).finish()
269 }
270 ConnectionEvent::LocalProtocolsChange(v) => {
271 f.debug_tuple("LocalProtocolsChange").field(v).finish()
272 }
273 ConnectionEvent::RemoteProtocolsChange(v) => {
274 f.debug_tuple("RemoteProtocolsChange").field(v).finish()
275 }
276 }
277 }
278}
279
280impl<IP: InboundUpgradeSend, OP: OutboundUpgradeSend, IOI, OOI>
281 ConnectionEvent<'_, IP, OP, IOI, OOI>
282{
283 pub fn is_outbound(&self) -> bool {
285 match self {
286 ConnectionEvent::DialUpgradeError(_) | ConnectionEvent::FullyNegotiatedOutbound(_) => {
287 true
288 }
289 ConnectionEvent::FullyNegotiatedInbound(_)
290 | ConnectionEvent::AddressChange(_)
291 | ConnectionEvent::LocalProtocolsChange(_)
292 | ConnectionEvent::RemoteProtocolsChange(_)
293 | ConnectionEvent::ListenUpgradeError(_) => false,
294 }
295 }
296
297 pub fn is_inbound(&self) -> bool {
299 match self {
300 ConnectionEvent::FullyNegotiatedInbound(_) | ConnectionEvent::ListenUpgradeError(_) => {
301 true
302 }
303 ConnectionEvent::FullyNegotiatedOutbound(_)
304 | ConnectionEvent::AddressChange(_)
305 | ConnectionEvent::LocalProtocolsChange(_)
306 | ConnectionEvent::RemoteProtocolsChange(_)
307 | ConnectionEvent::DialUpgradeError(_) => false,
308 }
309 }
310}
311
312#[derive(Debug)]
321pub struct FullyNegotiatedInbound<IP: InboundUpgradeSend, IOI = ()> {
322 pub protocol: IP::Output,
323 pub info: IOI,
324}
325
326#[derive(Debug)]
332pub struct FullyNegotiatedOutbound<OP: OutboundUpgradeSend, OOI = ()> {
333 pub protocol: OP::Output,
334 pub info: OOI,
335}
336
337#[derive(Debug)]
340pub struct AddressChange<'a> {
341 pub new_address: &'a Multiaddr,
342}
343
344#[derive(Debug, Clone)]
347pub enum ProtocolsChange<'a> {
348 Added(ProtocolsAdded<'a>),
349 Removed(ProtocolsRemoved<'a>),
350}
351
352impl<'a> ProtocolsChange<'a> {
353 pub(crate) fn from_initial_protocols<'b, T: AsRef<str> + 'b>(
355 new_protocols: impl IntoIterator<Item = &'b T>,
356 buffer: &'a mut Vec<StreamProtocol>,
357 ) -> Self {
358 buffer.clear();
359 buffer.extend(
360 new_protocols
361 .into_iter()
362 .filter_map(|i| StreamProtocol::try_from_owned(i.as_ref().to_owned()).ok()),
363 );
364
365 ProtocolsChange::Added(ProtocolsAdded {
366 protocols: buffer.iter(),
367 })
368 }
369
370 pub(crate) fn add(
374 existing_protocols: &HashSet<StreamProtocol>,
375 to_add: HashSet<StreamProtocol>,
376 buffer: &'a mut Vec<StreamProtocol>,
377 ) -> Option<Self> {
378 buffer.clear();
379 buffer.extend(
380 to_add
381 .into_iter()
382 .filter(|i| !existing_protocols.contains(i)),
383 );
384
385 if buffer.is_empty() {
386 return None;
387 }
388
389 Some(Self::Added(ProtocolsAdded {
390 protocols: buffer.iter(),
391 }))
392 }
393
394 pub(crate) fn remove(
400 existing_protocols: &mut HashSet<StreamProtocol>,
401 to_remove: HashSet<StreamProtocol>,
402 buffer: &'a mut Vec<StreamProtocol>,
403 ) -> Option<Self> {
404 buffer.clear();
405 buffer.extend(
406 to_remove
407 .into_iter()
408 .filter_map(|i| existing_protocols.take(&i)),
409 );
410
411 if buffer.is_empty() {
412 return None;
413 }
414
415 Some(Self::Removed(ProtocolsRemoved {
416 protocols: buffer.iter(),
417 }))
418 }
419
420 pub(crate) fn from_full_sets<T: AsRef<str>>(
423 existing_protocols: &mut HashMap<AsStrHashEq<T>, bool>,
424 new_protocols: impl IntoIterator<Item = T>,
425 buffer: &'a mut Vec<StreamProtocol>,
426 ) -> SmallVec<[Self; 2]> {
427 buffer.clear();
428
429 for v in existing_protocols.values_mut() {
431 *v = false;
432 }
433
434 let mut new_protocol_count = 0; for new_protocol in new_protocols {
436 existing_protocols
437 .entry(AsStrHashEq(new_protocol))
438 .and_modify(|v| *v = true) .or_insert_with_key(|k| {
440 buffer.extend(StreamProtocol::try_from_owned(k.0.as_ref().to_owned()).ok());
442 true
443 });
444 new_protocol_count += 1;
445 }
446
447 if new_protocol_count == existing_protocols.len() && buffer.is_empty() {
448 return SmallVec::new();
449 }
450
451 let num_new_protocols = buffer.len();
452 existing_protocols.retain(|p, &mut is_supported| {
456 if !is_supported {
457 buffer.extend(StreamProtocol::try_from_owned(p.0.as_ref().to_owned()).ok());
458 }
459
460 is_supported
461 });
462
463 let (added, removed) = buffer.split_at(num_new_protocols);
464 let mut changes = SmallVec::new();
465 if !added.is_empty() {
466 changes.push(ProtocolsChange::Added(ProtocolsAdded {
467 protocols: added.iter(),
468 }));
469 }
470 if !removed.is_empty() {
471 changes.push(ProtocolsChange::Removed(ProtocolsRemoved {
472 protocols: removed.iter(),
473 }));
474 }
475 changes
476 }
477}
478
479#[derive(Debug, Clone)]
481pub struct ProtocolsAdded<'a> {
482 pub(crate) protocols: slice::Iter<'a, StreamProtocol>,
483}
484
485#[derive(Debug, Clone)]
487pub struct ProtocolsRemoved<'a> {
488 pub(crate) protocols: slice::Iter<'a, StreamProtocol>,
489}
490
491impl<'a> Iterator for ProtocolsAdded<'a> {
492 type Item = &'a StreamProtocol;
493 fn next(&mut self) -> Option<Self::Item> {
494 self.protocols.next()
495 }
496}
497
498impl<'a> Iterator for ProtocolsRemoved<'a> {
499 type Item = &'a StreamProtocol;
500 fn next(&mut self) -> Option<Self::Item> {
501 self.protocols.next()
502 }
503}
504
505#[derive(Debug)]
508pub struct DialUpgradeError<OOI, OP: OutboundUpgradeSend> {
509 pub info: OOI,
510 pub error: StreamUpgradeError<OP::Error>,
511}
512
513#[derive(Debug)]
516pub struct ListenUpgradeError<IOI, IP: InboundUpgradeSend> {
517 pub info: IOI,
518 pub error: IP::Error,
519}
520
521#[derive(Copy, Clone, Debug, PartialEq, Eq)]
527pub struct SubstreamProtocol<TUpgrade, TInfo = ()> {
528 upgrade: TUpgrade,
529 info: TInfo,
530 timeout: Duration,
531}
532
533impl<TUpgrade, TInfo> SubstreamProtocol<TUpgrade, TInfo> {
534 pub fn new(upgrade: TUpgrade, info: TInfo) -> Self {
539 SubstreamProtocol {
540 upgrade,
541 info,
542 timeout: Duration::from_secs(10),
543 }
544 }
545
546 pub fn map_upgrade<U, F>(self, f: F) -> SubstreamProtocol<U, TInfo>
548 where
549 F: FnOnce(TUpgrade) -> U,
550 {
551 SubstreamProtocol {
552 upgrade: f(self.upgrade),
553 info: self.info,
554 timeout: self.timeout,
555 }
556 }
557
558 pub fn map_info<U, F>(self, f: F) -> SubstreamProtocol<TUpgrade, U>
560 where
561 F: FnOnce(TInfo) -> U,
562 {
563 SubstreamProtocol {
564 upgrade: self.upgrade,
565 info: f(self.info),
566 timeout: self.timeout,
567 }
568 }
569
570 pub fn with_timeout(mut self, timeout: Duration) -> Self {
572 self.timeout = timeout;
573 self
574 }
575
576 pub fn upgrade(&self) -> &TUpgrade {
578 &self.upgrade
579 }
580
581 pub fn info(&self) -> &TInfo {
583 &self.info
584 }
585
586 pub fn timeout(&self) -> &Duration {
588 &self.timeout
589 }
590
591 pub fn into_upgrade(self) -> (TUpgrade, TInfo) {
593 (self.upgrade, self.info)
594 }
595}
596
597#[derive(Debug, Clone, PartialEq, Eq)]
599#[non_exhaustive]
600pub enum ConnectionHandlerEvent<TConnectionUpgrade, TOutboundOpenInfo, TCustom> {
601 OutboundSubstreamRequest {
603 protocol: SubstreamProtocol<TConnectionUpgrade, TOutboundOpenInfo>,
605 },
606 ReportRemoteProtocols(ProtocolSupport),
608
609 NotifyBehaviour(TCustom),
611}
612
613#[derive(Debug, Clone, PartialEq, Eq)]
614pub enum ProtocolSupport {
615 Added(HashSet<StreamProtocol>),
617 Removed(HashSet<StreamProtocol>),
619}
620
621impl<TConnectionUpgrade, TOutboundOpenInfo, TCustom>
623 ConnectionHandlerEvent<TConnectionUpgrade, TOutboundOpenInfo, TCustom>
624{
625 pub fn map_outbound_open_info<F, I>(
628 self,
629 map: F,
630 ) -> ConnectionHandlerEvent<TConnectionUpgrade, I, TCustom>
631 where
632 F: FnOnce(TOutboundOpenInfo) -> I,
633 {
634 match self {
635 ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
636 ConnectionHandlerEvent::OutboundSubstreamRequest {
637 protocol: protocol.map_info(map),
638 }
639 }
640 ConnectionHandlerEvent::NotifyBehaviour(val) => {
641 ConnectionHandlerEvent::NotifyBehaviour(val)
642 }
643 ConnectionHandlerEvent::ReportRemoteProtocols(support) => {
644 ConnectionHandlerEvent::ReportRemoteProtocols(support)
645 }
646 }
647 }
648
649 pub fn map_protocol<F, I>(self, map: F) -> ConnectionHandlerEvent<I, TOutboundOpenInfo, TCustom>
652 where
653 F: FnOnce(TConnectionUpgrade) -> I,
654 {
655 match self {
656 ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
657 ConnectionHandlerEvent::OutboundSubstreamRequest {
658 protocol: protocol.map_upgrade(map),
659 }
660 }
661 ConnectionHandlerEvent::NotifyBehaviour(val) => {
662 ConnectionHandlerEvent::NotifyBehaviour(val)
663 }
664 ConnectionHandlerEvent::ReportRemoteProtocols(support) => {
665 ConnectionHandlerEvent::ReportRemoteProtocols(support)
666 }
667 }
668 }
669
670 pub fn map_custom<F, I>(
672 self,
673 map: F,
674 ) -> ConnectionHandlerEvent<TConnectionUpgrade, TOutboundOpenInfo, I>
675 where
676 F: FnOnce(TCustom) -> I,
677 {
678 match self {
679 ConnectionHandlerEvent::OutboundSubstreamRequest { protocol } => {
680 ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }
681 }
682 ConnectionHandlerEvent::NotifyBehaviour(val) => {
683 ConnectionHandlerEvent::NotifyBehaviour(map(val))
684 }
685 ConnectionHandlerEvent::ReportRemoteProtocols(support) => {
686 ConnectionHandlerEvent::ReportRemoteProtocols(support)
687 }
688 }
689 }
690}
691
692#[derive(Debug)]
694pub enum StreamUpgradeError<TUpgrErr> {
695 Timeout,
697 Apply(TUpgrErr),
699 NegotiationFailed,
701 Io(io::Error),
703}
704
705impl<TUpgrErr> StreamUpgradeError<TUpgrErr> {
706 pub fn map_upgrade_err<F, E>(self, f: F) -> StreamUpgradeError<E>
708 where
709 F: FnOnce(TUpgrErr) -> E,
710 {
711 match self {
712 StreamUpgradeError::Timeout => StreamUpgradeError::Timeout,
713 StreamUpgradeError::Apply(e) => StreamUpgradeError::Apply(f(e)),
714 StreamUpgradeError::NegotiationFailed => StreamUpgradeError::NegotiationFailed,
715 StreamUpgradeError::Io(e) => StreamUpgradeError::Io(e),
716 }
717 }
718}
719
720impl<TUpgrErr> fmt::Display for StreamUpgradeError<TUpgrErr>
721where
722 TUpgrErr: error::Error + 'static,
723{
724 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
725 match self {
726 StreamUpgradeError::Timeout => {
727 write!(f, "Timeout error while opening a substream")
728 }
729 StreamUpgradeError::Apply(err) => {
730 write!(f, "Apply: ")?;
731 crate::print_error_chain(f, err)
732 }
733 StreamUpgradeError::NegotiationFailed => {
734 write!(f, "no protocols could be agreed upon")
735 }
736 StreamUpgradeError::Io(e) => {
737 write!(f, "IO error: ")?;
738 crate::print_error_chain(f, e)
739 }
740 }
741 }
742}
743
744impl<TUpgrErr> error::Error for StreamUpgradeError<TUpgrErr>
745where
746 TUpgrErr: error::Error + 'static,
747{
748 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
749 None
750 }
751}
752
753#[cfg(test)]
754mod test {
755 use super::*;
756
757 fn protocol_set_of(s: &'static str) -> HashSet<StreamProtocol> {
758 s.split_whitespace()
759 .map(|p| StreamProtocol::try_from_owned(format!("/{p}")).unwrap())
760 .collect()
761 }
762
763 fn test_remove(
764 existing: &mut HashSet<StreamProtocol>,
765 to_remove: HashSet<StreamProtocol>,
766 ) -> HashSet<StreamProtocol> {
767 ProtocolsChange::remove(existing, to_remove, &mut Vec::new())
768 .into_iter()
769 .flat_map(|c| match c {
770 ProtocolsChange::Added(_) => panic!("unexpected added"),
771 ProtocolsChange::Removed(r) => r.cloned(),
772 })
773 .collect::<HashSet<_>>()
774 }
775
776 #[test]
777 fn test_protocol_remove_subset() {
778 let mut existing = protocol_set_of("a b c");
779 let to_remove = protocol_set_of("a b");
780
781 let change = test_remove(&mut existing, to_remove);
782
783 assert_eq!(existing, protocol_set_of("c"));
784 assert_eq!(change, protocol_set_of("a b"));
785 }
786
787 #[test]
788 fn test_protocol_remove_all() {
789 let mut existing = protocol_set_of("a b c");
790 let to_remove = protocol_set_of("a b c");
791
792 let change = test_remove(&mut existing, to_remove);
793
794 assert_eq!(existing, protocol_set_of(""));
795 assert_eq!(change, protocol_set_of("a b c"));
796 }
797
798 #[test]
799 fn test_protocol_remove_superset() {
800 let mut existing = protocol_set_of("a b c");
801 let to_remove = protocol_set_of("a b c d");
802
803 let change = test_remove(&mut existing, to_remove);
804
805 assert_eq!(existing, protocol_set_of(""));
806 assert_eq!(change, protocol_set_of("a b c"));
807 }
808
809 #[test]
810 fn test_protocol_remove_none() {
811 let mut existing = protocol_set_of("a b c");
812 let to_remove = protocol_set_of("d");
813
814 let change = test_remove(&mut existing, to_remove);
815
816 assert_eq!(existing, protocol_set_of("a b c"));
817 assert_eq!(change, protocol_set_of(""));
818 }
819
820 #[test]
821 fn test_protocol_remove_none_from_empty() {
822 let mut existing = protocol_set_of("");
823 let to_remove = protocol_set_of("d");
824
825 let change = test_remove(&mut existing, to_remove);
826
827 assert_eq!(existing, protocol_set_of(""));
828 assert_eq!(change, protocol_set_of(""));
829 }
830
831 fn test_from_full_sets(
832 existing: HashSet<StreamProtocol>,
833 new: HashSet<StreamProtocol>,
834 ) -> [HashSet<StreamProtocol>; 2] {
835 let mut buffer = Vec::new();
836 let mut existing = existing
837 .iter()
838 .map(|p| (AsStrHashEq(p.as_ref()), true))
839 .collect::<HashMap<_, _>>();
840
841 let changes = ProtocolsChange::from_full_sets(
842 &mut existing,
843 new.iter().map(AsRef::as_ref),
844 &mut buffer,
845 );
846
847 let mut added_changes = HashSet::new();
848 let mut removed_changes = HashSet::new();
849
850 for change in changes {
851 match change {
852 ProtocolsChange::Added(a) => {
853 added_changes.extend(a.cloned());
854 }
855 ProtocolsChange::Removed(r) => {
856 removed_changes.extend(r.cloned());
857 }
858 }
859 }
860
861 [removed_changes, added_changes]
862 }
863
864 #[test]
865 fn test_from_full_stes_subset() {
866 let existing = protocol_set_of("a b c");
867 let new = protocol_set_of("a b");
868
869 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
870
871 assert_eq!(added_changes, protocol_set_of(""));
872 assert_eq!(removed_changes, protocol_set_of("c"));
873 }
874
875 #[test]
876 fn test_from_full_sets_superset() {
877 let existing = protocol_set_of("a b");
878 let new = protocol_set_of("a b c");
879
880 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
881
882 assert_eq!(added_changes, protocol_set_of("c"));
883 assert_eq!(removed_changes, protocol_set_of(""));
884 }
885
886 #[test]
887 fn test_from_full_sets_intersection() {
888 let existing = protocol_set_of("a b c");
889 let new = protocol_set_of("b c d");
890
891 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
892
893 assert_eq!(added_changes, protocol_set_of("d"));
894 assert_eq!(removed_changes, protocol_set_of("a"));
895 }
896
897 #[test]
898 fn test_from_full_sets_disjoint() {
899 let existing = protocol_set_of("a b c");
900 let new = protocol_set_of("d e f");
901
902 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
903
904 assert_eq!(added_changes, protocol_set_of("d e f"));
905 assert_eq!(removed_changes, protocol_set_of("a b c"));
906 }
907
908 #[test]
909 fn test_from_full_sets_empty() {
910 let existing = protocol_set_of("");
911 let new = protocol_set_of("");
912
913 let [removed_changes, added_changes] = test_from_full_sets(existing, new);
914
915 assert_eq!(added_changes, protocol_set_of(""));
916 assert_eq!(removed_changes, protocol_set_of(""));
917 }
918}