Implement reconnection logic
Still needs some changes to handle * Retrying message send * Resubscribing current event subscriptions * Make non-blocking again
This commit is contained in:
parent
d370cb6432
commit
a1e77c9c35
|
@ -1,10 +1,8 @@
|
|||
use std::time;
|
||||
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
|
||||
use connection::{
|
||||
Connection,
|
||||
SocketConnection,
|
||||
Manager as ConnectionManager,
|
||||
};
|
||||
use connection::Manager as ConnectionManager;
|
||||
use models::{
|
||||
OpCode,
|
||||
Command,
|
||||
|
@ -12,30 +10,24 @@ use models::{
|
|||
payload::Payload,
|
||||
commands::{SubscriptionArgs, Subscription},
|
||||
};
|
||||
|
||||
#[cfg(feature = "rich_presence")]
|
||||
use models::rich_presence::{SetActivityArgs, Activity};
|
||||
use error::{Result, Error};
|
||||
|
||||
|
||||
pub struct Client<T>
|
||||
where T: Connection + Send + Sync + 'static
|
||||
{
|
||||
connection: ConnectionManager<T>,
|
||||
pub struct Client {
|
||||
connection: ConnectionManager,
|
||||
}
|
||||
|
||||
impl<T> Client<T>
|
||||
where T: Connection + Send + Sync + 'static
|
||||
{
|
||||
pub fn with_connection(client_id: u64, connection: T) -> Result<Self> {
|
||||
impl Client {
|
||||
pub fn new(client_id: u64) -> Result<Self> {
|
||||
Ok(Self {
|
||||
connection: ConnectionManager::with_connection(client_id, connection)?
|
||||
connection: ConnectionManager::new(client_id)?
|
||||
})
|
||||
}
|
||||
|
||||
pub fn start(mut self) -> Result<Self> {
|
||||
self.connection.handshake()?;
|
||||
Ok(self)
|
||||
pub fn start(&mut self) {
|
||||
self.connection.start();
|
||||
}
|
||||
|
||||
pub fn execute<A, E>(&mut self, cmd: Command, args: A, evt: Option<Event>) -> Result<Payload<E>>
|
||||
|
@ -43,7 +35,7 @@ impl<T> Client<T>
|
|||
E: Serialize + DeserializeOwned + Send + Sync
|
||||
{
|
||||
self.connection.send(OpCode::Frame, Payload::with_nonce(cmd, Some(args), None, evt))?;
|
||||
let response: Payload<E> = self.connection.recv()?;
|
||||
let (_, response): (OpCode, Payload<E>) = self.connection.recv()?;
|
||||
|
||||
match response.evt {
|
||||
Some(Event::Error) => Err(Error::SubscriptionFailed),
|
||||
|
@ -70,10 +62,3 @@ impl<T> Client<T>
|
|||
self.execute(Command::Unsubscribe, f(SubscriptionArgs::new()), Some(evt))
|
||||
}
|
||||
}
|
||||
|
||||
impl Client<SocketConnection> {
|
||||
pub fn new(client_id: u64) -> Result<Self> {
|
||||
let socket = Connection::connect()?;
|
||||
Self::with_connection(client_id, socket)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,18 +18,20 @@ pub trait Connection: Sized {
|
|||
|
||||
fn connect() -> Result<Self>;
|
||||
|
||||
fn disconnect(&self) -> Result<()>;
|
||||
|
||||
fn socket_path(n: u8) -> PathBuf {
|
||||
Self::ipc_path().join(format!("discord-ipc-{}", n))
|
||||
}
|
||||
|
||||
fn send(&mut self, message: Message) -> Result<()> {
|
||||
debug!("{:?}", message);
|
||||
match message.encode() {
|
||||
Err(why) => error!("{:?}", why),
|
||||
Ok(bytes) => {
|
||||
self.socket().write_all(bytes.as_ref())?;
|
||||
}
|
||||
};
|
||||
debug!("-> {:?}", message);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -38,7 +40,7 @@ pub trait Connection: Sized {
|
|||
let n = self.socket().read(buf.as_mut_slice())?;
|
||||
buf.resize(n, 0);
|
||||
let message = Message::decode(&buf)?;
|
||||
debug!("{:?}", message);
|
||||
debug!("<- {:?}", message);
|
||||
Ok(message)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,53 +3,61 @@ use std::{
|
|||
sync::{
|
||||
Arc,
|
||||
Mutex,
|
||||
atomic::AtomicBool,
|
||||
atomic::{AtomicBool, ATOMIC_BOOL_INIT, Ordering},
|
||||
mpsc::{sync_channel, Receiver, SyncSender},
|
||||
},
|
||||
time,
|
||||
io::ErrorKind
|
||||
};
|
||||
|
||||
use serde_json;
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
|
||||
use super::Connection;
|
||||
use super::{
|
||||
Connection as BaseConnection,
|
||||
SocketConnection,
|
||||
};
|
||||
use utils;
|
||||
use models::{Message, OpCode, ReadyEvent, payload::Payload};
|
||||
use error::Result;
|
||||
use models::{Message, OpCode};
|
||||
use error::{Result, Error};
|
||||
|
||||
type MessageQueue = (SyncSender<Message>, Receiver<Message>);
|
||||
type Connection = Arc<Mutex<Option<SocketConnection>>>;
|
||||
|
||||
pub struct Manager<T>
|
||||
where T: Connection + Send + Sync
|
||||
{
|
||||
client_id: u64,
|
||||
static CONNECTED: AtomicBool = ATOMIC_BOOL_INIT;
|
||||
static STARTED: AtomicBool = ATOMIC_BOOL_INIT;
|
||||
static HANDSHAKED: AtomicBool = ATOMIC_BOOL_INIT;
|
||||
static HANDSHAKING: AtomicBool = ATOMIC_BOOL_INIT;
|
||||
|
||||
pub struct Manager {
|
||||
send_channel: SyncSender<Message>,
|
||||
recv_channel: Receiver<Message>,
|
||||
_version: u32,
|
||||
_connected: Arc<AtomicBool>,
|
||||
_receiver: JoinHandle<()>,
|
||||
_sender: JoinHandle<()>,
|
||||
_connection: Arc<Mutex<T>>,
|
||||
_checker: JoinHandle<()>,
|
||||
_connection: Connection,
|
||||
}
|
||||
|
||||
impl<T> Manager<T>
|
||||
where T: Connection + Send + Sync + 'static,
|
||||
{
|
||||
pub fn with_connection(client_id: u64, connection: T) -> Result<Self> {
|
||||
impl Manager {
|
||||
pub fn new(client_id: u64) -> Result<Self> {
|
||||
let send_queue: MessageQueue = sync_channel(20);
|
||||
let recv_queue: MessageQueue = sync_channel(20);
|
||||
let conn = Arc::new(Mutex::new(connection));
|
||||
let conn = Arc::new(Mutex::new(None));
|
||||
let send_channel = send_queue.0;
|
||||
let recv_channel = recv_queue.1;
|
||||
|
||||
Ok(Self {
|
||||
client_id,
|
||||
send_channel: send_queue.0,
|
||||
recv_channel: recv_queue.1,
|
||||
_version: 1,
|
||||
_connected: Arc::new(AtomicBool::new(false)),
|
||||
let manager = Self {
|
||||
send_channel,
|
||||
recv_channel,
|
||||
_sender: Self::sender_loop(conn.clone(), (recv_queue.0.clone(), send_queue.1)),
|
||||
_receiver: Self::receiver_loop(conn.clone(), recv_queue.0.clone()),
|
||||
_checker: Self::connection_checker(conn.clone(), client_id),
|
||||
_connection: conn,
|
||||
})
|
||||
};
|
||||
|
||||
Ok(manager)
|
||||
}
|
||||
|
||||
pub fn start(&mut self) {
|
||||
STARTED.store(true, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
pub fn send<S>(&mut self, opcode: OpCode, payload: S) -> Result<()>
|
||||
|
@ -60,51 +68,123 @@ impl<T> Manager<T>
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn recv<S>(&mut self) -> Result<S>
|
||||
pub fn recv<S>(&mut self) -> Result<(OpCode, S)>
|
||||
where S: DeserializeOwned + Send + Sync
|
||||
{
|
||||
let message = self.recv_channel.recv().unwrap();
|
||||
let payload = serde_json::from_str(&message.payload).unwrap();
|
||||
Ok(payload)
|
||||
let message = self.recv_channel.recv_timeout(time::Duration::from_secs(10))?;
|
||||
let payload = serde_json::from_str(&message.payload)?;
|
||||
Ok((message.opcode, payload))
|
||||
}
|
||||
|
||||
pub fn handshake(&mut self) -> Result<()> {
|
||||
let hs = json![{
|
||||
"client_id": self.client_id.to_string(),
|
||||
"v": 1,
|
||||
"nonce": utils::nonce()
|
||||
}];
|
||||
self.send(OpCode::Handshake, hs)?;
|
||||
let _: Payload<ReadyEvent> = self.recv()?;
|
||||
fn connect(connection: Connection) -> Result<()> {
|
||||
if !CONNECTED.load(Ordering::SeqCst) {
|
||||
if let Ok(mut conn_lock) = connection.lock() {
|
||||
if conn_lock.is_some() {
|
||||
if let Some(ref mut conn) = *conn_lock {
|
||||
if let Ok(opcode) = Self::ping(conn) {
|
||||
if opcode == OpCode::Pong {
|
||||
CONNECTED.store(true, Ordering::SeqCst);
|
||||
debug!("Reconnected")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*conn_lock = Some(SocketConnection::connect()?);
|
||||
CONNECTED.store(true, Ordering::SeqCst);
|
||||
debug!("Connected")
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn sender_loop(connection: Arc<Mutex<T>>, queue: MessageQueue) -> JoinHandle<()> {
|
||||
fn disconnect(connection: Connection) {
|
||||
if let Ok(mut conn_lock) = connection.lock() {
|
||||
if let Some(ref mut conn) = *conn_lock {
|
||||
if conn.disconnect().is_ok() {
|
||||
CONNECTED.store(false, Ordering::SeqCst);
|
||||
HANDSHAKED.store(false, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
|
||||
*conn_lock = None;
|
||||
}
|
||||
}
|
||||
|
||||
fn handshake(connection: Connection, client_id: u64) -> Result<()> {
|
||||
if CONNECTED.load(Ordering::SeqCst) && !HANDSHAKED.load(Ordering::SeqCst) && !HANDSHAKING.load(Ordering::SeqCst) {
|
||||
let hs = json![{
|
||||
"client_id": client_id.to_string(),
|
||||
"v": 1,
|
||||
"nonce": utils::nonce()
|
||||
}];
|
||||
|
||||
if let Ok(mut conn_guard) = connection.lock() {
|
||||
if let Some(ref mut conn) = *conn_guard {
|
||||
conn.send(Message::new(OpCode::Handshake, hs))?;
|
||||
let _res = conn.recv()?;
|
||||
HANDSHAKED.store(true, Ordering::SeqCst);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ping(connection: &mut SocketConnection) -> Result<OpCode> {
|
||||
let message = Message::new(OpCode::Ping, json![{}]);
|
||||
connection.send(message)?;
|
||||
let opcode = connection.recv()?.opcode;
|
||||
debug!("{:?}", opcode);
|
||||
Ok(opcode)
|
||||
}
|
||||
|
||||
fn send_and_receive(connection: &Connection, queue: &MessageQueue) -> Result<()> {
|
||||
if let Ok(msg) = queue.1.recv() {
|
||||
if let Ok(mut conn_guard) = connection.lock() {
|
||||
if let Some(ref mut conn) = *conn_guard {
|
||||
conn.send(msg)?;
|
||||
let res = conn.recv()?;
|
||||
queue.0.send(res).unwrap();
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn connection_checker(connection: Connection, client_id: u64) -> JoinHandle<()> {
|
||||
thread::spawn(move || {
|
||||
println!("starting sender loop...");
|
||||
debug!("Starting connection checker loop...");
|
||||
|
||||
loop {
|
||||
if let Ok(msg) = queue.1.recv() {
|
||||
if let Ok(mut guard) = connection.lock() {
|
||||
guard.send(msg).unwrap();
|
||||
if let Ok(res) = guard.recv() {
|
||||
queue.0.send(res).unwrap();
|
||||
}
|
||||
}
|
||||
let _ = Self::connect(connection.clone());
|
||||
match Self::handshake(connection.clone(), client_id) {
|
||||
Err(Error::IoError(ref err)) if err.kind() == ErrorKind::WouldBlock => {
|
||||
debug!("{:?}", err);
|
||||
},
|
||||
Err(err) => debug!("{:?}", err),
|
||||
Ok(_) => ()
|
||||
};
|
||||
|
||||
thread::sleep(time::Duration::from_millis(500));
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn receiver_loop(connection: Arc<Mutex<T>>, queue: SyncSender<Message>) -> JoinHandle<()> {
|
||||
fn sender_loop(connection: Connection, queue: MessageQueue) -> JoinHandle<()> {
|
||||
thread::spawn(move || {
|
||||
println!("starting receiver loop...");
|
||||
debug!("Starting sender loop...");
|
||||
loop {
|
||||
if let Ok(mut guard) = connection.lock() {
|
||||
if let Ok(msg) = guard.recv() {
|
||||
queue.send(msg).unwrap();
|
||||
if STARTED.load(Ordering::SeqCst) && CONNECTED.load(Ordering::SeqCst) && HANDSHAKED.load(Ordering::SeqCst) {
|
||||
match Self::send_and_receive(&connection, &queue) {
|
||||
Err(Error::IoError(ref err)) if err.kind() == ErrorKind::WouldBlock => (),
|
||||
Err(Error::IoError(err)) => {
|
||||
Self::disconnect(connection.clone());
|
||||
// error!("Disconnected: {}", err);
|
||||
},
|
||||
Err(why) => error!("{}", why),
|
||||
Ok(_) => ()
|
||||
}
|
||||
};
|
||||
}
|
||||
thread::sleep(time::Duration::from_millis(500));
|
||||
}
|
||||
})
|
||||
|
|
|
@ -3,6 +3,7 @@ use std::{
|
|||
path::PathBuf,
|
||||
env,
|
||||
os::unix::net::UnixStream,
|
||||
net::Shutdown,
|
||||
};
|
||||
|
||||
use super::base::Connection;
|
||||
|
@ -19,12 +20,17 @@ impl Connection for UnixConnection {
|
|||
fn connect() -> Result<Self> {
|
||||
let connection_name = Self::socket_path(0);
|
||||
let socket = UnixStream::connect(connection_name)?;
|
||||
socket.set_nonblocking(true)?;
|
||||
// socket.set_nonblocking(true)?;
|
||||
socket.set_write_timeout(Some(time::Duration::from_secs(30)))?;
|
||||
socket.set_read_timeout(Some(time::Duration::from_secs(30)))?;
|
||||
Ok(Self { socket })
|
||||
}
|
||||
|
||||
fn disconnect(&self) -> Result<()> {
|
||||
self.socket.shutdown(Shutdown::Both)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ipc_path() -> PathBuf {
|
||||
let tmp = env::var("XDG_RUNTIME_DIR")
|
||||
.or_else(|_| env::var("TMPDIR"))
|
||||
|
|
26
src/error.rs
26
src/error.rs
|
@ -2,17 +2,21 @@ use std::{
|
|||
error::Error as StdError,
|
||||
io::Error as IoError,
|
||||
result::Result as StdResult,
|
||||
sync::mpsc::RecvTimeoutError as ChannelTimeout,
|
||||
fmt::{
|
||||
self,
|
||||
Display,
|
||||
Formatter
|
||||
}
|
||||
},
|
||||
};
|
||||
use serde_json::Error as JsonError;
|
||||
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Io(IoError),
|
||||
IoError(IoError),
|
||||
JsonError(JsonError),
|
||||
Timeout(ChannelTimeout),
|
||||
Conversion,
|
||||
SubscriptionFailed,
|
||||
}
|
||||
|
@ -28,14 +32,28 @@ impl StdError for Error {
|
|||
match *self {
|
||||
Error::Conversion => "Failed to convert values",
|
||||
Error::SubscriptionFailed => "Failed to subscribe to event",
|
||||
Error::Io(ref err) => err.description()
|
||||
Error::IoError(ref err) => err.description(),
|
||||
Error::JsonError(ref err) => err.description(),
|
||||
Error::Timeout(ref err) => err.description(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<IoError> for Error {
|
||||
fn from(err: IoError) -> Self {
|
||||
Error::Io(err)
|
||||
Error::IoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<JsonError> for Error {
|
||||
fn from(err: JsonError) -> Self {
|
||||
Error::JsonError(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ChannelTimeout> for Error {
|
||||
fn from(err: ChannelTimeout) -> Self {
|
||||
Error::Timeout(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue