diff --git a/src/client.rs b/src/client.rs index 71f6d99..c948975 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,10 +1,8 @@ +use std::time; + use serde::{Serialize, de::DeserializeOwned}; -use connection::{ - Connection, - SocketConnection, - Manager as ConnectionManager, -}; +use connection::Manager as ConnectionManager; use models::{ OpCode, Command, @@ -12,30 +10,24 @@ use models::{ payload::Payload, commands::{SubscriptionArgs, Subscription}, }; - #[cfg(feature = "rich_presence")] use models::rich_presence::{SetActivityArgs, Activity}; use error::{Result, Error}; -pub struct Client - where T: Connection + Send + Sync + 'static -{ - connection: ConnectionManager, +pub struct Client { + connection: ConnectionManager, } -impl Client - where T: Connection + Send + Sync + 'static -{ - pub fn with_connection(client_id: u64, connection: T) -> Result { +impl Client { + pub fn new(client_id: u64) -> Result { Ok(Self { - connection: ConnectionManager::with_connection(client_id, connection)? + connection: ConnectionManager::new(client_id)? }) } - pub fn start(mut self) -> Result { - self.connection.handshake()?; - Ok(self) + pub fn start(&mut self) { + self.connection.start(); } pub fn execute(&mut self, cmd: Command, args: A, evt: Option) -> Result> @@ -43,7 +35,7 @@ impl Client E: Serialize + DeserializeOwned + Send + Sync { self.connection.send(OpCode::Frame, Payload::with_nonce(cmd, Some(args), None, evt))?; - let response: Payload = self.connection.recv()?; + let (_, response): (OpCode, Payload) = self.connection.recv()?; match response.evt { Some(Event::Error) => Err(Error::SubscriptionFailed), @@ -70,10 +62,3 @@ impl Client self.execute(Command::Unsubscribe, f(SubscriptionArgs::new()), Some(evt)) } } - -impl Client { - pub fn new(client_id: u64) -> Result { - let socket = Connection::connect()?; - Self::with_connection(client_id, socket) - } -} diff --git a/src/connection/base.rs b/src/connection/base.rs index c6cc2f4..595347e 100644 --- a/src/connection/base.rs +++ b/src/connection/base.rs @@ -18,18 +18,20 @@ pub trait Connection: Sized { fn connect() -> Result; + fn disconnect(&self) -> Result<()>; + fn socket_path(n: u8) -> PathBuf { Self::ipc_path().join(format!("discord-ipc-{}", n)) } fn send(&mut self, message: Message) -> Result<()> { - debug!("{:?}", message); match message.encode() { Err(why) => error!("{:?}", why), Ok(bytes) => { self.socket().write_all(bytes.as_ref())?; } }; + debug!("-> {:?}", message); Ok(()) } @@ -38,7 +40,7 @@ pub trait Connection: Sized { let n = self.socket().read(buf.as_mut_slice())?; buf.resize(n, 0); let message = Message::decode(&buf)?; - debug!("{:?}", message); + debug!("<- {:?}", message); Ok(message) } } diff --git a/src/connection/manager.rs b/src/connection/manager.rs index 37e7ba9..25425c0 100644 --- a/src/connection/manager.rs +++ b/src/connection/manager.rs @@ -3,53 +3,61 @@ use std::{ sync::{ Arc, Mutex, - atomic::AtomicBool, + atomic::{AtomicBool, ATOMIC_BOOL_INIT, Ordering}, mpsc::{sync_channel, Receiver, SyncSender}, }, time, + io::ErrorKind }; use serde_json; use serde::{Serialize, de::DeserializeOwned}; -use super::Connection; +use super::{ + Connection as BaseConnection, + SocketConnection, +}; use utils; -use models::{Message, OpCode, ReadyEvent, payload::Payload}; -use error::Result; +use models::{Message, OpCode}; +use error::{Result, Error}; type MessageQueue = (SyncSender, Receiver); +type Connection = Arc>>; -pub struct Manager - where T: Connection + Send + Sync -{ - client_id: u64, +static CONNECTED: AtomicBool = ATOMIC_BOOL_INIT; +static STARTED: AtomicBool = ATOMIC_BOOL_INIT; +static HANDSHAKED: AtomicBool = ATOMIC_BOOL_INIT; +static HANDSHAKING: AtomicBool = ATOMIC_BOOL_INIT; + +pub struct Manager { send_channel: SyncSender, recv_channel: Receiver, - _version: u32, - _connected: Arc, - _receiver: JoinHandle<()>, _sender: JoinHandle<()>, - _connection: Arc>, + _checker: JoinHandle<()>, + _connection: Connection, } -impl Manager - where T: Connection + Send + Sync + 'static, -{ - pub fn with_connection(client_id: u64, connection: T) -> Result { +impl Manager { + pub fn new(client_id: u64) -> Result { let send_queue: MessageQueue = sync_channel(20); let recv_queue: MessageQueue = sync_channel(20); - let conn = Arc::new(Mutex::new(connection)); + let conn = Arc::new(Mutex::new(None)); + let send_channel = send_queue.0; + let recv_channel = recv_queue.1; - Ok(Self { - client_id, - send_channel: send_queue.0, - recv_channel: recv_queue.1, - _version: 1, - _connected: Arc::new(AtomicBool::new(false)), + let manager = Self { + send_channel, + recv_channel, _sender: Self::sender_loop(conn.clone(), (recv_queue.0.clone(), send_queue.1)), - _receiver: Self::receiver_loop(conn.clone(), recv_queue.0.clone()), + _checker: Self::connection_checker(conn.clone(), client_id), _connection: conn, - }) + }; + + Ok(manager) + } + + pub fn start(&mut self) { + STARTED.store(true, Ordering::SeqCst); } pub fn send(&mut self, opcode: OpCode, payload: S) -> Result<()> @@ -60,51 +68,123 @@ impl Manager Ok(()) } - pub fn recv(&mut self) -> Result + pub fn recv(&mut self) -> Result<(OpCode, S)> where S: DeserializeOwned + Send + Sync { - let message = self.recv_channel.recv().unwrap(); - let payload = serde_json::from_str(&message.payload).unwrap(); - Ok(payload) + let message = self.recv_channel.recv_timeout(time::Duration::from_secs(10))?; + let payload = serde_json::from_str(&message.payload)?; + Ok((message.opcode, payload)) } - pub fn handshake(&mut self) -> Result<()> { - let hs = json![{ - "client_id": self.client_id.to_string(), - "v": 1, - "nonce": utils::nonce() - }]; - self.send(OpCode::Handshake, hs)?; - let _: Payload = self.recv()?; + fn connect(connection: Connection) -> Result<()> { + if !CONNECTED.load(Ordering::SeqCst) { + if let Ok(mut conn_lock) = connection.lock() { + if conn_lock.is_some() { + if let Some(ref mut conn) = *conn_lock { + if let Ok(opcode) = Self::ping(conn) { + if opcode == OpCode::Pong { + CONNECTED.store(true, Ordering::SeqCst); + debug!("Reconnected") + } + } + } + } else { + *conn_lock = Some(SocketConnection::connect()?); + CONNECTED.store(true, Ordering::SeqCst); + debug!("Connected") + } + } + } Ok(()) } - fn sender_loop(connection: Arc>, queue: MessageQueue) -> JoinHandle<()> { + fn disconnect(connection: Connection) { + if let Ok(mut conn_lock) = connection.lock() { + if let Some(ref mut conn) = *conn_lock { + if conn.disconnect().is_ok() { + CONNECTED.store(false, Ordering::SeqCst); + HANDSHAKED.store(false, Ordering::SeqCst); + } + } + + *conn_lock = None; + } + } + + fn handshake(connection: Connection, client_id: u64) -> Result<()> { + if CONNECTED.load(Ordering::SeqCst) && !HANDSHAKED.load(Ordering::SeqCst) && !HANDSHAKING.load(Ordering::SeqCst) { + let hs = json![{ + "client_id": client_id.to_string(), + "v": 1, + "nonce": utils::nonce() + }]; + + if let Ok(mut conn_guard) = connection.lock() { + if let Some(ref mut conn) = *conn_guard { + conn.send(Message::new(OpCode::Handshake, hs))?; + let _res = conn.recv()?; + HANDSHAKED.store(true, Ordering::SeqCst); + } + } + } + Ok(()) + } + + fn ping(connection: &mut SocketConnection) -> Result { + let message = Message::new(OpCode::Ping, json![{}]); + connection.send(message)?; + let opcode = connection.recv()?.opcode; + debug!("{:?}", opcode); + Ok(opcode) + } + + fn send_and_receive(connection: &Connection, queue: &MessageQueue) -> Result<()> { + if let Ok(msg) = queue.1.recv() { + if let Ok(mut conn_guard) = connection.lock() { + if let Some(ref mut conn) = *conn_guard { + conn.send(msg)?; + let res = conn.recv()?; + queue.0.send(res).unwrap(); + } + } + }; + Ok(()) + } + + fn connection_checker(connection: Connection, client_id: u64) -> JoinHandle<()> { thread::spawn(move || { - println!("starting sender loop..."); + debug!("Starting connection checker loop..."); + loop { - if let Ok(msg) = queue.1.recv() { - if let Ok(mut guard) = connection.lock() { - guard.send(msg).unwrap(); - if let Ok(res) = guard.recv() { - queue.0.send(res).unwrap(); - } - } + let _ = Self::connect(connection.clone()); + match Self::handshake(connection.clone(), client_id) { + Err(Error::IoError(ref err)) if err.kind() == ErrorKind::WouldBlock => { + debug!("{:?}", err); + }, + Err(err) => debug!("{:?}", err), + Ok(_) => () }; + thread::sleep(time::Duration::from_millis(500)); } }) } - fn receiver_loop(connection: Arc>, queue: SyncSender) -> JoinHandle<()> { + fn sender_loop(connection: Connection, queue: MessageQueue) -> JoinHandle<()> { thread::spawn(move || { - println!("starting receiver loop..."); + debug!("Starting sender loop..."); loop { - if let Ok(mut guard) = connection.lock() { - if let Ok(msg) = guard.recv() { - queue.send(msg).unwrap(); + if STARTED.load(Ordering::SeqCst) && CONNECTED.load(Ordering::SeqCst) && HANDSHAKED.load(Ordering::SeqCst) { + match Self::send_and_receive(&connection, &queue) { + Err(Error::IoError(ref err)) if err.kind() == ErrorKind::WouldBlock => (), + Err(Error::IoError(err)) => { + Self::disconnect(connection.clone()); + // error!("Disconnected: {}", err); + }, + Err(why) => error!("{}", why), + Ok(_) => () } - }; + } thread::sleep(time::Duration::from_millis(500)); } }) diff --git a/src/connection/unix.rs b/src/connection/unix.rs index 9a900ff..121dc9c 100644 --- a/src/connection/unix.rs +++ b/src/connection/unix.rs @@ -3,6 +3,7 @@ use std::{ path::PathBuf, env, os::unix::net::UnixStream, + net::Shutdown, }; use super::base::Connection; @@ -19,12 +20,17 @@ impl Connection for UnixConnection { fn connect() -> Result { let connection_name = Self::socket_path(0); let socket = UnixStream::connect(connection_name)?; - socket.set_nonblocking(true)?; + // socket.set_nonblocking(true)?; socket.set_write_timeout(Some(time::Duration::from_secs(30)))?; socket.set_read_timeout(Some(time::Duration::from_secs(30)))?; Ok(Self { socket }) } + fn disconnect(&self) -> Result<()> { + self.socket.shutdown(Shutdown::Both)?; + Ok(()) + } + fn ipc_path() -> PathBuf { let tmp = env::var("XDG_RUNTIME_DIR") .or_else(|_| env::var("TMPDIR")) diff --git a/src/error.rs b/src/error.rs index 915288d..0ccfc93 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,17 +2,21 @@ use std::{ error::Error as StdError, io::Error as IoError, result::Result as StdResult, + sync::mpsc::RecvTimeoutError as ChannelTimeout, fmt::{ self, Display, Formatter - } + }, }; +use serde_json::Error as JsonError; #[derive(Debug)] pub enum Error { - Io(IoError), + IoError(IoError), + JsonError(JsonError), + Timeout(ChannelTimeout), Conversion, SubscriptionFailed, } @@ -28,14 +32,28 @@ impl StdError for Error { match *self { Error::Conversion => "Failed to convert values", Error::SubscriptionFailed => "Failed to subscribe to event", - Error::Io(ref err) => err.description() + Error::IoError(ref err) => err.description(), + Error::JsonError(ref err) => err.description(), + Error::Timeout(ref err) => err.description(), } } } impl From for Error { fn from(err: IoError) -> Self { - Error::Io(err) + Error::IoError(err) + } +} + +impl From for Error { + fn from(err: JsonError) -> Self { + Error::JsonError(err) + } +} + +impl From for Error { + fn from(err: ChannelTimeout) -> Self { + Error::Timeout(err) } }