diff --git a/Cargo.toml b/Cargo.toml index 5700a2b..b673642 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,8 @@ serde_json = "^1.0" byteorder = "^1.0" log = "~0.4" bytes = "^0.4" +parking_lot = "^0.7" +crossbeam-channel = "^0.3" [target.'cfg(windows)'.dependencies] named_pipe = "0.3.0" diff --git a/examples/discord_presence.rs b/examples/discord_presence.rs index 1552eda..04bb2eb 100644 --- a/examples/discord_presence.rs +++ b/examples/discord_presence.rs @@ -8,8 +8,7 @@ use discord_rpc_client::Client as DiscordRPC; fn main() { TermLogger::init(LevelFilter::Debug, Config::default()).unwrap(); - let mut drpc = DiscordRPC::new(425407036495495169) - .expect("Failed to create client"); + let mut drpc = DiscordRPC::new(425407036495495169); drpc.start(); diff --git a/examples/discord_presence_subscribe.rs b/examples/discord_presence_subscribe.rs index f72d1fc..b120033 100644 --- a/examples/discord_presence_subscribe.rs +++ b/examples/discord_presence_subscribe.rs @@ -11,8 +11,7 @@ use discord_rpc_client::{ fn main() { TermLogger::init(LevelFilter::Debug, Config::default()).unwrap(); - let mut drpc = DiscordRPC::new(425407036495495169) - .expect("Failed to create client"); + let mut drpc = DiscordRPC::new(425407036495495169); drpc.start(); diff --git a/src/client.rs b/src/client.rs index 3d7dcc9..a3593bd 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,6 +8,7 @@ use models::{ Command, Event, payload::Payload, + message::Message, commands::{SubscriptionArgs, Subscription}, }; #[cfg(feature = "rich_presence")] @@ -21,26 +22,27 @@ use error::{Result, Error}; pub struct Client { - connection: ConnectionManager, + connection_manager: ConnectionManager, } impl Client { - pub fn new(client_id: u64) -> Result { - Ok(Self { - connection: ConnectionManager::new(client_id)? - }) + pub fn new(client_id: u64) -> Self { + let connection_manager = ConnectionManager::new(client_id); + Self { connection_manager } } pub fn start(&mut self) { - self.connection.start(); + self.connection_manager.start(); } - pub fn execute(&mut self, cmd: Command, args: A, evt: Option) -> Result> + fn execute(&mut self, cmd: Command, args: A, evt: Option) -> Result> where A: Serialize + Send + Sync, E: Serialize + DeserializeOwned + Send + Sync { - self.connection.send(OpCode::Frame, Payload::with_nonce(cmd, Some(args), None, evt))?; - let (_, response): (OpCode, Payload) = self.connection.recv()?; + let message = Message::new(OpCode::Frame, Payload::with_nonce(cmd, Some(args), None, evt)); + self.connection_manager.send(message)?; + let Message { payload, .. } = self.connection_manager.recv()?; + let response: Payload = serde_json::from_str(&payload)?; match response.evt { Some(Event::Error) => Err(Error::SubscriptionFailed), diff --git a/src/connection/manager.rs b/src/connection/manager.rs index 8024b1b..d7c2132 100644 --- a/src/connection/manager.rs +++ b/src/connection/manager.rs @@ -1,191 +1,132 @@ use std::{ - thread::{self, JoinHandle}, + thread, sync::{ Arc, - Mutex, - atomic::{AtomicBool, ATOMIC_BOOL_INIT, Ordering}, - mpsc::{sync_channel, Receiver, SyncSender}, }, time, io::ErrorKind }; - -use serde_json; -use serde::{Serialize, de::DeserializeOwned}; +use crossbeam_channel::{unbounded, Receiver, Sender}; +use parking_lot::Mutex; use super::{ - Connection as BaseConnection, + Connection, SocketConnection, }; -use utils; -use models::{Message, OpCode}; +use models::Message; use error::{Result, Error}; -type MessageQueue = (SyncSender, Receiver); -type Connection = Arc>>; -static CONNECTED: AtomicBool = ATOMIC_BOOL_INIT; -static STARTED: AtomicBool = ATOMIC_BOOL_INIT; -static HANDSHAKED: AtomicBool = ATOMIC_BOOL_INIT; +type Tx = Sender; +type Rx = Receiver; +#[derive(Clone)] pub struct Manager { - send_channel: SyncSender, - recv_channel: Receiver, - _sender: JoinHandle<()>, - _checker: JoinHandle<()>, - _connection: Connection, + connection: Arc>>, + client_id: u64, + outbound: (Rx, Tx), + inbound: (Rx, Tx), + handshake_completed: bool, } impl Manager { - pub fn new(client_id: u64) -> Result { - let send_queue: MessageQueue = sync_channel(20); - let recv_queue: MessageQueue = sync_channel(20); - let conn = Arc::new(Mutex::new(None)); - let send_channel = send_queue.0; - let recv_channel = recv_queue.1; + pub fn new(client_id: u64) -> Self { + let connection = Arc::new(None); + let (sender_o, receiver_o) = unbounded(); + let (sender_i, receiver_i) = unbounded(); - let manager = Self { - send_channel, - recv_channel, - _sender: Self::sender_loop(conn.clone(), (recv_queue.0.clone(), send_queue.1)), - _checker: Self::connection_checker(conn.clone(), client_id), - _connection: conn, - }; - - Ok(manager) + Self { + connection, + client_id, + handshake_completed: false, + inbound: (receiver_i, sender_i), + outbound: (receiver_o, sender_o), + } } pub fn start(&mut self) { - STARTED.store(true, Ordering::SeqCst); - } - - pub fn send(&mut self, opcode: OpCode, payload: S) -> Result<()> - where S: Serialize + Sync + Send - { - let message = Message::new(opcode, payload); - self.send_channel.send(message).unwrap(); - Ok(()) - } - - pub fn recv(&mut self) -> Result<(OpCode, S)> - where S: DeserializeOwned + Send + Sync - { - let message = self.recv_channel.recv_timeout(time::Duration::from_secs(10))?; - let payload = serde_json::from_str(&message.payload)?; - Ok((message.opcode, payload)) - } - - fn connect(connection: Connection) -> Result<()> { - if !CONNECTED.load(Ordering::SeqCst) { - if let Ok(mut conn_lock) = connection.lock() { - if conn_lock.is_some() { - if let Some(ref mut conn) = *conn_lock { - if let Ok(opcode) = Self::ping(conn) { - if opcode == OpCode::Pong { - CONNECTED.store(true, Ordering::SeqCst); - debug!("Reconnected") - } - } - } - } else { - *conn_lock = Some(SocketConnection::connect()?); - CONNECTED.store(true, Ordering::SeqCst); - debug!("Connected") - } - } - } - Ok(()) - } - - fn disconnect(connection: Connection) { - if let Ok(mut conn_lock) = connection.lock() { - if let Some(ref mut conn) = *conn_lock { - if conn.disconnect().is_ok() { - CONNECTED.store(false, Ordering::SeqCst); - HANDSHAKED.store(false, Ordering::SeqCst); - } - } - - *conn_lock = None; - } - } - - fn handshake(connection: Connection, client_id: u64) -> Result<()> { - if CONNECTED.load(Ordering::SeqCst) && !HANDSHAKED.load(Ordering::SeqCst) { - let hs = json![{ - "client_id": client_id.to_string(), - "v": 1, - "nonce": utils::nonce() - }]; - - if let Ok(mut conn_guard) = connection.lock() { - if let Some(ref mut conn) = *conn_guard { - conn.send(Message::new(OpCode::Handshake, hs))?; - let _res = conn.recv()?; - HANDSHAKED.store(true, Ordering::SeqCst); - } - } - } - Ok(()) - } - - fn ping(connection: &mut SocketConnection) -> Result { - let message = Message::new(OpCode::Ping, json![{}]); - connection.send(message)?; - let opcode = connection.recv()?.opcode; - debug!("{:?}", opcode); - Ok(opcode) - } - - fn send_and_receive(connection: &Connection, queue: &MessageQueue) -> Result<()> { - if let Ok(msg) = queue.1.recv() { - if let Ok(mut conn_guard) = connection.lock() { - if let Some(ref mut conn) = *conn_guard { - conn.send(msg)?; - let res = conn.recv()?; - queue.0.send(res).unwrap(); - } - } - }; - Ok(()) - } - - fn connection_checker(connection: Connection, client_id: u64) -> JoinHandle<()> { + let manager_inner = self.clone(); thread::spawn(move || { - debug!("Starting connection checker loop..."); - - loop { - let _ = Self::connect(connection.clone()); - match Self::handshake(connection.clone(), client_id) { - Err(Error::IoError(ref err)) if err.kind() == ErrorKind::WouldBlock => { - debug!("{:?}", err); - }, - Err(err) => debug!("{:?}", err), - Ok(_) => () - }; - - thread::sleep(time::Duration::from_millis(500)); - } - }) + send_and_receive_loop(manager_inner); + }); } - fn sender_loop(connection: Connection, queue: MessageQueue) -> JoinHandle<()> { - thread::spawn(move || { - debug!("Starting sender loop..."); - loop { - if STARTED.load(Ordering::SeqCst) && CONNECTED.load(Ordering::SeqCst) && HANDSHAKED.load(Ordering::SeqCst) { - match Self::send_and_receive(&connection, &queue) { - Err(Error::IoError(ref err)) if err.kind() == ErrorKind::WouldBlock => (), - Err(Error::IoError(_err)) => { - Self::disconnect(connection.clone()); - // error!("Disconnected: {}", _err); - }, - Err(why) => error!("{}", why), - Ok(_) => () - } + pub fn send(&self, message: Message) -> Result<()> { + self.outbound.1.send(message).unwrap(); + Ok(()) + } + + pub fn recv(&self) -> Result { + let message = self.inbound.0.recv().unwrap(); + Ok(message) + } + + fn connect(&mut self) -> Result<()> { + if self.connection.is_some() { + return Ok(()); + } + + debug!("Connecting"); + + let mut new_connection = SocketConnection::connect()?; + + debug!("Performing handshake"); + new_connection.handshake(self.client_id)?; + debug!("Handshake completed"); + + self.connection = Arc::new(Some(Mutex::new(new_connection))); + + debug!("Connected"); + + Ok(()) + } + + fn disconnect(&mut self) { + self.handshake_completed = false; + self.connection = Arc::new(None); + } + +} + + +fn send_and_receive_loop(mut manager: Manager) { + debug!("Starting sender loop"); + + let mut inbound = manager.inbound.1.clone(); + let outbound = manager.outbound.0.clone(); + + loop { + let connection = manager.connection.clone(); + + match *connection { + Some(ref conn) => { + let mut connection = conn.lock(); + match send_and_receive(&mut *connection, &mut inbound, &outbound) { + Err(Error::IoError(ref err)) if err.kind() == ErrorKind::WouldBlock => (), + Err(Error::IoError(_)) | Err(Error::ConnectionClosed) => manager.disconnect(), + Err(why) => error!("error: {}", why), + _ => (), + } + }, + None => { + match manager.connect() { + Err(why) => error!("Failed to connect: {:?}", why), + _ => manager.handshake_completed = true, } - thread::sleep(time::Duration::from_millis(500)); } - }) + } + + thread::sleep(time::Duration::from_millis(500)); } } + +fn send_and_receive(connection: &mut SocketConnection, inbound: &mut Tx, outbound: &Rx) -> Result<()> { + while let Ok(msg) = outbound.try_recv() { + connection.send(msg).expect("Failed to send outgoing data"); + } + + let msg = connection.recv()?; + inbound.send(msg).expect("Failed to send received data"); + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index ae69233..f2d3a9d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,8 @@ extern crate serde_json; extern crate byteorder; extern crate uuid; extern crate bytes; +extern crate parking_lot; +extern crate crossbeam_channel; #[cfg(windows)] extern crate named_pipe;