use serde::{de::DeserializeOwned, Serialize}; #[allow(unused)] use serde_json::Value; use tokio::select; use tokio::sync::watch; use tokio::sync::Mutex; use crate::connection::{Connection, SocketConnection}; use crate::error::{Error, Result}; #[cfg(feature = "rich_presence")] use crate::models::rich_presence::{ Activity, CloseActivityRequestArgs, SendActivityJoinInviteArgs, SetActivityArgs, }; use crate::models::{ commands::{Subscription, SubscriptionArgs}, message::Message, payload::Payload, Command, Event, OpCode, }; macro_rules! hollow { ($expr:expr) => {{ let ref_ = $expr.borrow(); ref_.hollow() }}; } #[derive(Debug)] enum ConnectionState { Disconnected, Connecting, Connected(T), Disconnecting, } impl Clone for ConnectionState { fn clone(&self) -> Self { match self { Self::Disconnected => Self::Disconnected, Self::Connecting => Self::Connecting, Self::Connected(arg0) => Self::Connected(arg0.clone()), Self::Disconnecting => Self::Disconnecting, } } } impl Copy for ConnectionState {} impl ConnectionState { pub fn hollow(&self) -> ConnectionState<()> { match self { ConnectionState::Disconnected => ConnectionState::Disconnected, ConnectionState::Connecting => ConnectionState::Connecting, ConnectionState::Connected(_) => ConnectionState::Connected(()), ConnectionState::Disconnecting => ConnectionState::Disconnecting, } } } macro_rules! yield_while { ($receive:expr, $pat:pat) => {{ let mut new_state: _; loop { new_state = $receive; match new_state { $pat => tokio::task::yield_now().await, _ => break, } } new_state }}; } type FullConnectionState = ConnectionState<(u64, Mutex)>; #[derive(Debug)] pub struct Client { state_sender: watch::Sender, state_receiver: watch::Receiver, update: Mutex<()>, } impl Default for Client { fn default() -> Self { let (state_sender, state_receiver) = watch::channel(ConnectionState::Disconnected); Self { state_sender, state_receiver, update: Mutex::new(()), } } } impl Client { /// Returns the client id used by the current connection, or [`None`] if the client is not [`ConnectionState::Connected`]. pub fn client_id(&self) -> Option { match *self.state_receiver.borrow() { ConnectionState::Connected((client_id, _)) => Some(client_id), _ => None, } } #[instrument(level = "debug")] async fn connect_and_handshake(client_id: u64) -> Result { debug!("Connecting"); let mut new_connection = SocketConnection::connect().await?; debug!("Performing handshake"); new_connection.handshake(client_id).await?; debug!("Handshake completed"); Ok(new_connection) } #[instrument(level = "debug")] async fn connect0(&self, client_id: u64, conn: Result) -> Result<()> { let _state_guard = self.update.lock().await; match hollow!(self.state_receiver) { state @ ConnectionState::Disconnected => panic!( "Illegal state during connection process {:?} -> {:?}", ConnectionState::<()>::Connecting, state ), ConnectionState::Connecting => match conn { Ok(conn) => { self.state_sender .send(ConnectionState::Connected((client_id, Mutex::new(conn)))) .expect("the receiver cannot be dropped without the sender!"); debug!("Connected"); Ok(()) } Err(e) => { self.state_sender .send(ConnectionState::Disconnected) .expect("the receiver cannot be dropped without the sender!"); debug!("Failed to connect and disconnected"); Err(e) } }, ConnectionState::Connected(_) => panic!("Illegal concurrent connection!"), ConnectionState::Disconnecting => { match conn { Ok(conn) => { if let Err(e) = conn.disconnect().await { error!("failed to disconnect properly: {}", e); } } Err(e) => { error!("failed connection: {}", e); } } self.state_sender .send(ConnectionState::Disconnected) .expect("the receiver cannot be dropped without the sender!"); Err(Error::ConnectionClosed) } } } #[instrument(level = "info")] pub async fn connect(&self, client_id: u64) -> Result<()> { match hollow!(self.state_receiver) { ConnectionState::Connected(_) => Ok(()), _ => { let state_guard = self.update.lock().await; match hollow!(self.state_receiver) { ConnectionState::Connected(_) => Ok(()), ConnectionState::Disconnecting => Err(Error::ConnectionClosed), ConnectionState::Connecting => { match yield_while!( hollow!(self.state_receiver), ConnectionState::Connecting ) { ConnectionState::Connected(_) => Ok(()), ConnectionState::Disconnecting => Err(Error::ConnectionClosed), ConnectionState::Disconnected => Err(Error::ConnectionClosed), ConnectionState::Connecting => unreachable!(), } } ConnectionState::Disconnected => { self.state_sender .send(ConnectionState::Connecting) .expect("the receiver cannot be dropped without the sender!"); drop(state_guard); select! { conn = Self::connect_and_handshake(client_id) => { self.connect0(client_id, conn).await } // _ = tokio::task::yield_now() if self.state_receiver.borrow().is_disconnecting() => { // self.state_sender.send(ConnectionState::Disconnected).expect("the receiver cannot be dropped without the sender!"); // Err(Error::ConnectionClosed) // } } } } } } } /// If currently connected, the function will close the connection. /// If currently connecting, the function will wait for the connection to be established and will immediately close it. /// If currently disconnecting, the function will wait for the connection to be closed. #[instrument(level = "info")] pub async fn disconnect(&self) { let _state_guard = self.update.lock().await; trace!("aquired state guard for disconnect"); match hollow!(self.state_receiver) { ConnectionState::Disconnected => {} ref state @ ConnectionState::Disconnecting => { trace!("Waiting while in disconnecting state(b)"); match yield_while!(hollow!(self.state_receiver), ConnectionState::Disconnecting) { ConnectionState::Disconnected => {} ConnectionState::Disconnecting => unreachable!(), new_state => panic!("Illegal state change {:?} -> {:?}", state, new_state), } } ConnectionState::Connecting => { self.state_sender .send(ConnectionState::Disconnecting) .expect("the receiver cannot be dropped without the sender!"); } state @ ConnectionState::Connected(()) => { trace!("Sending disconnecting state"); let s = self .state_sender .send_replace(ConnectionState::Disconnecting); trace!("Sent disconnecting state"); match s { ConnectionState::Connected(conn) => { match conn.1.into_inner().disconnect().await { Err(e) => { error!("failed to disconnect properly: {}", e); } _ => self .state_sender .send(ConnectionState::Disconnected) .expect("the receiver cannot be dropped without the sender!"), } } new_state @ ConnectionState::Connecting => { panic!("Illegal state change {:?} -> {:?}", state, new_state); } state @ ConnectionState::Disconnecting => { trace!("Waiting while in disconnecting state(b)"); match yield_while!( hollow!(self.state_receiver), ConnectionState::Disconnecting ) { ConnectionState::Disconnected => {} ConnectionState::Disconnecting => unreachable!(), new_state => { panic!("Illegal state change {:?} -> {:?}", state, new_state) } } } ConnectionState::Disconnected => {} } } } } #[instrument(level = "info")] async fn execute(&self, cmd: Command, args: A, evt: Option) -> Result> where A: std::fmt::Debug + Serialize + Send + Sync, E: std::fmt::Debug + Serialize + DeserializeOwned + Send + Sync, { let message = Message::new( OpCode::Frame, Payload::with_nonce(cmd, Some(args), None, evt), ); let result = match &*self.state_receiver.borrow() { ConnectionState::Connected((_, conn)) => { try { let mut conn = conn.lock().await; conn.send(message).await?; conn.recv().await? } } _ => Err(Error::ConnectionClosed), }; let Message { payload, .. } = match result { Ok(msg) => Ok(msg), Err(e @ Error::ConnectionClosed | e @ Error::ConnectionClosedWhileSending(_)) => { debug!("disconnecting because connection is closed."); self.disconnect().await; Err(e) } Err(e) => Err(e), }?; let response: Payload = serde_json::from_str(&payload)?; match response.evt { Some(Event::Error) => Err(Error::SubscriptionFailed), _ => Ok(response), } } #[cfg(feature = "rich_presence")] pub async fn set_activity(&self, activity: Activity) -> Result> { self.execute(Command::SetActivity, SetActivityArgs::new(activity), None) .await } #[cfg(feature = "rich_presence")] pub async fn clear_activity(&self) -> Result> { self.execute(Command::SetActivity, SetActivityArgs::default(), None) .await } // NOTE: Not sure what the actual response values of // SEND_ACTIVITY_JOIN_INVITE and CLOSE_ACTIVITY_REQUEST are, // they are not documented. #[cfg(feature = "rich_presence")] pub async fn send_activity_join_invite(&self, user_id: u64) -> Result> { self.execute( Command::SendActivityJoinInvite, SendActivityJoinInviteArgs::new(user_id), None, ) .await } #[cfg(feature = "rich_presence")] pub async fn close_activity_request(&self, user_id: u64) -> Result> { self.execute( Command::CloseActivityRequest, CloseActivityRequestArgs::new(user_id), None, ) .await } pub async fn subscribe(&self, evt: Event, f: F) -> Result> where F: FnOnce(SubscriptionArgs) -> SubscriptionArgs, { self.execute(Command::Subscribe, f(SubscriptionArgs::new()), Some(evt)) .await } pub async fn unsubscribe(&self, evt: Event, f: F) -> Result> where F: FnOnce(SubscriptionArgs) -> SubscriptionArgs, { self.execute(Command::Unsubscribe, f(SubscriptionArgs::new()), Some(evt)) .await } }