dstar_gateway_server/client_pool/
pool.rs1use std::collections::{HashMap, HashSet};
9use std::marker::PhantomData;
10use std::net::SocketAddr;
11use std::time::Instant;
12
13use tokio::sync::Mutex;
14
15use dstar_gateway_core::session::client::Protocol;
16use dstar_gateway_core::types::Module;
17
18use crate::reflector::AccessPolicy;
19
20use super::handle::ClientHandle;
21
22pub const DEFAULT_UNHEALTHY_THRESHOLD: u32 = 5;
27
28#[non_exhaustive]
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum UnhealthyOutcome {
39 StillHealthy {
41 failure_count: u32,
43 },
44 ShouldEvict {
46 failure_count: u32,
48 },
49}
50
51#[derive(Debug)]
58pub struct ClientPool<P: Protocol> {
59 clients: Mutex<HashMap<SocketAddr, ClientHandle<P>>>,
60 by_module: Mutex<HashMap<Module, HashSet<SocketAddr>>>,
61 _protocol: PhantomData<fn() -> P>,
62}
63
64impl<P: Protocol> Default for ClientPool<P> {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl<P: Protocol> ClientPool<P> {
71 #[must_use]
73 pub fn new() -> Self {
74 Self {
75 clients: Mutex::new(HashMap::new()),
76 by_module: Mutex::new(HashMap::new()),
77 _protocol: PhantomData,
78 }
79 }
80
81 pub async fn insert(&self, peer: SocketAddr, handle: ClientHandle<P>) {
94 let module = handle.module;
95 let mut clients = self.clients.lock().await;
96 drop(clients.insert(peer, handle));
99 drop(clients);
100 if let Some(module) = module {
101 let mut index = self.by_module.lock().await;
102 let _ = index.entry(module).or_default().insert(peer);
103 }
104 }
105
106 pub async fn remove(&self, peer: &SocketAddr) -> Option<ClientHandle<P>> {
115 let mut clients = self.clients.lock().await;
116 let handle = clients.remove(peer)?;
117 drop(clients);
118 if let Some(module) = handle.module {
119 let mut index = self.by_module.lock().await;
120 if let Some(set) = index.get_mut(&module) {
121 let _ = set.remove(peer);
122 if set.is_empty() {
123 drop(index.remove(&module));
124 }
125 }
126 }
127 Some(handle)
128 }
129
130 pub async fn set_module(&self, peer: &SocketAddr, module: Module) {
140 let mut clients = self.clients.lock().await;
141 let Some(handle) = clients.get_mut(peer) else {
142 return;
143 };
144 let previous_module = handle.module;
145 handle.module = Some(module);
146 drop(clients);
147 let mut index = self.by_module.lock().await;
148 if let Some(prev) = previous_module
149 && prev != module
150 && let Some(set) = index.get_mut(&prev)
151 {
152 let _ = set.remove(peer);
153 if set.is_empty() {
154 drop(index.remove(&prev));
155 }
156 }
157 let _ = index.entry(module).or_default().insert(*peer);
158 }
159
160 pub async fn members_of_module(&self, module: Module) -> Vec<SocketAddr> {
168 let index = self.by_module.lock().await;
169 index
170 .get(&module)
171 .map(|set| set.iter().copied().collect())
172 .unwrap_or_default()
173 }
174
175 pub async fn len(&self) -> usize {
182 self.clients.lock().await.len()
183 }
184
185 pub async fn is_empty(&self) -> bool {
191 self.clients.lock().await.is_empty()
192 }
193
194 pub async fn contains(&self, peer: &SocketAddr) -> bool {
200 self.clients.lock().await.contains_key(peer)
201 }
202
203 pub async fn mark_unhealthy(&self, peer: &SocketAddr) -> UnhealthyOutcome {
223 let mut clients = self.clients.lock().await;
224 let count = match clients.get_mut(peer) {
225 Some(handle) => {
226 handle.send_failure_count = handle.send_failure_count.saturating_add(1);
227 handle.send_failure_count
228 }
229 None => 0,
230 };
231 drop(clients);
232 if count >= DEFAULT_UNHEALTHY_THRESHOLD {
233 UnhealthyOutcome::ShouldEvict {
234 failure_count: count,
235 }
236 } else {
237 UnhealthyOutcome::StillHealthy {
238 failure_count: count,
239 }
240 }
241 }
242
243 pub async fn record_last_heard(&self, peer: &SocketAddr, now: Instant) {
249 let mut clients = self.clients.lock().await;
250 if let Some(handle) = clients.get_mut(peer) {
251 handle.last_heard = now;
252 }
253 }
254
255 pub async fn module_of(&self, peer: &SocketAddr) -> Option<Module> {
261 let clients = self.clients.lock().await;
262 clients.get(peer).and_then(|handle| handle.module)
263 }
264
265 pub async fn access_of(&self, peer: &SocketAddr) -> Option<AccessPolicy> {
275 let clients = self.clients.lock().await;
276 clients.get(peer).map(|handle| handle.access)
277 }
278
279 pub async fn try_consume_tx_token(&self, peer: &SocketAddr, now: Instant) -> bool {
297 let mut clients = self.clients.lock().await;
298 match clients.get_mut(peer) {
299 Some(handle) => handle.tx_budget.try_consume(now, 1),
300 None => false,
301 }
302 }
303
304 pub async fn with_handle_mut<F, R>(&self, peer: &SocketAddr, f: F) -> Option<R>
319 where
320 F: FnOnce(&mut ClientHandle<P>) -> R,
321 {
322 let mut clients = self.clients.lock().await;
323 clients.get_mut(peer).map(f)
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::{ClientHandle, ClientPool, Instant, Module, SocketAddr};
330 use crate::reflector::AccessPolicy;
331 use dstar_gateway_core::ServerSessionCore;
332 use dstar_gateway_core::session::client::DExtra;
333 use dstar_gateway_core::types::ProtocolKind;
334 use std::net::{IpAddr, Ipv4Addr};
335
336 fn peer(port: u16) -> SocketAddr {
337 SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), port)
338 }
339
340 fn fresh_handle(port: u16) -> ClientHandle<DExtra> {
341 let core = ServerSessionCore::new(ProtocolKind::DExtra, peer(port), Module::C);
342 ClientHandle::new(core, AccessPolicy::ReadWrite, Instant::now())
343 }
344
345 #[tokio::test]
346 async fn insert_and_contains() {
347 let pool = ClientPool::<DExtra>::new();
348 assert_eq!(pool.len().await, 0);
349 assert!(!pool.contains(&peer(30001)).await);
350
351 pool.insert(peer(30001), fresh_handle(30001)).await;
352 assert_eq!(pool.len().await, 1);
353 assert!(pool.contains(&peer(30001)).await);
354 }
355
356 #[tokio::test]
357 async fn remove_returns_handle() {
358 let pool = ClientPool::<DExtra>::new();
359 pool.insert(peer(30001), fresh_handle(30001)).await;
360 let removed = pool.remove(&peer(30001)).await;
361 assert!(removed.is_some());
362 assert_eq!(pool.len().await, 0);
363 let removed_again = pool.remove(&peer(30001)).await;
365 assert!(removed_again.is_none());
366 }
367
368 #[tokio::test]
369 async fn set_module_populates_reverse_index() {
370 let pool = ClientPool::<DExtra>::new();
371 pool.insert(peer(30001), fresh_handle(30001)).await;
372 pool.insert(peer(30002), fresh_handle(30002)).await;
373 pool.set_module(&peer(30001), Module::C).await;
374 pool.set_module(&peer(30002), Module::C).await;
375
376 let members = pool.members_of_module(Module::C).await;
377 assert_eq!(members.len(), 2);
378 assert!(members.contains(&peer(30001)));
379 assert!(members.contains(&peer(30002)));
380 }
381
382 #[tokio::test]
383 async fn set_module_moves_peer_between_modules() {
384 let pool = ClientPool::<DExtra>::new();
385 pool.insert(peer(30001), fresh_handle(30001)).await;
386 pool.set_module(&peer(30001), Module::C).await;
387 pool.set_module(&peer(30001), Module::D).await;
388
389 assert!(pool.members_of_module(Module::C).await.is_empty());
390 let d_members = pool.members_of_module(Module::D).await;
391 assert_eq!(d_members, vec![peer(30001)]);
392 }
393
394 #[tokio::test]
395 async fn members_of_empty_module_is_empty() {
396 let pool = ClientPool::<DExtra>::new();
397 assert!(pool.members_of_module(Module::Z).await.is_empty());
398 }
399
400 #[tokio::test]
401 async fn mark_unhealthy_increments_counter() {
402 let pool = ClientPool::<DExtra>::new();
403 pool.insert(peer(30001), fresh_handle(30001)).await;
404 assert_eq!(
405 pool.mark_unhealthy(&peer(30001)).await,
406 super::UnhealthyOutcome::StillHealthy { failure_count: 1 }
407 );
408 assert_eq!(
409 pool.mark_unhealthy(&peer(30001)).await,
410 super::UnhealthyOutcome::StillHealthy { failure_count: 2 }
411 );
412 assert_eq!(
413 pool.mark_unhealthy(&peer(30001)).await,
414 super::UnhealthyOutcome::StillHealthy { failure_count: 3 }
415 );
416 }
417
418 #[tokio::test]
419 async fn mark_unhealthy_missing_peer_is_zero() {
420 let pool = ClientPool::<DExtra>::new();
421 assert_eq!(
422 pool.mark_unhealthy(&peer(30001)).await,
423 super::UnhealthyOutcome::StillHealthy { failure_count: 0 }
424 );
425 }
426
427 #[tokio::test]
428 async fn mark_unhealthy_threshold_triggers_eviction() {
429 let pool = ClientPool::<DExtra>::new();
430 pool.insert(peer(30001), fresh_handle(30001)).await;
431 for expected in 1_u32..=4 {
433 assert_eq!(
434 pool.mark_unhealthy(&peer(30001)).await,
435 super::UnhealthyOutcome::StillHealthy {
436 failure_count: expected
437 }
438 );
439 }
440 assert_eq!(
442 pool.mark_unhealthy(&peer(30001)).await,
443 super::UnhealthyOutcome::ShouldEvict { failure_count: 5 }
444 );
445 }
446
447 #[tokio::test]
448 async fn record_last_heard_updates_timestamp() {
449 let pool = ClientPool::<DExtra>::new();
450 pool.insert(peer(30001), fresh_handle(30001)).await;
451 let later = Instant::now() + std::time::Duration::from_secs(5);
452 pool.record_last_heard(&peer(30001), later).await;
453 }
456
457 #[tokio::test]
458 async fn remove_clears_reverse_index() {
459 let pool = ClientPool::<DExtra>::new();
460 pool.insert(peer(30001), fresh_handle(30001)).await;
461 pool.set_module(&peer(30001), Module::C).await;
462 drop(pool.remove(&peer(30001)).await);
463 assert!(pool.members_of_module(Module::C).await.is_empty());
464 }
465
466 #[tokio::test]
467 async fn module_of_returns_assigned_module() {
468 let pool = ClientPool::<DExtra>::new();
469 pool.insert(peer(30001), fresh_handle(30001)).await;
470 assert!(pool.module_of(&peer(30001)).await.is_none());
471 pool.set_module(&peer(30001), Module::C).await;
472 assert_eq!(pool.module_of(&peer(30001)).await, Some(Module::C));
473 }
474
475 #[tokio::test]
476 async fn module_of_missing_peer_is_none() {
477 let pool = ClientPool::<DExtra>::new();
478 assert!(pool.module_of(&peer(30001)).await.is_none());
479 }
480}