Implement reconnection logic

Still needs some changes to handle
* Retrying message send
* Resubscribing current event subscriptions
* Make non-blocking again
This commit is contained in:
Patrick Auernig 2018-04-07 13:22:50 +02:00
parent d370cb6432
commit a1e77c9c35
5 changed files with 176 additions and 85 deletions

View File

@ -1,10 +1,8 @@
use std::time;
use serde::{Serialize, de::DeserializeOwned};
use connection::{
Connection,
SocketConnection,
Manager as ConnectionManager,
};
use connection::Manager as ConnectionManager;
use models::{
OpCode,
Command,
@ -12,30 +10,24 @@ use models::{
payload::Payload,
commands::{SubscriptionArgs, Subscription},
};
#[cfg(feature = "rich_presence")]
use models::rich_presence::{SetActivityArgs, Activity};
use error::{Result, Error};
pub struct Client<T>
where T: Connection + Send + Sync + 'static
{
connection: ConnectionManager<T>,
pub struct Client {
connection: ConnectionManager,
}
impl<T> Client<T>
where T: Connection + Send + Sync + 'static
{
pub fn with_connection(client_id: u64, connection: T) -> Result<Self> {
impl Client {
pub fn new(client_id: u64) -> Result<Self> {
Ok(Self {
connection: ConnectionManager::with_connection(client_id, connection)?
connection: ConnectionManager::new(client_id)?
})
}
pub fn start(mut self) -> Result<Self> {
self.connection.handshake()?;
Ok(self)
pub fn start(&mut self) {
self.connection.start();
}
pub fn execute<A, E>(&mut self, cmd: Command, args: A, evt: Option<Event>) -> Result<Payload<E>>
@ -43,7 +35,7 @@ impl<T> Client<T>
E: Serialize + DeserializeOwned + Send + Sync
{
self.connection.send(OpCode::Frame, Payload::with_nonce(cmd, Some(args), None, evt))?;
let response: Payload<E> = self.connection.recv()?;
let (_, response): (OpCode, Payload<E>) = self.connection.recv()?;
match response.evt {
Some(Event::Error) => Err(Error::SubscriptionFailed),
@ -70,10 +62,3 @@ impl<T> Client<T>
self.execute(Command::Unsubscribe, f(SubscriptionArgs::new()), Some(evt))
}
}
impl Client<SocketConnection> {
pub fn new(client_id: u64) -> Result<Self> {
let socket = Connection::connect()?;
Self::with_connection(client_id, socket)
}
}

View File

@ -18,18 +18,20 @@ pub trait Connection: Sized {
fn connect() -> Result<Self>;
fn disconnect(&self) -> Result<()>;
fn socket_path(n: u8) -> PathBuf {
Self::ipc_path().join(format!("discord-ipc-{}", n))
}
fn send(&mut self, message: Message) -> Result<()> {
debug!("{:?}", message);
match message.encode() {
Err(why) => error!("{:?}", why),
Ok(bytes) => {
self.socket().write_all(bytes.as_ref())?;
}
};
debug!("-> {:?}", message);
Ok(())
}
@ -38,7 +40,7 @@ pub trait Connection: Sized {
let n = self.socket().read(buf.as_mut_slice())?;
buf.resize(n, 0);
let message = Message::decode(&buf)?;
debug!("{:?}", message);
debug!("<- {:?}", message);
Ok(message)
}
}

View File

@ -3,53 +3,61 @@ use std::{
sync::{
Arc,
Mutex,
atomic::AtomicBool,
atomic::{AtomicBool, ATOMIC_BOOL_INIT, Ordering},
mpsc::{sync_channel, Receiver, SyncSender},
},
time,
io::ErrorKind
};
use serde_json;
use serde::{Serialize, de::DeserializeOwned};
use super::Connection;
use super::{
Connection as BaseConnection,
SocketConnection,
};
use utils;
use models::{Message, OpCode, ReadyEvent, payload::Payload};
use error::Result;
use models::{Message, OpCode};
use error::{Result, Error};
type MessageQueue = (SyncSender<Message>, Receiver<Message>);
type Connection = Arc<Mutex<Option<SocketConnection>>>;
pub struct Manager<T>
where T: Connection + Send + Sync
{
client_id: u64,
static CONNECTED: AtomicBool = ATOMIC_BOOL_INIT;
static STARTED: AtomicBool = ATOMIC_BOOL_INIT;
static HANDSHAKED: AtomicBool = ATOMIC_BOOL_INIT;
static HANDSHAKING: AtomicBool = ATOMIC_BOOL_INIT;
pub struct Manager {
send_channel: SyncSender<Message>,
recv_channel: Receiver<Message>,
_version: u32,
_connected: Arc<AtomicBool>,
_receiver: JoinHandle<()>,
_sender: JoinHandle<()>,
_connection: Arc<Mutex<T>>,
_checker: JoinHandle<()>,
_connection: Connection,
}
impl<T> Manager<T>
where T: Connection + Send + Sync + 'static,
{
pub fn with_connection(client_id: u64, connection: T) -> Result<Self> {
impl Manager {
pub fn new(client_id: u64) -> Result<Self> {
let send_queue: MessageQueue = sync_channel(20);
let recv_queue: MessageQueue = sync_channel(20);
let conn = Arc::new(Mutex::new(connection));
let conn = Arc::new(Mutex::new(None));
let send_channel = send_queue.0;
let recv_channel = recv_queue.1;
Ok(Self {
client_id,
send_channel: send_queue.0,
recv_channel: recv_queue.1,
_version: 1,
_connected: Arc::new(AtomicBool::new(false)),
let manager = Self {
send_channel,
recv_channel,
_sender: Self::sender_loop(conn.clone(), (recv_queue.0.clone(), send_queue.1)),
_receiver: Self::receiver_loop(conn.clone(), recv_queue.0.clone()),
_checker: Self::connection_checker(conn.clone(), client_id),
_connection: conn,
})
};
Ok(manager)
}
pub fn start(&mut self) {
STARTED.store(true, Ordering::SeqCst);
}
pub fn send<S>(&mut self, opcode: OpCode, payload: S) -> Result<()>
@ -60,51 +68,123 @@ impl<T> Manager<T>
Ok(())
}
pub fn recv<S>(&mut self) -> Result<S>
pub fn recv<S>(&mut self) -> Result<(OpCode, S)>
where S: DeserializeOwned + Send + Sync
{
let message = self.recv_channel.recv().unwrap();
let payload = serde_json::from_str(&message.payload).unwrap();
Ok(payload)
let message = self.recv_channel.recv_timeout(time::Duration::from_secs(10))?;
let payload = serde_json::from_str(&message.payload)?;
Ok((message.opcode, payload))
}
pub fn handshake(&mut self) -> Result<()> {
let hs = json![{
"client_id": self.client_id.to_string(),
"v": 1,
"nonce": utils::nonce()
}];
self.send(OpCode::Handshake, hs)?;
let _: Payload<ReadyEvent> = self.recv()?;
fn connect(connection: Connection) -> Result<()> {
if !CONNECTED.load(Ordering::SeqCst) {
if let Ok(mut conn_lock) = connection.lock() {
if conn_lock.is_some() {
if let Some(ref mut conn) = *conn_lock {
if let Ok(opcode) = Self::ping(conn) {
if opcode == OpCode::Pong {
CONNECTED.store(true, Ordering::SeqCst);
debug!("Reconnected")
}
}
}
} else {
*conn_lock = Some(SocketConnection::connect()?);
CONNECTED.store(true, Ordering::SeqCst);
debug!("Connected")
}
}
}
Ok(())
}
fn sender_loop(connection: Arc<Mutex<T>>, queue: MessageQueue) -> JoinHandle<()> {
fn disconnect(connection: Connection) {
if let Ok(mut conn_lock) = connection.lock() {
if let Some(ref mut conn) = *conn_lock {
if conn.disconnect().is_ok() {
CONNECTED.store(false, Ordering::SeqCst);
HANDSHAKED.store(false, Ordering::SeqCst);
}
}
*conn_lock = None;
}
}
fn handshake(connection: Connection, client_id: u64) -> Result<()> {
if CONNECTED.load(Ordering::SeqCst) && !HANDSHAKED.load(Ordering::SeqCst) && !HANDSHAKING.load(Ordering::SeqCst) {
let hs = json![{
"client_id": client_id.to_string(),
"v": 1,
"nonce": utils::nonce()
}];
if let Ok(mut conn_guard) = connection.lock() {
if let Some(ref mut conn) = *conn_guard {
conn.send(Message::new(OpCode::Handshake, hs))?;
let _res = conn.recv()?;
HANDSHAKED.store(true, Ordering::SeqCst);
}
}
}
Ok(())
}
fn ping(connection: &mut SocketConnection) -> Result<OpCode> {
let message = Message::new(OpCode::Ping, json![{}]);
connection.send(message)?;
let opcode = connection.recv()?.opcode;
debug!("{:?}", opcode);
Ok(opcode)
}
fn send_and_receive(connection: &Connection, queue: &MessageQueue) -> Result<()> {
if let Ok(msg) = queue.1.recv() {
if let Ok(mut conn_guard) = connection.lock() {
if let Some(ref mut conn) = *conn_guard {
conn.send(msg)?;
let res = conn.recv()?;
queue.0.send(res).unwrap();
}
}
};
Ok(())
}
fn connection_checker(connection: Connection, client_id: u64) -> JoinHandle<()> {
thread::spawn(move || {
println!("starting sender loop...");
debug!("Starting connection checker loop...");
loop {
if let Ok(msg) = queue.1.recv() {
if let Ok(mut guard) = connection.lock() {
guard.send(msg).unwrap();
if let Ok(res) = guard.recv() {
queue.0.send(res).unwrap();
}
}
let _ = Self::connect(connection.clone());
match Self::handshake(connection.clone(), client_id) {
Err(Error::IoError(ref err)) if err.kind() == ErrorKind::WouldBlock => {
debug!("{:?}", err);
},
Err(err) => debug!("{:?}", err),
Ok(_) => ()
};
thread::sleep(time::Duration::from_millis(500));
}
})
}
fn receiver_loop(connection: Arc<Mutex<T>>, queue: SyncSender<Message>) -> JoinHandle<()> {
fn sender_loop(connection: Connection, queue: MessageQueue) -> JoinHandle<()> {
thread::spawn(move || {
println!("starting receiver loop...");
debug!("Starting sender loop...");
loop {
if let Ok(mut guard) = connection.lock() {
if let Ok(msg) = guard.recv() {
queue.send(msg).unwrap();
if STARTED.load(Ordering::SeqCst) && CONNECTED.load(Ordering::SeqCst) && HANDSHAKED.load(Ordering::SeqCst) {
match Self::send_and_receive(&connection, &queue) {
Err(Error::IoError(ref err)) if err.kind() == ErrorKind::WouldBlock => (),
Err(Error::IoError(err)) => {
Self::disconnect(connection.clone());
// error!("Disconnected: {}", err);
},
Err(why) => error!("{}", why),
Ok(_) => ()
}
};
}
thread::sleep(time::Duration::from_millis(500));
}
})

View File

@ -3,6 +3,7 @@ use std::{
path::PathBuf,
env,
os::unix::net::UnixStream,
net::Shutdown,
};
use super::base::Connection;
@ -19,12 +20,17 @@ impl Connection for UnixConnection {
fn connect() -> Result<Self> {
let connection_name = Self::socket_path(0);
let socket = UnixStream::connect(connection_name)?;
socket.set_nonblocking(true)?;
// socket.set_nonblocking(true)?;
socket.set_write_timeout(Some(time::Duration::from_secs(30)))?;
socket.set_read_timeout(Some(time::Duration::from_secs(30)))?;
Ok(Self { socket })
}
fn disconnect(&self) -> Result<()> {
self.socket.shutdown(Shutdown::Both)?;
Ok(())
}
fn ipc_path() -> PathBuf {
let tmp = env::var("XDG_RUNTIME_DIR")
.or_else(|_| env::var("TMPDIR"))

View File

@ -2,17 +2,21 @@ use std::{
error::Error as StdError,
io::Error as IoError,
result::Result as StdResult,
sync::mpsc::RecvTimeoutError as ChannelTimeout,
fmt::{
self,
Display,
Formatter
}
},
};
use serde_json::Error as JsonError;
#[derive(Debug)]
pub enum Error {
Io(IoError),
IoError(IoError),
JsonError(JsonError),
Timeout(ChannelTimeout),
Conversion,
SubscriptionFailed,
}
@ -28,14 +32,28 @@ impl StdError for Error {
match *self {
Error::Conversion => "Failed to convert values",
Error::SubscriptionFailed => "Failed to subscribe to event",
Error::Io(ref err) => err.description()
Error::IoError(ref err) => err.description(),
Error::JsonError(ref err) => err.description(),
Error::Timeout(ref err) => err.description(),
}
}
}
impl From<IoError> for Error {
fn from(err: IoError) -> Self {
Error::Io(err)
Error::IoError(err)
}
}
impl From<JsonError> for Error {
fn from(err: JsonError) -> Self {
Error::JsonError(err)
}
}
impl From<ChannelTimeout> for Error {
fn from(err: ChannelTimeout) -> Self {
Error::Timeout(err)
}
}