1use std::{
22 collections::{HashMap, HashSet},
23 io,
24 io::ErrorKind,
25 net::SocketAddr,
26 sync::Arc,
27 task::{Context, Poll},
28};
29
30use async_trait::async_trait;
31use futures::{
32 channel::oneshot,
33 future::{BoxFuture, FutureExt, OptionFuture},
34 stream::FuturesUnordered,
35 StreamExt,
36};
37use stun::{
38 attributes::ATTR_USERNAME,
39 message::{is_message as is_stun_message, Message as STUNMessage},
40};
41use thiserror::Error;
42use tokio::{io::ReadBuf, net::UdpSocket};
43use webrtc::{
44 ice::udp_mux::{UDPMux, UDPMuxConn, UDPMuxConnParams, UDPMuxWriter},
45 util::{Conn, Error},
46};
47
48use crate::tokio::req_res_chan;
49
50const RECEIVE_MTU: usize = 8192;
51
52#[derive(Debug)]
54pub(crate) struct NewAddr {
55 pub(crate) addr: SocketAddr,
56 pub(crate) ufrag: String,
57}
58
59#[derive(Debug)]
61pub(crate) enum UDPMuxEvent {
62 Error(std::io::Error),
64 NewAddr(NewAddr),
66}
67
68pub(crate) struct UDPMuxNewAddr {
73 udp_sock: UdpSocket,
74
75 listen_addr: SocketAddr,
76
77 conns: HashMap<String, UDPMuxConn>,
79
80 address_map: HashMap<SocketAddr, UDPMuxConn>,
82
83 new_addrs: HashSet<SocketAddr>,
85
86 is_closed: bool,
88
89 send_buffer: Option<(Vec<u8>, SocketAddr, oneshot::Sender<Result<usize, Error>>)>,
90
91 close_futures: FuturesUnordered<BoxFuture<'static, ()>>,
92 write_future: OptionFuture<BoxFuture<'static, ()>>,
93
94 close_command: req_res_chan::Receiver<(), Result<(), Error>>,
95 get_conn_command: req_res_chan::Receiver<String, Result<Arc<dyn Conn + Send + Sync>, Error>>,
96 remove_conn_command: req_res_chan::Receiver<String, ()>,
97 registration_command: req_res_chan::Receiver<(UDPMuxConn, SocketAddr), ()>,
98 send_command: req_res_chan::Receiver<(Vec<u8>, SocketAddr), Result<usize, Error>>,
99
100 udp_mux_handle: Arc<UdpMuxHandle>,
101 udp_mux_writer_handle: Arc<UdpMuxWriterHandle>,
102}
103
104impl UDPMuxNewAddr {
105 pub(crate) fn listen_on(addr: SocketAddr) -> Result<Self, io::Error> {
106 let std_sock = std::net::UdpSocket::bind(addr)?;
107 std_sock.set_nonblocking(true)?;
108
109 let tokio_socket = UdpSocket::from_std(std_sock)?;
110 let listen_addr = tokio_socket.local_addr()?;
111
112 let (udp_mux_handle, close_command, get_conn_command, remove_conn_command) =
113 UdpMuxHandle::new();
114 let (udp_mux_writer_handle, registration_command, send_command) = UdpMuxWriterHandle::new();
115
116 Ok(Self {
117 udp_sock: tokio_socket,
118 listen_addr,
119 conns: HashMap::default(),
120 address_map: HashMap::default(),
121 new_addrs: HashSet::default(),
122 is_closed: false,
123 send_buffer: None,
124 close_futures: FuturesUnordered::default(),
125 write_future: OptionFuture::default(),
126 close_command,
127 get_conn_command,
128 remove_conn_command,
129 registration_command,
130 send_command,
131 udp_mux_handle: Arc::new(udp_mux_handle),
132 udp_mux_writer_handle: Arc::new(udp_mux_writer_handle),
133 })
134 }
135
136 pub(crate) fn listen_addr(&self) -> SocketAddr {
137 self.listen_addr
138 }
139
140 pub(crate) fn udp_mux_handle(&self) -> Arc<UdpMuxHandle> {
141 self.udp_mux_handle.clone()
142 }
143
144 fn create_muxed_conn(&self, ufrag: &str) -> Result<UDPMuxConn, Error> {
146 let local_addr = self.udp_sock.local_addr()?;
147
148 let params = UDPMuxConnParams {
149 local_addr,
150 key: ufrag.into(),
151 udp_mux: Arc::downgrade(
152 &(self.udp_mux_writer_handle.clone() as Arc<dyn UDPMuxWriter + Send + Sync>),
153 ),
154 };
155
156 Ok(UDPMuxConn::new(params))
157 }
158
159 fn conn_from_stun_message(
162 &self,
163 buffer: &[u8],
164 addr: &SocketAddr,
165 ) -> Option<Result<UDPMuxConn, ConnQueryError>> {
166 match ufrag_from_stun_message(buffer, true) {
167 Ok(ufrag) => {
168 if let Some(conn) = self.conns.get(&ufrag) {
169 let associated_addrs = conn.get_addresses();
170 if associated_addrs.is_empty() || associated_addrs.contains(addr) {
172 return Some(Ok(conn.clone()));
173 } else {
174 return Some(Err(ConnQueryError::UfragAlreadyTaken { associated_addrs }));
175 }
176 }
177 None
178 }
179 Err(e) => {
180 tracing::debug!(address=%addr, "{}", e);
181 None
182 }
183 }
184 }
185
186 pub(crate) fn poll(&mut self, cx: &mut Context) -> Poll<UDPMuxEvent> {
189 let mut recv_buf = [0u8; RECEIVE_MTU];
190
191 loop {
192 match self.send_buffer.take() {
194 None => {
195 if let Poll::Ready(Some(((buf, target), response))) =
196 self.send_command.poll_next_unpin(cx)
197 {
198 self.send_buffer = Some((buf, target, response));
199 continue;
200 }
201 }
202 Some((buf, target, response)) => {
203 match self.udp_sock.poll_send_to(cx, &buf, target) {
204 Poll::Ready(result) => {
205 let _ = response.send(result.map_err(|e| Error::Io(e.into())));
206 continue;
207 }
208 Poll::Pending => {
209 self.send_buffer = Some((buf, target, response));
210 }
211 }
212 }
213 }
214
215 if let Poll::Ready(Some(((conn, addr), response))) =
217 self.registration_command.poll_next_unpin(cx)
218 {
219 let key = conn.key();
220
221 self.address_map
222 .entry(addr)
223 .and_modify(|e| {
224 if e.key() != key {
225 e.remove_address(&addr);
226 *e = conn.clone();
227 }
228 })
229 .or_insert_with(|| conn.clone());
230
231 self.new_addrs.remove(&addr);
233
234 let _ = response.send(());
235
236 continue;
237 }
238
239 if let Poll::Ready(Some((ufrag, response))) = self.get_conn_command.poll_next_unpin(cx)
241 {
242 if self.is_closed {
243 let _ = response.send(Err(Error::ErrUseClosedNetworkConn));
244 continue;
245 }
246
247 if let Some(conn) = self.conns.get(&ufrag).cloned() {
248 let _ = response.send(Ok(Arc::new(conn)));
249 continue;
250 }
251
252 let muxed_conn = match self.create_muxed_conn(&ufrag) {
253 Ok(conn) => conn,
254 Err(e) => {
255 let _ = response.send(Err(e));
256 continue;
257 }
258 };
259 let mut close_rx = muxed_conn.close_rx();
260
261 self.close_futures.push({
262 let ufrag = ufrag.clone();
263 let udp_mux_handle = self.udp_mux_handle.clone();
264
265 Box::pin(async move {
266 let _ = close_rx.changed().await;
267 udp_mux_handle.remove_conn_by_ufrag(&ufrag).await;
268 })
269 });
270
271 self.conns.insert(ufrag, muxed_conn.clone());
272
273 let _ = response.send(Ok(Arc::new(muxed_conn) as Arc<dyn Conn + Send + Sync>));
274
275 continue;
276 }
277
278 if let Poll::Ready(Some(((), response))) = self.close_command.poll_next_unpin(cx) {
280 if self.is_closed {
281 let _ = response.send(Err(Error::ErrAlreadyClosed));
282 continue;
283 }
284
285 for (_, conn) in self.conns.drain() {
286 conn.close();
287 }
288
289 self.address_map.clear();
292
293 self.new_addrs.clear();
296
297 let _ = response.send(Ok(()));
298
299 self.is_closed = true;
300
301 continue;
302 }
303
304 if let Poll::Ready(Some((ufrag, response))) =
306 self.remove_conn_command.poll_next_unpin(cx)
307 {
308 if let Some(removed_conn) = self.conns.remove(&ufrag) {
312 for address in removed_conn.get_addresses() {
313 self.address_map.remove(&address);
314 }
315 }
316
317 let _ = response.send(());
318
319 continue;
320 }
321
322 let _ = self.close_futures.poll_next_unpin(cx);
324
325 match self.write_future.poll_unpin(cx) {
327 Poll::Ready(Some(())) => {
328 self.write_future = OptionFuture::default();
329 continue;
330 }
331 Poll::Ready(None) => {
332 let mut read = ReadBuf::new(&mut recv_buf);
334
335 match self.udp_sock.poll_recv_from(cx, &mut read) {
336 Poll::Ready(Ok(addr)) => {
337 let conn = self.address_map.get(&addr);
339
340 let conn = match conn {
341 None if is_stun_message(read.filled()) => {
345 match self.conn_from_stun_message(read.filled(), &addr) {
346 Some(Ok(s)) => Some(s),
347 Some(Err(e)) => {
348 tracing::debug!(address=%&addr, "Error when querying existing connections: {}", e);
349 continue;
350 }
351 None => None,
352 }
353 }
354 Some(s) => Some(s.to_owned()),
355 _ => None,
356 };
357
358 match conn {
359 None => {
360 if !self.new_addrs.contains(&addr) {
361 match ufrag_from_stun_message(read.filled(), false) {
362 Ok(ufrag) => {
363 tracing::trace!(
364 address=%&addr,
365 %ufrag,
366 "Notifying about new address from ufrag",
367 );
368 self.new_addrs.insert(addr);
369 return Poll::Ready(UDPMuxEvent::NewAddr(
370 NewAddr { addr, ufrag },
371 ));
372 }
373 Err(e) => {
374 tracing::debug!(
375 address=%&addr,
376 "Unknown address (non STUN packet: {})",
377 e
378 );
379 }
380 }
381 }
382 }
383 Some(conn) => {
384 let mut packet = vec![0u8; read.filled().len()];
385 packet.copy_from_slice(read.filled());
386 self.write_future = OptionFuture::from(Some(
387 async move {
388 if let Err(err) = conn.write_packet(&packet, addr).await
389 {
390 tracing::error!(
391 address=%addr,
392 "Failed to write packet: {}",
393 err,
394 );
395 }
396 }
397 .boxed(),
398 ));
399 }
400 }
401
402 continue;
403 }
404 Poll::Pending => {}
405 Poll::Ready(Err(err)) if err.kind() == ErrorKind::TimedOut => {}
406 Poll::Ready(Err(err)) if err.kind() == ErrorKind::ConnectionReset => {
407 tracing::debug!("ConnectionReset by remote client {err:?}")
408 }
409 Poll::Ready(Err(err)) => {
410 tracing::error!("Could not read udp packet: {}", err);
411 return Poll::Ready(UDPMuxEvent::Error(err));
412 }
413 }
414 }
415 Poll::Pending => {}
416 }
417
418 return Poll::Pending;
419 }
420 }
421}
422
423pub(crate) struct UdpMuxHandle {
426 close_sender: req_res_chan::Sender<(), Result<(), Error>>,
427 get_conn_sender: req_res_chan::Sender<String, Result<Arc<dyn Conn + Send + Sync>, Error>>,
428 remove_sender: req_res_chan::Sender<String, ()>,
429}
430
431impl UdpMuxHandle {
432 pub(crate) fn new() -> (
434 Self,
435 req_res_chan::Receiver<(), Result<(), Error>>,
436 req_res_chan::Receiver<String, Result<Arc<dyn Conn + Send + Sync>, Error>>,
437 req_res_chan::Receiver<String, ()>,
438 ) {
439 let (sender1, receiver1) = req_res_chan::new(1);
440 let (sender2, receiver2) = req_res_chan::new(1);
441 let (sender3, receiver3) = req_res_chan::new(1);
442
443 let this = Self {
444 close_sender: sender1,
445 get_conn_sender: sender2,
446 remove_sender: sender3,
447 };
448
449 (this, receiver1, receiver2, receiver3)
450 }
451}
452
453#[async_trait]
454impl UDPMux for UdpMuxHandle {
455 async fn close(&self) -> Result<(), Error> {
456 self.close_sender
457 .send(())
458 .await
459 .map_err(|e| Error::Io(e.into()))??;
460
461 Ok(())
462 }
463
464 async fn get_conn(self: Arc<Self>, ufrag: &str) -> Result<Arc<dyn Conn + Send + Sync>, Error> {
465 let conn = self
466 .get_conn_sender
467 .send(ufrag.to_owned())
468 .await
469 .map_err(|e| Error::Io(e.into()))??;
470
471 Ok(conn)
472 }
473
474 async fn remove_conn_by_ufrag(&self, ufrag: &str) {
475 if let Err(e) = self.remove_sender.send(ufrag.to_owned()).await {
476 tracing::debug!("Failed to send message through channel: {:?}", e);
477 }
478 }
479}
480
481pub(crate) struct UdpMuxWriterHandle {
484 registration_channel: req_res_chan::Sender<(UDPMuxConn, SocketAddr), ()>,
485 send_channel: req_res_chan::Sender<(Vec<u8>, SocketAddr), Result<usize, Error>>,
486}
487
488impl UdpMuxWriterHandle {
489 fn new() -> (
491 Self,
492 req_res_chan::Receiver<(UDPMuxConn, SocketAddr), ()>,
493 req_res_chan::Receiver<(Vec<u8>, SocketAddr), Result<usize, Error>>,
494 ) {
495 let (sender1, receiver1) = req_res_chan::new(1);
496 let (sender2, receiver2) = req_res_chan::new(1);
497
498 let this = Self {
499 registration_channel: sender1,
500 send_channel: sender2,
501 };
502
503 (this, receiver1, receiver2)
504 }
505}
506
507#[async_trait]
508impl UDPMuxWriter for UdpMuxWriterHandle {
509 async fn register_conn_for_address(&self, conn: &UDPMuxConn, addr: SocketAddr) {
510 match self
511 .registration_channel
512 .send((conn.to_owned(), addr))
513 .await
514 {
515 Ok(()) => {}
516 Err(e) => {
517 tracing::debug!("Failed to send message through channel: {:?}", e);
518 return;
519 }
520 }
521
522 tracing::debug!(address=%addr, connection=%conn.key(), "Registered address for connection");
523 }
524
525 async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> Result<usize, Error> {
526 let bytes_written = self
527 .send_channel
528 .send((buf.to_owned(), target.to_owned()))
529 .await
530 .map_err(|e| Error::Io(e.into()))??;
531
532 Ok(bytes_written)
533 }
534}
535
536fn ufrag_from_stun_message(buffer: &[u8], local_ufrag: bool) -> Result<String, Error> {
539 let (result, message) = {
540 let mut m = STUNMessage::new();
541
542 (m.unmarshal_binary(buffer), m)
543 };
544
545 if let Err(err) = result {
546 Err(Error::Other(format!("failed to handle decode ICE: {err}")))
547 } else {
548 let (attr, found) = message.attributes.get(ATTR_USERNAME);
549 if !found {
550 return Err(Error::Other("no username attribute in STUN message".into()));
551 }
552
553 match String::from_utf8(attr.value) {
554 Err(err) => Err(Error::Other(format!(
557 "failed to decode USERNAME from STUN message as UTF-8: {err}"
558 ))),
559 Ok(s) => {
560 let res = if local_ufrag {
562 s.split(':').next()
563 } else {
564 s.split(':').next_back()
565 };
566 match res {
567 Some(s) => Ok(s.to_owned()),
568 None => Err(Error::Other("can't get ufrag from username".into())),
569 }
570 }
571 }
572 }
573}
574
575#[derive(Error, Debug)]
576enum ConnQueryError {
577 #[error("ufrag is already taken (associated_addrs={associated_addrs:?})")]
578 UfragAlreadyTaken { associated_addrs: Vec<SocketAddr> },
579}