From d370cb6432ed6e3095a88f4de0475604ef0ed119 Mon Sep 17 00:00:00 2001 From: Patrick Auernig Date: Fri, 6 Apr 2018 21:51:01 +0200 Subject: [PATCH] Add connection manager --- src/client.rs | 37 ++++--------- src/connection/base.rs | 13 +---- src/connection/manager.rs | 112 ++++++++++++++++++++++++++++++++++++++ src/connection/mod.rs | 2 + src/connection/unix.rs | 1 + 5 files changed, 129 insertions(+), 36 deletions(-) create mode 100644 src/connection/manager.rs diff --git a/src/client.rs b/src/client.rs index cea619c..71f6d99 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,6 +3,7 @@ use serde::{Serialize, de::DeserializeOwned}; use connection::{ Connection, SocketConnection, + Manager as ConnectionManager, }; use models::{ OpCode, @@ -15,35 +16,34 @@ use models::{ #[cfg(feature = "rich_presence")] use models::rich_presence::{SetActivityArgs, Activity}; use error::{Result, Error}; -use utils; pub struct Client - where T: Connection + where T: Connection + Send + Sync + 'static { - client_id: u64, - version: u32, - connection: T, + connection: ConnectionManager, } impl Client - where T: Connection + where T: Connection + Send + Sync + 'static { pub fn with_connection(client_id: u64, connection: T) -> Result { - Ok(Self { version: 1, client_id, connection }) + Ok(Self { + connection: ConnectionManager::with_connection(client_id, connection)? + }) } pub fn start(mut self) -> Result { - self.handshake()?; + self.connection.handshake()?; Ok(self) } pub fn execute(&mut self, cmd: Command, args: A, evt: Option) -> Result> - where A: Serialize, - E: Serialize + DeserializeOwned + where A: Serialize + Send + Sync, + E: Serialize + DeserializeOwned + Send + Sync { self.connection.send(OpCode::Frame, Payload::with_nonce(cmd, Some(args), None, evt))?; - let response: Payload = self.connection.recv()?.into(); + let response: Payload = self.connection.recv()?; match response.evt { Some(Event::Error) => Err(Error::SubscriptionFailed), @@ -69,21 +69,6 @@ impl Client { self.execute(Command::Unsubscribe, f(SubscriptionArgs::new()), Some(evt)) } - -// private - - fn handshake(&mut self) -> Result<()> { - let client_id = self.client_id; - let version = self.version; - let hs = json![{ - "client_id": client_id.to_string(), - "v": version, - "nonce": utils::nonce() - }]; - self.connection.send(OpCode::Handshake, hs)?; - self.connection.recv()?; - Ok(()) - } } impl Client { diff --git a/src/connection/base.rs b/src/connection/base.rs index 8643731..c6cc2f4 100644 --- a/src/connection/base.rs +++ b/src/connection/base.rs @@ -4,15 +4,11 @@ use std::{ path::PathBuf, }; -use serde::Serialize; - -use models::message::{Message, OpCode}; +use models::message::Message; use error::Result; -pub trait Connection - where Self: Sized -{ +pub trait Connection: Sized { type Socket: Write + Read; @@ -26,10 +22,7 @@ pub trait Connection Self::ipc_path().join(format!("discord-ipc-{}", n)) } - fn send(&mut self, opcode: OpCode, payload: T) -> Result<()> - where T: Serialize - { - let message = Message::new(opcode, payload); + fn send(&mut self, message: Message) -> Result<()> { debug!("{:?}", message); match message.encode() { Err(why) => error!("{:?}", why), diff --git a/src/connection/manager.rs b/src/connection/manager.rs new file mode 100644 index 0000000..37e7ba9 --- /dev/null +++ b/src/connection/manager.rs @@ -0,0 +1,112 @@ +use std::{ + thread::{self, JoinHandle}, + sync::{ + Arc, + Mutex, + atomic::AtomicBool, + mpsc::{sync_channel, Receiver, SyncSender}, + }, + time, +}; + +use serde_json; +use serde::{Serialize, de::DeserializeOwned}; + +use super::Connection; +use utils; +use models::{Message, OpCode, ReadyEvent, payload::Payload}; +use error::Result; + +type MessageQueue = (SyncSender, Receiver); + +pub struct Manager + where T: Connection + Send + Sync +{ + client_id: u64, + send_channel: SyncSender, + recv_channel: Receiver, + _version: u32, + _connected: Arc, + _receiver: JoinHandle<()>, + _sender: JoinHandle<()>, + _connection: Arc>, +} + +impl Manager + where T: Connection + Send + Sync + 'static, +{ + pub fn with_connection(client_id: u64, connection: T) -> Result { + let send_queue: MessageQueue = sync_channel(20); + let recv_queue: MessageQueue = sync_channel(20); + let conn = Arc::new(Mutex::new(connection)); + + Ok(Self { + client_id, + send_channel: send_queue.0, + recv_channel: recv_queue.1, + _version: 1, + _connected: Arc::new(AtomicBool::new(false)), + _sender: Self::sender_loop(conn.clone(), (recv_queue.0.clone(), send_queue.1)), + _receiver: Self::receiver_loop(conn.clone(), recv_queue.0.clone()), + _connection: conn, + }) + } + + pub fn send(&mut self, opcode: OpCode, payload: S) -> Result<()> + where S: Serialize + Sync + Send + { + let message = Message::new(opcode, payload); + self.send_channel.send(message).unwrap(); + Ok(()) + } + + pub fn recv(&mut self) -> Result + where S: DeserializeOwned + Send + Sync + { + let message = self.recv_channel.recv().unwrap(); + let payload = serde_json::from_str(&message.payload).unwrap(); + Ok(payload) + } + + pub fn handshake(&mut self) -> Result<()> { + let hs = json![{ + "client_id": self.client_id.to_string(), + "v": 1, + "nonce": utils::nonce() + }]; + self.send(OpCode::Handshake, hs)?; + let _: Payload = self.recv()?; + Ok(()) + } + + fn sender_loop(connection: Arc>, queue: MessageQueue) -> JoinHandle<()> { + thread::spawn(move || { + println!("starting sender loop..."); + loop { + if let Ok(msg) = queue.1.recv() { + if let Ok(mut guard) = connection.lock() { + guard.send(msg).unwrap(); + if let Ok(res) = guard.recv() { + queue.0.send(res).unwrap(); + } + } + }; + thread::sleep(time::Duration::from_millis(500)); + } + }) + } + + fn receiver_loop(connection: Arc>, queue: SyncSender) -> JoinHandle<()> { + thread::spawn(move || { + println!("starting receiver loop..."); + loop { + if let Ok(mut guard) = connection.lock() { + if let Ok(msg) = guard.recv() { + queue.send(msg).unwrap(); + } + }; + thread::sleep(time::Duration::from_millis(500)); + } + }) + } +} diff --git a/src/connection/mod.rs b/src/connection/mod.rs index f76e59b..dbdad2f 100644 --- a/src/connection/mod.rs +++ b/src/connection/mod.rs @@ -1,10 +1,12 @@ mod base; +mod manager; #[cfg(unix)] mod unix; #[cfg(windows)] mod windows; pub use self::base::Connection as Connection; +pub use self::manager::Manager; #[cfg(unix)] pub use self::unix::UnixConnection as SocketConnection; #[cfg(windows)] diff --git a/src/connection/unix.rs b/src/connection/unix.rs index a7753d2..9a900ff 100644 --- a/src/connection/unix.rs +++ b/src/connection/unix.rs @@ -19,6 +19,7 @@ impl Connection for UnixConnection { fn connect() -> Result { let connection_name = Self::socket_path(0); let socket = UnixStream::connect(connection_name)?; + socket.set_nonblocking(true)?; socket.set_write_timeout(Some(time::Duration::from_secs(30)))?; socket.set_read_timeout(Some(time::Duration::from_secs(30)))?; Ok(Self { socket })