use serde::{de::DeserializeOwned, Serialize}; #[allow(unused)] use serde_json::Value; use tokio::select; use tokio::sync::watch::Ref; use tokio::sync::{watch, Mutex}; use crate::connection::{Connection, SocketConnection}; use crate::error::{Error, Result}; #[cfg(feature = "rich_presence")] use crate::models::rich_presence::{ Activity, CloseActivityRequestArgs, SendActivityJoinInviteArgs, SetActivityArgs, }; use crate::models::{ commands::{Subscription, SubscriptionArgs}, message::Message, payload::Payload, Command, Event, OpCode, }; macro_rules! hollow { ($expr:expr) => {{ let ref_ = $expr.borrow(); ref_.hollow() }}; } #[derive(Debug)] enum ConnectionState { Disconnected, Connecting, Connected(T), Disconnecting, } impl Clone for ConnectionState { fn clone(&self) -> Self { match self { Self::Disconnected => Self::Disconnected, Self::Connecting => Self::Connecting, Self::Connected(arg0) => Self::Connected(arg0.clone()), Self::Disconnecting => Self::Disconnecting, } } } impl Copy for ConnectionState {} impl ConnectionState { pub fn is_disconnected(&self) -> bool { match self { ConnectionState::Disconnected => true, _ => false, } } pub fn is_connecting(&self) -> bool { match self { ConnectionState::Connecting => true, _ => false, } } pub fn is_connected(&self) -> bool { match self { ConnectionState::Connected(_) => true, _ => false, } } pub fn is_disconnecting(&self) -> bool { match self { ConnectionState::Disconnecting => true, _ => false, } } } impl ConnectionState { pub fn hollow(&self) -> ConnectionState<()> { match self { ConnectionState::Disconnected => ConnectionState::Disconnected, ConnectionState::Connecting => ConnectionState::Connecting, ConnectionState::Connected(_) => ConnectionState::Connected(()), ConnectionState::Disconnecting => ConnectionState::Disconnecting, } } } macro_rules! yield_while { ($receive:expr, $pat:pat) => {{ let mut new_state: _; loop { new_state = $receive; match new_state { $pat => tokio::task::yield_now().await, _ => break, } } new_state }}; } #[derive(Debug)] pub struct Client { state: ( watch::Sender)>>, watch::Receiver)>>, ), update: Mutex<()>, } impl Default for Client { fn default() -> Self { Self { state: watch::channel(ConnectionState::Disconnected), update: Mutex::new(()), } } } impl Client { /// Returns the client id used by the current connection, or [`None`] if the client is not [`ConnectionState::Connected`]. pub fn client_id(&self) -> Option { match *self.state.1.borrow() { ConnectionState::Connected((client_id, _)) => Some(client_id), _ => None, } } async fn connect_and_handshake(client_id: u64) -> Result { debug!("Connecting"); let mut new_connection = SocketConnection::connect().await?; debug!("Performing handshake"); new_connection.handshake(client_id).await?; debug!("Handshake completed"); Ok(new_connection) } async fn connect0(&self, client_id: u64, conn: Result) -> Result<()> { let _state_guard = self.update.lock().await; match hollow!(self.state.1) { state @ ConnectionState::Disconnected => panic!( "Illegal state during connection process {:?} -> {:?}", ConnectionState::<()>::Connecting, state ), ConnectionState::Connecting => { self.state .0 .send(ConnectionState::Connected((client_id, Mutex::new(conn?)))) .expect("the receiver cannot be dropped without the sender!"); debug!("Connected"); Ok(()) } ConnectionState::Connected(_) => panic!("Illegal concurrent connection!"), ConnectionState::Disconnecting => { match conn { Ok(conn) => match conn.disconnect().await { Err(e) => { error!("failed to disconnect properly: {}", e); } _ => {} }, Err(e) => { error!("failed connection: {}", e); } } self.state .0 .send(ConnectionState::Disconnected) .expect("the receiver cannot be dropped without the sender!"); Err(Error::ConnectionClosed) } } } pub async fn connect(&self, client_id: u64) -> Result<()> { match hollow!(self.state.1) { ConnectionState::Connected(_) => Ok(()), _ => { let state_guard = self.update.lock().await; match hollow!(self.state.1) { ConnectionState::Connected(_) => Ok(()), ConnectionState::Disconnecting => Err(Error::ConnectionClosed), ConnectionState::Connecting => { match yield_while!(hollow!(self.state.1), ConnectionState::Connecting) { ConnectionState::Connected(_) => Ok(()), ConnectionState::Disconnecting => Err(Error::ConnectionClosed), ConnectionState::Disconnected => Err(Error::ConnectionClosed), ConnectionState::Connecting => unreachable!(), } } ConnectionState::Disconnected => { self.state .0 .send(ConnectionState::Connecting) .expect("the receiver cannot be dropped without the sender!"); drop(state_guard); select! { conn = Self::connect_and_handshake(client_id) => { self.connect0(client_id, conn).await } // _ = tokio::task::yield_now() if self.state.1.borrow().is_disconnecting() => { // self.state.0.send(ConnectionState::Disconnected).expect("the receiver cannot be dropped without the sender!"); // Err(Error::ConnectionClosed) // } } } } } } } /// If currently connected, the function will close the connection. /// If currently connecting, the function will wait for the connection to be established and will immediately close it. /// If currently disconnecting, the function will wait for the connection to be closed. pub async fn disconnect(&self) { let _state_guard = self.update.lock().await; match hollow!(self.state.1) { ConnectionState::Disconnected => {} ref state @ ConnectionState::Disconnecting => { match yield_while!(hollow!(self.state.1), ConnectionState::Disconnecting) { ConnectionState::Disconnected => {} ConnectionState::Disconnecting => unreachable!(), new_state => panic!("Illegal state change {:?} -> {:?}", state, new_state), } } ConnectionState::Connecting => { self.state .0 .send(ConnectionState::Disconnecting) .expect("the receiver cannot be dropped without the sender!"); } state @ ConnectionState::Connected(_) => { match self.state.0.send_replace(ConnectionState::Disconnecting) { ConnectionState::Connected(conn) => { match conn.1.into_inner().disconnect().await { Err(e) => { error!("failed to disconnect properly: {}", e); } _ => {} } } new_state @ ConnectionState::Connecting => { panic!("Illegal state change {:?} -> {:?}", state, new_state); } state @ ConnectionState::Disconnecting => { match yield_while!(hollow!(self.state.1), ConnectionState::Disconnecting) { ConnectionState::Disconnected => {} ConnectionState::Disconnecting => unreachable!(), new_state => { panic!("Illegal state change {:?} -> {:?}", state, new_state) } } } ConnectionState::Disconnected => {} } } } } async fn execute(&self, cmd: Command, args: A, evt: Option) -> Result> where A: Serialize + Send + Sync, E: Serialize + DeserializeOwned + Send + Sync, { let message = Message::new( OpCode::Frame, Payload::with_nonce(cmd, Some(args), None, evt), ); match *self.state.1.borrow() { ConnectionState::Connected((_, ref conn)) => { let mut conn = conn.lock().await; conn.send(message).await?; let Message { payload, .. } = conn.recv().await?; let response: Payload = serde_json::from_str(&payload)?; match response.evt { Some(Event::Error) => Err(Error::SubscriptionFailed), _ => Ok(response), } } _ => Err(Error::ConnectionClosed), } } #[cfg(feature = "rich_presence")] pub async fn set_activity(&self, activity: Activity) -> Result> { self.execute(Command::SetActivity, SetActivityArgs::new(activity), None) .await } #[cfg(feature = "rich_presence")] pub async fn clear_activity(&self) -> Result> { self.execute(Command::SetActivity, SetActivityArgs::default(), None) .await } // NOTE: Not sure what the actual response values of // SEND_ACTIVITY_JOIN_INVITE and CLOSE_ACTIVITY_REQUEST are, // they are not documented. #[cfg(feature = "rich_presence")] pub async fn send_activity_join_invite(&self, user_id: u64) -> Result> { self.execute( Command::SendActivityJoinInvite, SendActivityJoinInviteArgs::new(user_id), None, ) .await } #[cfg(feature = "rich_presence")] pub async fn close_activity_request(&self, user_id: u64) -> Result> { self.execute( Command::CloseActivityRequest, CloseActivityRequestArgs::new(user_id), None, ) .await } pub async fn subscribe(&self, evt: Event, f: F) -> Result> where F: FnOnce(SubscriptionArgs) -> SubscriptionArgs, { self.execute(Command::Subscribe, f(SubscriptionArgs::new()), Some(evt)) .await } pub async fn unsubscribe(&self, evt: Event, f: F) -> Result> where F: FnOnce(SubscriptionArgs) -> SubscriptionArgs, { self.execute(Command::Unsubscribe, f(SubscriptionArgs::new()), Some(evt)) .await } }