baubot_server/lib.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
//! [BauServer] is a TCP server protocol for [BauBot]. Because fuck HTTP. Protect the connection
//! yourself, ideally through dockerized containers or fork a TLS wrapper I'll be happy to merge.
//!
//! Transactions shoud follow the following pattern:
//! - [BauClient] opens a connection to [BauServer]
//! - [BauClient] sends a payload that is serializable into a [BauMessage]. [BauClient] itself
//! checks the payload to make sure it can be transmitted.
//! - [BauClient] rejects the request if it cannot be correctly serialized.
//! - [BauServer] constructs a [BauMessage] and sends that to [BauBot]
//! - [BauServer] rejects the request if it cannot be correctly de-serialized.
//! - [BauBot] broadcasts the [BauMessage] to the appropriate [BauMessage::recipients]
//! - (only if response requested) [BauBot] polls the [BauMessage::recipients] for a response
//! - (only if response requested) [BauBot] receives the [BauResponse] and pipes it back to the
//! [BauServer]
//! - (only if response requested) [BauServer] notifies the [BauClient] through a
//! [BauServerResponseSender] that we received a response.
//! - (only if response requested) [BauClient] polls the [BauServerResponseReceiver] and obtains
//! the [BauServerResponse].
//! - [net::TcpStream] is closed, signifying the end of the transaction.
use baubot_core::BauBot;
pub use prelude::types::*;
pub(crate) use prelude::*;
use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::Arc;
use tokio::io::AsyncWriteExt;
pub mod prelude;
/// [BauServer] listens for requests on the specified address, ideally following the transaction
/// protocol described in the [crate] documentation.
pub struct BauServer<Db, DbRef>
where
Db: baubot_core::prelude::BauData + Send + Sync + 'static,
DbRef: Deref<Target = Db> + Clone + Send + Sync + 'static,
{
_baubot: PhantomData<BauBot<Db, DbRef>>,
listener: task::JoinHandle<()>,
}
/// Abort the listener on drop to avoid hanging processes
impl<Db, DbRef> Drop for BauServer<Db, DbRef>
where
Db: baubot_core::prelude::BauData + Send + Sync + 'static,
DbRef: Deref<Target = Db> + Clone + Send + Sync + 'static,
{
fn drop(&mut self) {
self.listener.abort();
}
}
impl<Db, DbRef> BauServer<Db, DbRef>
where
Db: baubot_core::prelude::BauData + Send + Sync + 'static,
DbRef: Deref<Target = Db> + Clone + Send + Sync + 'static,
{
/// Creates a new [BauServer] object that adheres to the transaction protocol specified in the
/// [crate] documentation.
pub fn new<S: Into<String>>(db: DbRef, addr: ::core::net::SocketAddr, token: S) -> Self {
// Create baubot
let baubot = Arc::new(BauBot::new(db, token));
// Create listening thread
let listener = task::spawn(Self::listen(baubot, addr));
Self {
_baubot: PhantomData,
listener,
}
}
/// Loop that listens for [net::TcpStream] connections and spawns threads to deal with them.
async fn listen(baubot: Arc<BauBot<Db, DbRef>>, addr: ::core::net::SocketAddr) {
// Create TCP listener
// NOTE: Init tasks should unwrap
let tcp_listener = net::TcpListener::bind(addr).await.unwrap();
// Debug
trace!("Server listening on {:?}", tcp_listener.local_addr());
// Loop
loop {
match tcp_listener.accept().await {
Ok(ok) => {
task::spawn(Self::incoming_handler(baubot.clone(), ok));
}
Err(err) => error!("Unable to accept connection: {err:?}"),
}
}
}
async fn incoming_handler(
baubot: Arc<BauBot<Db, DbRef>>,
(mut tcp_stream, socket_addr): (net::TcpStream, std::net::SocketAddr),
) -> std::io::Result<()> {
let _ = baubot;
trace!("Received connection from {socket_addr:?}");
// Receive stream
let request = read_stream(&tcp_stream).await?;
trace!("Received request: {request}");
// Pass off to baubot notification
let baubot_response_receivers = Self::notify_baubot(baubot, request);
// Check the status of the baubot_response_recievers
match baubot_response_receivers {
// If baubot managed to assemble a set of receivers:
Ok(responses) => Self::await_baubot_responses(&tcp_stream, responses).await?,
// If there was a serialization error, report it
Err(err) => {
let err = BauServerResponse::InvalidData(err);
// NOTE: Safe to unwrap becausee we have checked the serialization pipeline
let err = serde_json::to_string(&err).unwrap();
write_stream(&tcp_stream, &err).await?;
}
}
// close the TCP connection
trace!("Shutting stream down");
tcp_stream.shutdown().await
}
fn notify_baubot(
baubot: Arc<BauBot<Db, DbRef>>,
request: String,
) -> Result<Vec<(String, BauResponseReceiver)>, SerializeError> {
let mut bau_message = BauMessage::builder(&request)?();
let mut baubot_responses = Vec::new();
// If payload requries a response, create handlers
if !bau_message.responses.keyboard.is_empty() {}
for (recipient, baubot_response_sender_field) in bau_message.recipients.iter_mut() {
let (baubot_response_sender, baubot_response_receiver) = sync::oneshot::channel();
baubot_responses.push((recipient.clone(), baubot_response_receiver));
*baubot_response_sender_field = Some(baubot_response_sender);
}
// NOTE: Safe to unwrap because if the baubot has died... this entire thing is dogshit
baubot.send(bau_message).unwrap();
Ok(baubot_responses)
}
async fn await_baubot_responses(
tcp_stream: &net::TcpStream,
responses: Vec<(String, BauResponseReceiver)>,
) -> std::io::Result<()> {
// Create iterator over responses that returns a future
let responses = responses.into_iter().map(|(recipient, receiver)| async {
let receiver = receiver.await;
trace!("Recieved response from baubot: {receiver:?}");
let response = match receiver {
// This indicates a good response
Ok(response) => BauServerResponse::Recipient {
recipient,
response,
},
// If the receiver errors out, it means that a timeout has occured
Err(_) => BauServerResponse::Recipient {
recipient,
response: Err(BauBotError::Timeout),
},
};
// NOTE: Safe to unwrap because we checked the serialization chain
let response = serde_json::to_string(&response).unwrap();
write_stream(&tcp_stream, &response).await
});
// Drive responses
for response in responses {
response.await?;
}
Ok(())
}
}
/// [BauClient] is a helper to connect to a [BauServer] through the transaction protocol
/// described in the [crate] documentation
pub struct BauClient<const RETRIES: usize> {
addr: ::core::net::SocketAddr,
}
/// [BauClient] connects to a [BauServer]
impl<const RETRIES: usize> BauClient<RETRIES> {
/// Create a new [BauClient] to which adheres to the transaction protocol described in the
/// [crate] documentation.
pub fn new(addr: ::core::net::SocketAddr) -> Self {
Self { addr }
}
/// Creates a [net::TcpStream]. Opaque to the end user.
async fn connect(&self) -> std::io::Result<net::TcpStream> {
// Set an attempt counter
let mut attempt = RETRIES;
// Try to connect until attempts run out
loop {
// Debug
trace!(
"Attempt #{} to connect to {:?}",
RETRIES - attempt,
self.addr
);
// Actual attempted connection
match net::TcpStream::connect(self.addr.clone()).await {
// If successful, return the connection
Ok(connection) => {
trace!("Succesful connection");
break Ok(connection);
}
// If unsuccessful, decrement the attempt counter or break as necssary
Err(err) => {
// Decrement attempt counter if we have not reached 0
if attempt > 0 {
attempt -= 1;
continue;
}
// Break
else {
trace!("Connection err: {err}");
break Err(err);
}
}
}
}
}
/// Sends a [String] payload through the [BauClient] to the [BauServer] and returns a
/// [BauServerResponseReceiver] that we can poll for responses.
pub async fn send_string(
&self,
request: String,
) -> Result<BauServerResponseReceiver, SendError> {
// Create stream
let tcp_stream = self.connect().await?;
// Sanity test the input
trace!("Checking request for validity");
let _ = BauMessage::builder(&request)?;
// Write to the stream
trace!("Attemping to send request to stream: {request}");
let _ = write_stream(&tcp_stream, &request).await?;
// Create senders and receivers and send them away with the tcp_stream
let (bau_response_sender, bau_response_receiver) = sync::mpsc::unbounded_channel();
task::spawn(Self::receive_responses(tcp_stream, bau_response_sender));
Ok(bau_response_receiver)
}
/// Sends a [BauMessage] through the [BauClient] to the [BauServer] and returns a
/// [BauServerResponseReceiver] that we can poll for responses.
///
/// **Note**: Caller should **not** provide any handlers. They will be ignored.
pub async fn send(
&self,
bau_message: BauMessage,
) -> Result<BauServerResponseReceiver, SendError> {
// Sanity test the input
let bau_message = serde_json::to_string(&bau_message).map_err(|err| {
SendError::InvalidData(SerializeError::InvalidJson(format! {"{err:?}"}))
})?;
self.send_string(bau_message).await
}
/// Loop to receive responses from the [BauServer] and send responses to the
/// [BauServerResponseReceiver]. Opaque to the consumer.
async fn receive_responses(
tcp_stream: net::TcpStream,
bau_response_sender: BauServerResponseSender,
) {
trace!("Waiting for responses from BauServer");
// If the stream is closed read_stream will return an error (which we can discard since we
// dont care).
while let Ok(response) = read_stream(&tcp_stream).await {
trace!("Received payload: {response}");
// Break if empty response received because that means the stream closed
if response.is_empty() {
break;
}
// Check if we are able to construct a response from the server
if let Ok(response) = serde_json::from_str::<BauServerResponse>(&response) {
trace!("Sending response to receiver.");
if let Err(_) = bau_response_sender.send(response) {
error!("Reciever went out of scope.");
break;
}
};
}
trace!("Finished receiving responses from BauServer.");
}
}