diff --git a/Cargo.toml b/Cargo.toml index 2e184f8..4eb6441 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,33 +1,37 @@ [package] -authors = ["Patrick Auernig ", "Michael Pfaff "] -name = "discord-rpc-client" +name = "discord-rpc-client" +version = "0.5.0" description = "A Rust client for Discord RPC." -keywords = ["discord", "rpc", "ipc"] -license = "MIT" -readme = "README.md" -repository = "https://gitlab.com/valeth/discord-rpc-client.rs.git" -version = "0.4.0" - -[dependencies] -serde = "^1.0" -serde_derive = "^1.0" -serde_json = "^1.0" -byteorder = "^1.0" -log = "~0.4" -bytes = "^0.4" -parking_lot = "^0.7" -crossbeam-channel = "^0.3" - -[target.'cfg(windows)'.dependencies] -named_pipe = "0.3.0" - -[dependencies.uuid] -version = "^0.6.2" -features = ["v4"] - -[dev-dependencies] -simplelog = "~0.5" +authors = [ + "Patrick Auernig ", + "Michael Pfaff ", +] +keywords = ["discord", "rpc", "ipc"] +license = "MIT" +readme = "README.md" +repository = "https://gitlab.com/valeth/discord-rpc-client.rs.git" +edition = "2021" [features] default = ["rich_presence"] rich_presence = [] +tokio-parking_lot = ["tokio/parking_lot"] + +[dependencies] +async-trait = "0.1.52" +tokio = { version = "1.16", features = ["io-util", "net", "sync", "macros", "rt"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +byteorder = "1.0" +log = "0.4" +bytes = "1.1.0" +uuid = { version = "0.8", features = ["v4"] } + +[dev-dependencies] +simplelog = "0.11" +tokio = { version = "1.16", features = [ + "time", + "rt-multi-thread", + "macros", + "parking_lot", +] } diff --git a/examples/discord_presence.rs b/examples/discord_presence.rs index 4d11c26..31ba4cd 100644 --- a/examples/discord_presence.rs +++ b/examples/discord_presence.rs @@ -5,12 +5,19 @@ use discord_rpc_client::{models::Activity, Client as DiscordRPC}; use simplelog::*; use std::io; -fn main() -> discord_rpc_client::Result<()> { - TermLogger::init(LevelFilter::Debug, Config::default()).unwrap(); +#[tokio::main] +async fn main() -> discord_rpc_client::Result<()> { + TermLogger::init( + LevelFilter::Debug, + Config::default(), + TerminalMode::Mixed, + ColorChoice::Always, + ) + .unwrap(); - let mut drpc = DiscordRPC::new(425407036495495169); + let drpc = DiscordRPC::default(); - drpc.connect()?; + drpc.connect(425407036495495169).await?; loop { let mut buf = String::new(); @@ -19,16 +26,19 @@ fn main() -> discord_rpc_client::Result<()> { buf.pop(); if buf.is_empty() { - if let Err(why) = drpc.clear_activity() { + if let Err(why) = drpc.clear_activity().await { println!("Failed to clear presence: {}", why); } } else { - if let Err(why) = drpc.set_activity(Activity::new().state(buf).assets(|ass| { - ass.large_image("ferris_wat") - .large_text("wat.") - .small_image("rusting") - .small_text("rusting...") - })) { + if let Err(why) = drpc + .set_activity(Activity::new().state(buf).assets(|ass| { + ass.large_image("ferris_wat") + .large_text("wat.") + .small_image("rusting") + .small_text("rusting...") + })) + .await + { println!("Failed to set presence: {}", why); } } diff --git a/examples/discord_presence_subscribe.rs b/examples/discord_presence_subscribe.rs index b3742b8..9aa93af 100644 --- a/examples/discord_presence_subscribe.rs +++ b/examples/discord_presence_subscribe.rs @@ -5,26 +5,37 @@ use discord_rpc_client::{models::Event, Client as DiscordRPC}; use simplelog::*; use std::{thread, time}; -fn main() -> discord_rpc_client::Result<()> { - TermLogger::init(LevelFilter::Debug, Config::default()).unwrap(); +#[tokio::main] +async fn main() -> discord_rpc_client::Result<()> { + TermLogger::init( + LevelFilter::Debug, + Config::default(), + TerminalMode::Mixed, + ColorChoice::Always, + ) + .unwrap(); - let mut drpc = DiscordRPC::new(425407036495495169); + let drpc = DiscordRPC::default(); - drpc.connect()?; + drpc.connect(425407036495495169).await?; drpc.subscribe(Event::ActivityJoin, |j| j.secret("123456")) + .await .expect("Failed to subscribe to event"); drpc.subscribe(Event::ActivitySpectate, |s| s.secret("123456")) + .await .expect("Failed to subscribe to event"); drpc.subscribe(Event::ActivityJoinRequest, |s| s) + .await .expect("Failed to subscribe to event"); drpc.unsubscribe(Event::ActivityJoinRequest, |j| j) + .await .expect("Failed to unsubscribe from event"); loop { - thread::sleep(time::Duration::from_millis(500)); + tokio::time::sleep(time::Duration::from_millis(500)).await; } } diff --git a/src/client.rs b/src/client.rs index 72e45d7..1f124ee 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,45 +1,271 @@ use serde::{de::DeserializeOwned, Serialize}; #[allow(unused)] use serde_json::Value; +use tokio::select; +use tokio::sync::watch::Ref; +use tokio::sync::{watch, Mutex}; -use connection::Manager as ConnectionManager; -use error::{Error, Result}; +use crate::connection::{Connection, SocketConnection}; +use crate::error::{Error, Result}; #[cfg(feature = "rich_presence")] -use models::rich_presence::{ +use crate::models::rich_presence::{ Activity, CloseActivityRequestArgs, SendActivityJoinInviteArgs, SetActivityArgs, }; -use models::{ +use crate::models::{ commands::{Subscription, SubscriptionArgs}, message::Message, payload::Payload, Command, Event, OpCode, }; -#[derive(Clone, Debug)] -#[repr(transparent)] +macro_rules! hollow { + ($expr:expr) => {{ + let ref_ = $expr.borrow(); + ref_.hollow() + }}; +} + +#[derive(Debug)] +enum ConnectionState { + Disconnected, + Connecting, + Connected(T), + Disconnecting, +} + +impl Clone for ConnectionState { + fn clone(&self) -> Self { + match self { + Self::Disconnected => Self::Disconnected, + Self::Connecting => Self::Connecting, + Self::Connected(arg0) => Self::Connected(arg0.clone()), + Self::Disconnecting => Self::Disconnecting, + } + } +} + +impl Copy for ConnectionState {} + +impl ConnectionState { + pub fn is_disconnected(&self) -> bool { + match self { + ConnectionState::Disconnected => true, + _ => false, + } + } + + pub fn is_connecting(&self) -> bool { + match self { + ConnectionState::Connecting => true, + _ => false, + } + } + + pub fn is_connected(&self) -> bool { + match self { + ConnectionState::Connected(_) => true, + _ => false, + } + } + + pub fn is_disconnecting(&self) -> bool { + match self { + ConnectionState::Disconnecting => true, + _ => false, + } + } +} + +impl ConnectionState { + pub fn hollow(&self) -> ConnectionState<()> { + match self { + ConnectionState::Disconnected => ConnectionState::Disconnected, + ConnectionState::Connecting => ConnectionState::Connecting, + ConnectionState::Connected(_) => ConnectionState::Connected(()), + ConnectionState::Disconnecting => ConnectionState::Disconnecting, + } + } +} + +macro_rules! yield_while { + ($receive:expr, $pat:pat) => {{ + let mut new_state: _; + loop { + new_state = $receive; + match new_state { + $pat => tokio::task::yield_now().await, + _ => break, + } + } + + new_state + }}; +} + +#[derive(Debug)] pub struct Client { - connection_manager: ConnectionManager, + state: ( + watch::Sender)>>, + watch::Receiver)>>, + ), + update: Mutex<()>, +} + +impl Default for Client { + fn default() -> Self { + Self { + state: watch::channel(ConnectionState::Disconnected), + update: Mutex::new(()), + } + } } impl Client { - pub fn new(client_id: u64) -> Self { - let connection_manager = ConnectionManager::new(client_id); - Self { connection_manager } + /// Returns the client id used by the current connection, or [`None`] if the client is not [`ConnectionState::Connected`]. + pub fn client_id(&self) -> Option { + match *self.state.1.borrow() { + ConnectionState::Connected((client_id, _)) => Some(client_id), + _ => None, + } } - pub fn client_id(&self) -> u64 { - self.connection_manager.client_id() + async fn connect_and_handshake(client_id: u64) -> Result { + debug!("Connecting"); + + let mut new_connection = SocketConnection::connect().await?; + + debug!("Performing handshake"); + new_connection.handshake(client_id).await?; + debug!("Handshake completed"); + + Ok(new_connection) } - pub fn connect(&mut self) -> Result<()> { - self.connection_manager.connect() + async fn connect0(&self, client_id: u64, conn: Result) -> Result<()> { + let _state_guard = self.update.lock().await; + match hollow!(self.state.1) { + state @ ConnectionState::Disconnected => panic!( + "Illegal state during connection process {:?} -> {:?}", + ConnectionState::<()>::Connecting, + state + ), + ConnectionState::Connecting => { + self.state + .0 + .send(ConnectionState::Connected((client_id, Mutex::new(conn?)))) + .expect("the receiver cannot be dropped without the sender!"); + debug!("Connected"); + Ok(()) + } + ConnectionState::Connected(_) => panic!("Illegal concurrent connection!"), + ConnectionState::Disconnecting => { + match conn { + Ok(conn) => match conn.disconnect().await { + Err(e) => { + error!("failed to disconnect properly: {}", e); + } + _ => {} + }, + Err(e) => { + error!("failed connection: {}", e); + } + } + self.state + .0 + .send(ConnectionState::Disconnected) + .expect("the receiver cannot be dropped without the sender!"); + Err(Error::ConnectionClosed) + } + } } - pub fn disconnect(&mut self) { - self.connection_manager.disconnect() + pub async fn connect(&self, client_id: u64) -> Result<()> { + match hollow!(self.state.1) { + ConnectionState::Connected(_) => Ok(()), + _ => { + let state_guard = self.update.lock().await; + match hollow!(self.state.1) { + ConnectionState::Connected(_) => Ok(()), + ConnectionState::Disconnecting => Err(Error::ConnectionClosed), + ConnectionState::Connecting => { + match yield_while!(hollow!(self.state.1), ConnectionState::Connecting) { + ConnectionState::Connected(_) => Ok(()), + ConnectionState::Disconnecting => Err(Error::ConnectionClosed), + ConnectionState::Disconnected => Err(Error::ConnectionClosed), + ConnectionState::Connecting => unreachable!(), + } + } + ConnectionState::Disconnected => { + self.state + .0 + .send(ConnectionState::Connecting) + .expect("the receiver cannot be dropped without the sender!"); + + drop(state_guard); + select! { + conn = Self::connect_and_handshake(client_id) => { + self.connect0(client_id, conn).await + } + // _ = tokio::task::yield_now() if self.state.1.borrow().is_disconnecting() => { + // self.state.0.send(ConnectionState::Disconnected).expect("the receiver cannot be dropped without the sender!"); + // Err(Error::ConnectionClosed) + // } + } + } + } + } + } } - fn execute(&mut self, cmd: Command, args: A, evt: Option) -> Result> + /// If currently connected, the function will close the connection. + /// If currently connecting, the function will wait for the connection to be established and will immediately close it. + /// If currently disconnecting, the function will wait for the connection to be closed. + pub async fn disconnect(&self) { + let _state_guard = self.update.lock().await; + match hollow!(self.state.1) { + ConnectionState::Disconnected => {} + ref state @ ConnectionState::Disconnecting => { + match yield_while!(hollow!(self.state.1), ConnectionState::Disconnecting) { + ConnectionState::Disconnected => {} + ConnectionState::Disconnecting => unreachable!(), + new_state => panic!("Illegal state change {:?} -> {:?}", state, new_state), + } + } + ConnectionState::Connecting => { + self.state + .0 + .send(ConnectionState::Disconnecting) + .expect("the receiver cannot be dropped without the sender!"); + } + state @ ConnectionState::Connected(_) => { + match self.state.0.send_replace(ConnectionState::Disconnecting) { + ConnectionState::Connected(conn) => { + match conn.1.into_inner().disconnect().await { + Err(e) => { + error!("failed to disconnect properly: {}", e); + } + _ => {} + } + } + new_state @ ConnectionState::Connecting => { + panic!("Illegal state change {:?} -> {:?}", state, new_state); + } + state @ ConnectionState::Disconnecting => { + match yield_while!(hollow!(self.state.1), ConnectionState::Disconnecting) { + ConnectionState::Disconnected => {} + ConnectionState::Disconnecting => unreachable!(), + new_state => { + panic!("Illegal state change {:?} -> {:?}", state, new_state) + } + } + } + ConnectionState::Disconnected => {} + } + } + } + } + + async fn execute(&self, cmd: Command, args: A, evt: Option) -> Result> where A: Serialize + Send + Sync, E: Serialize + DeserializeOwned + Send + Sync, @@ -48,58 +274,69 @@ impl Client { OpCode::Frame, Payload::with_nonce(cmd, Some(args), None, evt), ); - self.connection_manager.send(message)?; - let Message { payload, .. } = self.connection_manager.recv()?; - let response: Payload = serde_json::from_str(&payload)?; - - match response.evt { - Some(Event::Error) => Err(Error::SubscriptionFailed), - _ => Ok(response), + match *self.state.1.borrow() { + ConnectionState::Connected((_, ref conn)) => { + let mut conn = conn.lock().await; + conn.send(message).await?; + let Message { payload, .. } = conn.recv().await?; + let response: Payload = serde_json::from_str(&payload)?; + match response.evt { + Some(Event::Error) => Err(Error::SubscriptionFailed), + _ => Ok(response), + } + } + _ => Err(Error::ConnectionClosed), } } #[cfg(feature = "rich_presence")] - pub fn set_activity(&mut self, activity: Activity) -> Result> { + pub async fn set_activity(&self, activity: Activity) -> Result> { self.execute(Command::SetActivity, SetActivityArgs::new(activity), None) + .await } #[cfg(feature = "rich_presence")] - pub fn clear_activity(&mut self) -> Result> { + pub async fn clear_activity(&self) -> Result> { self.execute(Command::SetActivity, SetActivityArgs::default(), None) + .await } // NOTE: Not sure what the actual response values of // SEND_ACTIVITY_JOIN_INVITE and CLOSE_ACTIVITY_REQUEST are, // they are not documented. #[cfg(feature = "rich_presence")] - pub fn send_activity_join_invite(&mut self, user_id: u64) -> Result> { + pub async fn send_activity_join_invite(&self, user_id: u64) -> Result> { self.execute( Command::SendActivityJoinInvite, SendActivityJoinInviteArgs::new(user_id), None, ) + .await } #[cfg(feature = "rich_presence")] - pub fn close_activity_request(&mut self, user_id: u64) -> Result> { + pub async fn close_activity_request(&self, user_id: u64) -> Result> { self.execute( Command::CloseActivityRequest, CloseActivityRequestArgs::new(user_id), None, ) + .await } - pub fn subscribe(&mut self, evt: Event, f: F) -> Result> + pub async fn subscribe(&self, evt: Event, f: F) -> Result> where F: FnOnce(SubscriptionArgs) -> SubscriptionArgs, { self.execute(Command::Subscribe, f(SubscriptionArgs::new()), Some(evt)) + .await } - pub fn unsubscribe(&mut self, evt: Event, f: F) -> Result> + pub async fn unsubscribe(&self, evt: Event, f: F) -> Result> where F: FnOnce(SubscriptionArgs) -> SubscriptionArgs, { self.execute(Command::Unsubscribe, f(SubscriptionArgs::new()), Some(evt)) + .await } } diff --git a/src/connection/base.rs b/src/connection/base.rs index c09dfd5..4d775fe 100644 --- a/src/connection/base.rs +++ b/src/connection/base.rs @@ -1,79 +1,66 @@ use std::{ - io::{ErrorKind, Read, Write}, marker::Sized, path::PathBuf, - thread, time, }; use bytes::BytesMut; +use tokio::io::{AsyncWrite, AsyncRead, AsyncWriteExt, AsyncReadExt}; -use error::{Error, Result}; -use models::message::{Message, OpCode}; -use utils; +use crate::error::{Error, Result}; +use crate::models::message::{Message, OpCode}; +use crate::utils; -/// Wait for a non-blocking connection until it's complete. -macro_rules! try_until_done { - [ $e:expr ] => { - loop { - match $e { - Ok(_) => break, - Err(Error::IoError(ref err)) if err.kind() == ErrorKind::WouldBlock => (), - Err(why) => return Err(why), - } - - thread::sleep(time::Duration::from_millis(500)); - } - } -} - -pub trait Connection: Sized { - type Socket: Write + Read; +#[async_trait::async_trait] +pub trait Connection: Sized + Send { + type Socket: AsyncWrite + AsyncRead + Unpin + Send; fn socket(&mut self) -> &mut Self::Socket; fn ipc_path() -> PathBuf; - fn connect() -> Result; + async fn connect() -> Result; + + async fn disconnect(self) -> Result<()>; fn socket_path(n: u8) -> PathBuf { Self::ipc_path().join(format!("discord-ipc-{}", n)) } - fn handshake(&mut self, client_id: u64) -> Result<()> { + async fn handshake(&mut self, client_id: u64) -> Result<()> { let hs = json![{ "client_id": client_id.to_string(), "v": 1, "nonce": utils::nonce() }]; - try_until_done!(self.send(Message::new(OpCode::Handshake, hs.clone()))); - try_until_done!(self.recv()); + self.send(Message::new(OpCode::Handshake, hs.clone())).await?; + self.recv().await?; Ok(()) } - fn ping(&mut self) -> Result { + async fn ping(&mut self) -> Result { let message = Message::new(OpCode::Ping, json![{}]); - self.send(message)?; - let response = self.recv()?; + self.send(message).await?; + let response = self.recv().await?; Ok(response.opcode) } - fn send(&mut self, message: Message) -> Result<()> { + async fn send(&mut self, message: Message) -> Result<()> { match message.encode() { Err(why) => error!("{:?}", why), Ok(bytes) => { - self.socket().write_all(bytes.as_ref())?; + self.socket().write_all(bytes.as_ref()).await?; } }; debug!("-> {:?}", message); Ok(()) } - fn recv(&mut self) -> Result { + async fn recv(&mut self) -> Result { let mut buf = BytesMut::new(); buf.resize(1024, 0); - let n = self.socket().read(&mut buf)?; + let n = self.socket().read(&mut buf).await?; debug!("Received {} bytes", n); if n == 0 { diff --git a/src/connection/manager.rs b/src/connection/manager.rs deleted file mode 100644 index 445161e..0000000 --- a/src/connection/manager.rs +++ /dev/null @@ -1,166 +0,0 @@ -use crossbeam_channel::{unbounded, Receiver, Sender, TryRecvError}; -use parking_lot::{RwLock, RwLockUpgradableReadGuard}; -use std::{io::ErrorKind, sync::Arc, thread, time}; - -use super::{Connection, SocketConnection}; -use error::{Error, Result}; -use models::Message; - -type Tx = Sender; -type Rx = Receiver; - -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum ConnectionState { - Disconnected, - Connecting, - Connected, - Disconnecting, -} - -#[derive(Clone, Debug)] -pub struct Manager { - state: Arc>, - client_id: u64, - outbound: (Rx, Tx), - inbound: (Rx, Tx), -} - -impl Manager { - pub fn new(client_id: u64) -> Self { - let (sender_o, receiver_o) = unbounded(); - let (sender_i, receiver_i) = unbounded(); - - Self { - state: Arc::new(RwLock::new(ConnectionState::Disconnected)), - client_id, - inbound: (receiver_i, sender_i), - outbound: (receiver_o, sender_o), - } - } - - pub fn client_id(&self) -> u64 { - self.client_id - } - - pub fn send(&self, message: Message) -> Result<()> { - self.outbound.1.send(message)?; - Ok(()) - } - - // TODO: timeout - pub fn recv(&self) -> Result { - while *self.state.read() == ConnectionState::Connected { - match self.inbound.0.try_recv() { - Ok(message) => return Ok(message), - Err(TryRecvError::Empty) => {} - Err(TryRecvError::Disconnected) => break - } - } - - Err(Error::ConnectionClosed) - } - - pub fn connect(&mut self) -> Result<()> { - // check with a read lock first - let state = self.state.upgradable_read(); - match *state { - ConnectionState::Disconnected => { - // no need to double-check after this because the read lock is upgraded instead of replace with a write lock (no possibility of mutation before the upgrade). - let mut state = RwLockUpgradableReadGuard::upgrade(state); - *state = ConnectionState::Connecting; - - // we are in the connecting state and no longer need a lock to prevent the creation of many threads. - drop(state); - - debug!("Connecting"); - - let mut new_connection = SocketConnection::connect()?; - - debug!("Performing handshake"); - new_connection.handshake(self.client_id)?; - debug!("Handshake completed"); - - let mut state = self.state.write(); - match *state { - ConnectionState::Disconnected => panic!("Illegal state: {:?}", *state), - ConnectionState::Connecting => { - *state = ConnectionState::Connected; - - drop(state); - - debug!("Connected"); - - let state_arc = self.state.clone(); - let inbound = self.inbound.1.clone(); - let outbound = self.outbound.0.clone(); - thread::spawn(move || { - send_and_receive_loop(state_arc, inbound, outbound, new_connection); - }); - - Ok(()) - } - ConnectionState::Connected => panic!("Illegal state: {:?}", *state), - ConnectionState::Disconnecting => { - *state = ConnectionState::Disconnected; - - Err(Error::ConnectionClosed) - } - } - } - ConnectionState::Connecting => Ok(()), - ConnectionState::Connected => Ok(()), - ConnectionState::Disconnecting => Err(Error::Busy), - } - } - - pub fn disconnect(&mut self) { - let state = self.state.upgradable_read(); - if *state != ConnectionState::Disconnected { - let mut state = RwLockUpgradableReadGuard::upgrade(state); - *state = ConnectionState::Disconnecting; - } - } -} - -fn send_and_receive_loop( - state: Arc>, - mut inbound: Sender, - outbound: Receiver, - mut conn: SocketConnection, -) { - debug!("Starting sender loop"); - - loop { - match send_and_receive(&mut conn, &mut inbound, &outbound) { - Err(Error::IoError(ref err)) if err.kind() == ErrorKind::WouldBlock => {} - Err(Error::IoError(_)) | Err(Error::ConnectionClosed) => { - let mut state = state.write(); - *state = ConnectionState::Disconnected; - break; - } - Err(e) => error!("send_and_receive error: {}", e), - _ => {} - } - - thread::sleep(time::Duration::from_millis(500)); - } -} - -fn send_and_receive( - connection: &mut SocketConnection, - inbound: &mut Tx, - outbound: &Rx, -) -> Result<()> { - loop { - match outbound.try_recv() { - Ok(msg) => connection.send(msg)?, - Err(TryRecvError::Empty) => break, - Err(TryRecvError::Disconnected) => return Err(Error::ConnectionClosed), - } - } - - let msg = connection.recv()?; - inbound.send(msg)?; - - Ok(()) -} diff --git a/src/connection/mod.rs b/src/connection/mod.rs index 31a66dc..7ad438f 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -1,12 +1,10 @@ mod base; -mod manager; #[cfg(unix)] mod unix; #[cfg(windows)] mod windows; pub use self::base::Connection; -pub use self::manager::Manager; #[cfg(unix)] pub use self::unix::UnixConnection as SocketConnection; #[cfg(windows)] diff --git a/src/connection/unix.rs b/src/connection/unix.rs index b45daa6..e357368 100644 --- a/src/connection/unix.rs +++ b/src/connection/unix.rs @@ -1,23 +1,21 @@ -use std::{env, net::Shutdown, os::unix::net::UnixStream, path::PathBuf, time}; +use std::{env, path::PathBuf}; + +use tokio::{io::AsyncWriteExt, net::UnixStream}; use super::base::Connection; -use error::Result; +use crate::error::Result; #[derive(Debug)] pub struct UnixConnection { socket: UnixStream, } +#[async_trait::async_trait] impl Connection for UnixConnection { type Socket = UnixStream; - fn connect() -> Result { - let connection_name = Self::socket_path(0); - let socket = UnixStream::connect(connection_name)?; - socket.set_nonblocking(true)?; - socket.set_write_timeout(Some(time::Duration::from_secs(30)))?; - socket.set_read_timeout(Some(time::Duration::from_secs(30)))?; - Ok(Self { socket }) + fn socket(&mut self) -> &mut Self::Socket { + &mut self.socket } fn ipc_path() -> PathBuf { @@ -40,15 +38,14 @@ impl Connection for UnixConnection { .unwrap_or_else(|| PathBuf::from("/tmp")) } - fn socket(&mut self) -> &mut Self::Socket { - &mut self.socket + async fn connect() -> Result { + let connection_name = Self::socket_path(0); + let socket = UnixStream::connect(connection_name).await?; + Ok(Self { socket }) } -} -impl Drop for UnixConnection { - fn drop(&mut self) { - self.socket - .shutdown(Shutdown::Both) - .expect("Failed to properly shut down socket"); + async fn disconnect(mut self) -> Result<()> { + self.socket.shutdown().await?; + Ok(()) } } diff --git a/src/connection/windows.rs b/src/connection/windows.rs index 43e966e..7ba10ad 100644 --- a/src/connection/windows.rs +++ b/src/connection/windows.rs @@ -1,31 +1,33 @@ use std::{path::PathBuf, time}; -use super::base::Connection; -use error::Result; +use tokio::net::windows::named_pipe::{ClientOptions, NamedPipeClient}; -use named_pipe::PipeClient; +use super::base::Connection; +use crate::error::Result; #[derive(Debug)] pub struct WindowsConnection { - socket: PipeClient, + socket: NamedPipeClient, } +#[async_trait::async_trait] impl Connection for WindowsConnection { - type Socket = PipeClient; + type Socket = NamedPipeClient; - fn connect() -> Result { - let connection_name = Self::socket_path(0); - let mut socket = PipeClient::connect(connection_name)?; - socket.set_write_timeout(Some(time::Duration::from_secs(30))); - socket.set_read_timeout(Some(time::Duration::from_secs(30))); - Ok(Self { socket }) + fn socket(&mut self) -> &mut Self::Socket { + &mut self.socket } fn ipc_path() -> PathBuf { PathBuf::from(r"\\.\pipe\") } - fn socket(&mut self) -> &mut Self::Socket { - &mut self.socket + async fn connect() -> Result { + let connection_name = Self::socket_path(0); + Ok(Self { socket: ClientOptions::new().open(connection_name)? }) + } + + async fn disconnect(mut self) -> Result<()> { + Ok(()) } } diff --git a/src/error.rs b/src/error.rs index a154304..f095041 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,5 +1,5 @@ -use crossbeam_channel::SendError; use serde_json::Error as JsonError; +use tokio::sync::mpsc::error::SendError; use std::{ error::Error as StdError, fmt::{self, Display, Formatter}, @@ -63,4 +63,4 @@ impl From> for Error { } } -pub type Result = StdResult; +pub type Result = StdResult; diff --git a/src/lib.rs b/src/lib.rs index 3375196..c6be4bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,9 @@ #[macro_use] extern crate log; #[macro_use] -extern crate serde_derive; extern crate serde; #[macro_use] extern crate serde_json; -extern crate byteorder; -extern crate bytes; -extern crate crossbeam_channel; -#[cfg(windows)] -extern crate named_pipe; -extern crate parking_lot; -extern crate uuid; #[macro_use] mod macros; diff --git a/src/models/message.rs b/src/models/message.rs index 28736be..1176d0c 100644 --- a/src/models/message.rs +++ b/src/models/message.rs @@ -4,7 +4,7 @@ use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use serde::Serialize; use serde_json; -use error::{Error, Result}; +use crate::error::{Error, Result}; #[derive(Debug, Copy, Clone, PartialEq)] pub enum OpCode { @@ -15,9 +15,10 @@ pub enum OpCode { Pong, } -// FIXME: Use TryFrom trait when stable -impl OpCode { - fn try_from(int: u32) -> Result { +impl TryFrom for OpCode { + type Error = Error; + + fn try_from(int: u32) -> Result { match int { 0 => Ok(OpCode::Handshake), 1 => Ok(OpCode::Frame), diff --git a/src/models/payload.rs b/src/models/payload.rs index 23cc0c1..a3c31b2 100644 --- a/src/models/payload.rs +++ b/src/models/payload.rs @@ -4,7 +4,7 @@ use serde::{de::DeserializeOwned, Serialize}; use serde_json; use super::{Command, Event, Message}; -use utils; +use crate::utils; #[derive(Debug, PartialEq, Deserialize, Serialize)] pub struct Payload diff --git a/src/models/rich_presence.rs b/src/models/rich_presence.rs index 5bc4d92..1e68a2d 100644 --- a/src/models/rich_presence.rs +++ b/src/models/rich_presence.rs @@ -3,7 +3,7 @@ use std::default::Default; use super::shared::PartialUser; -use utils; +use crate::utils; #[derive(Debug, PartialEq, Deserialize, Serialize)] pub struct SetActivityArgs {