352 lines
13 KiB
Rust
352 lines
13 KiB
Rust
use serde::{de::DeserializeOwned, Serialize};
|
|
#[allow(unused)]
|
|
use serde_json::Value;
|
|
use tokio::select;
|
|
use tokio::sync::watch;
|
|
use tokio::sync::Mutex;
|
|
|
|
use crate::connection::{Connection, SocketConnection};
|
|
use crate::error::{Error, Result};
|
|
#[cfg(feature = "rich_presence")]
|
|
use crate::models::rich_presence::{
|
|
Activity, CloseActivityRequestArgs, SendActivityJoinInviteArgs, SetActivityArgs,
|
|
};
|
|
use crate::models::{
|
|
commands::{Subscription, SubscriptionArgs},
|
|
message::Message,
|
|
payload::Payload,
|
|
Command, Event, OpCode,
|
|
};
|
|
|
|
macro_rules! hollow {
|
|
($expr:expr) => {{
|
|
let ref_ = $expr.borrow();
|
|
ref_.hollow()
|
|
}};
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
enum ConnectionState<T> {
|
|
Disconnected,
|
|
Connecting,
|
|
Connected(T),
|
|
Disconnecting,
|
|
}
|
|
|
|
impl<T: Clone> Clone for ConnectionState<T> {
|
|
fn clone(&self) -> Self {
|
|
match self {
|
|
Self::Disconnected => Self::Disconnected,
|
|
Self::Connecting => Self::Connecting,
|
|
Self::Connected(arg0) => Self::Connected(arg0.clone()),
|
|
Self::Disconnecting => Self::Disconnecting,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T: Copy> Copy for ConnectionState<T> {}
|
|
|
|
impl<T> ConnectionState<T> {
|
|
pub fn hollow(&self) -> ConnectionState<()> {
|
|
match self {
|
|
ConnectionState::Disconnected => ConnectionState::Disconnected,
|
|
ConnectionState::Connecting => ConnectionState::Connecting,
|
|
ConnectionState::Connected(_) => ConnectionState::Connected(()),
|
|
ConnectionState::Disconnecting => ConnectionState::Disconnecting,
|
|
}
|
|
}
|
|
}
|
|
|
|
macro_rules! yield_while {
|
|
($receive:expr, $pat:pat) => {{
|
|
let mut new_state: _;
|
|
loop {
|
|
new_state = $receive;
|
|
match new_state {
|
|
$pat => tokio::task::yield_now().await,
|
|
_ => break,
|
|
}
|
|
}
|
|
|
|
new_state
|
|
}};
|
|
}
|
|
|
|
type FullConnectionState = ConnectionState<(u64, Mutex<SocketConnection>)>;
|
|
|
|
#[derive(Debug)]
|
|
pub struct Client {
|
|
state_sender: watch::Sender<FullConnectionState>,
|
|
state_receiver: watch::Receiver<FullConnectionState>,
|
|
update: Mutex<()>,
|
|
}
|
|
|
|
impl Default for Client {
|
|
fn default() -> Self {
|
|
let (state_sender, state_receiver) = watch::channel(ConnectionState::Disconnected);
|
|
Self {
|
|
state_sender,
|
|
state_receiver,
|
|
update: Mutex::new(()),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Client {
|
|
/// Returns the client id used by the current connection, or [`None`] if the client is not [`ConnectionState::Connected`].
|
|
pub fn client_id(&self) -> Option<u64> {
|
|
match *self.state_receiver.borrow() {
|
|
ConnectionState::Connected((client_id, _)) => Some(client_id),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
#[instrument(level = "debug")]
|
|
async fn connect_and_handshake(client_id: u64) -> Result<SocketConnection> {
|
|
debug!("Connecting");
|
|
|
|
let mut new_connection = SocketConnection::connect().await?;
|
|
|
|
debug!("Performing handshake");
|
|
new_connection.handshake(client_id).await?;
|
|
debug!("Handshake completed");
|
|
|
|
Ok(new_connection)
|
|
}
|
|
|
|
#[instrument(level = "debug")]
|
|
async fn connect0(&self, client_id: u64, conn: Result<SocketConnection>) -> Result<()> {
|
|
let _state_guard = self.update.lock().await;
|
|
match hollow!(self.state_receiver) {
|
|
state @ ConnectionState::Disconnected => panic!(
|
|
"Illegal state during connection process {:?} -> {:?}",
|
|
ConnectionState::<()>::Connecting,
|
|
state
|
|
),
|
|
ConnectionState::Connecting => match conn {
|
|
Ok(conn) => {
|
|
self.state_sender
|
|
.send(ConnectionState::Connected((client_id, Mutex::new(conn))))
|
|
.expect("the receiver cannot be dropped without the sender!");
|
|
debug!("Connected");
|
|
Ok(())
|
|
}
|
|
Err(e) => {
|
|
self.state_sender
|
|
.send(ConnectionState::Disconnected)
|
|
.expect("the receiver cannot be dropped without the sender!");
|
|
debug!("Failed to connect and disconnected");
|
|
Err(e)
|
|
}
|
|
},
|
|
ConnectionState::Connected(_) => panic!("Illegal concurrent connection!"),
|
|
ConnectionState::Disconnecting => {
|
|
match conn {
|
|
Ok(conn) => {
|
|
if let Err(e) = conn.disconnect().await {
|
|
error!("failed to disconnect properly: {}", e);
|
|
}
|
|
}
|
|
Err(e) => {
|
|
error!("failed connection: {}", e);
|
|
}
|
|
}
|
|
self.state_sender
|
|
.send(ConnectionState::Disconnected)
|
|
.expect("the receiver cannot be dropped without the sender!");
|
|
Err(Error::ConnectionClosed)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[instrument(level = "info")]
|
|
pub async fn connect(&self, client_id: u64) -> Result<()> {
|
|
match hollow!(self.state_receiver) {
|
|
ConnectionState::Connected(_) => Ok(()),
|
|
_ => {
|
|
let state_guard = self.update.lock().await;
|
|
match hollow!(self.state_receiver) {
|
|
ConnectionState::Connected(_) => Ok(()),
|
|
ConnectionState::Disconnecting => Err(Error::ConnectionClosed),
|
|
ConnectionState::Connecting => {
|
|
match yield_while!(
|
|
hollow!(self.state_receiver),
|
|
ConnectionState::Connecting
|
|
) {
|
|
ConnectionState::Connected(_) => Ok(()),
|
|
ConnectionState::Disconnecting => Err(Error::ConnectionClosed),
|
|
ConnectionState::Disconnected => Err(Error::ConnectionClosed),
|
|
ConnectionState::Connecting => unreachable!(),
|
|
}
|
|
}
|
|
ConnectionState::Disconnected => {
|
|
self.state_sender
|
|
.send(ConnectionState::Connecting)
|
|
.expect("the receiver cannot be dropped without the sender!");
|
|
|
|
drop(state_guard);
|
|
select! {
|
|
conn = Self::connect_and_handshake(client_id) => {
|
|
self.connect0(client_id, conn).await
|
|
}
|
|
// _ = tokio::task::yield_now() if self.state_receiver.borrow().is_disconnecting() => {
|
|
// self.state_sender.send(ConnectionState::Disconnected).expect("the receiver cannot be dropped without the sender!");
|
|
// Err(Error::ConnectionClosed)
|
|
// }
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// If currently connected, the function will close the connection.
|
|
/// If currently connecting, the function will wait for the connection to be established and will immediately close it.
|
|
/// If currently disconnecting, the function will wait for the connection to be closed.
|
|
#[instrument(level = "info")]
|
|
pub async fn disconnect(&self) {
|
|
let _state_guard = self.update.lock().await;
|
|
trace!("aquired state guard for disconnect");
|
|
match hollow!(self.state_receiver) {
|
|
ConnectionState::Disconnected => {}
|
|
ref state @ ConnectionState::Disconnecting => {
|
|
trace!("Waiting while in disconnecting state(b)");
|
|
match yield_while!(hollow!(self.state_receiver), ConnectionState::Disconnecting) {
|
|
ConnectionState::Disconnected => {}
|
|
ConnectionState::Disconnecting => unreachable!(),
|
|
new_state => panic!("Illegal state change {:?} -> {:?}", state, new_state),
|
|
}
|
|
}
|
|
ConnectionState::Connecting => {
|
|
self.state_sender
|
|
.send(ConnectionState::Disconnecting)
|
|
.expect("the receiver cannot be dropped without the sender!");
|
|
}
|
|
state @ ConnectionState::Connected(()) => {
|
|
trace!("Sending disconnecting state");
|
|
let s = self
|
|
.state_sender
|
|
.send_replace(ConnectionState::Disconnecting);
|
|
trace!("Sent disconnecting state");
|
|
match s {
|
|
ConnectionState::Connected(conn) => {
|
|
match conn.1.into_inner().disconnect().await {
|
|
Err(e) => {
|
|
error!("failed to disconnect properly: {}", e);
|
|
}
|
|
_ => self
|
|
.state_sender
|
|
.send(ConnectionState::Disconnected)
|
|
.expect("the receiver cannot be dropped without the sender!"),
|
|
}
|
|
}
|
|
new_state @ ConnectionState::Connecting => {
|
|
panic!("Illegal state change {:?} -> {:?}", state, new_state);
|
|
}
|
|
state @ ConnectionState::Disconnecting => {
|
|
trace!("Waiting while in disconnecting state(b)");
|
|
match yield_while!(
|
|
hollow!(self.state_receiver),
|
|
ConnectionState::Disconnecting
|
|
) {
|
|
ConnectionState::Disconnected => {}
|
|
ConnectionState::Disconnecting => unreachable!(),
|
|
new_state => {
|
|
panic!("Illegal state change {:?} -> {:?}", state, new_state)
|
|
}
|
|
}
|
|
}
|
|
ConnectionState::Disconnected => {}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[instrument(level = "info")]
|
|
async fn execute<A, E>(&self, cmd: Command, args: A, evt: Option<Event>) -> Result<Payload<E>>
|
|
where
|
|
A: std::fmt::Debug + Serialize + Send + Sync,
|
|
E: std::fmt::Debug + Serialize + DeserializeOwned + Send + Sync,
|
|
{
|
|
let message = Message::new(
|
|
OpCode::Frame,
|
|
Payload::with_nonce(cmd, Some(args), None, evt),
|
|
);
|
|
let result = match &*self.state_receiver.borrow() {
|
|
ConnectionState::Connected((_, conn)) => {
|
|
try {
|
|
let mut conn = conn.lock().await;
|
|
conn.send(message).await?;
|
|
conn.recv().await?
|
|
}
|
|
}
|
|
_ => Err(Error::ConnectionClosed),
|
|
};
|
|
let Message { payload, .. } = match result {
|
|
Ok(msg) => Ok(msg),
|
|
Err(e @ Error::ConnectionClosed | e @ Error::ConnectionClosedWhileSending(_)) => {
|
|
debug!("disconnecting because connection is closed.");
|
|
self.disconnect().await;
|
|
Err(e)
|
|
}
|
|
Err(e) => Err(e),
|
|
}?;
|
|
let response: Payload<E> = serde_json::from_str(&payload)?;
|
|
match response.evt {
|
|
Some(Event::Error) => Err(Error::SubscriptionFailed),
|
|
_ => Ok(response),
|
|
}
|
|
}
|
|
|
|
#[cfg(feature = "rich_presence")]
|
|
pub async fn set_activity(&self, activity: Activity) -> Result<Payload<Activity>> {
|
|
self.execute(Command::SetActivity, SetActivityArgs::new(activity), None)
|
|
.await
|
|
}
|
|
|
|
#[cfg(feature = "rich_presence")]
|
|
pub async fn clear_activity(&self) -> Result<Payload<Activity>> {
|
|
self.execute(Command::SetActivity, SetActivityArgs::default(), None)
|
|
.await
|
|
}
|
|
|
|
// NOTE: Not sure what the actual response values of
|
|
// SEND_ACTIVITY_JOIN_INVITE and CLOSE_ACTIVITY_REQUEST are,
|
|
// they are not documented.
|
|
#[cfg(feature = "rich_presence")]
|
|
pub async fn send_activity_join_invite(&self, user_id: u64) -> Result<Payload<Value>> {
|
|
self.execute(
|
|
Command::SendActivityJoinInvite,
|
|
SendActivityJoinInviteArgs::new(user_id),
|
|
None,
|
|
)
|
|
.await
|
|
}
|
|
|
|
#[cfg(feature = "rich_presence")]
|
|
pub async fn close_activity_request(&self, user_id: u64) -> Result<Payload<Value>> {
|
|
self.execute(
|
|
Command::CloseActivityRequest,
|
|
CloseActivityRequestArgs::new(user_id),
|
|
None,
|
|
)
|
|
.await
|
|
}
|
|
|
|
pub async fn subscribe<F>(&self, evt: Event, f: F) -> Result<Payload<Subscription>>
|
|
where
|
|
F: FnOnce(SubscriptionArgs) -> SubscriptionArgs,
|
|
{
|
|
self.execute(Command::Subscribe, f(SubscriptionArgs::new()), Some(evt))
|
|
.await
|
|
}
|
|
|
|
pub async fn unsubscribe<F>(&self, evt: Event, f: F) -> Result<Payload<Subscription>>
|
|
where
|
|
F: FnOnce(SubscriptionArgs) -> SubscriptionArgs,
|
|
{
|
|
self.execute(Command::Unsubscribe, f(SubscriptionArgs::new()), Some(evt))
|
|
.await
|
|
}
|
|
}
|