diff --git a/client/src/main.rs b/client/src/main.rs index 702e359..d5ac790 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -33,6 +33,10 @@ struct Args { #[clap(short, long)] debug: bool, + /// Request to shutdown the daemon + #[clap(long)] + shutdown: bool, + /// Command to run #[clap(trailing_var_arg = true)] command: Vec, diff --git a/common/src/protocol.rs b/common/src/protocol.rs index 214a401..079197b 100644 --- a/common/src/protocol.rs +++ b/common/src/protocol.rs @@ -60,6 +60,10 @@ pub enum ClientMessage { /// Session ID to get information for. session: String, }, + + /// Request to shutdown the daemon. + #[serde(rename = "shutdown")] + Shutdown, } /// Daemon-to-client response message. diff --git a/daemon/src/socket/mod.rs b/daemon/src/socket/mod.rs index a6c9b1a..e254da7 100644 --- a/daemon/src/socket/mod.rs +++ b/daemon/src/socket/mod.rs @@ -39,32 +39,68 @@ impl Server { } // Create the listener - let listener = UnixListener::bind(&self.socket_path) - .context("Failed to bind Unix domain socket")?; + let listener = + UnixListener::bind(&self.socket_path).context("Failed to bind Unix domain socket")?; info!("Listening on {:?}", self.socket_path); + // Channel for shutdown signal + let (shutdown_tx, mut shutdown_rx) = mpsc::channel::(1); + // Accept connections loop { - match listener.accept().await { - Ok((stream, _addr)) => { - let session_manager = self.session_manager.clone(); - tokio::spawn(async move { - if let Err(e) = handle_client(stream, session_manager).await { - error!("Error handling client: {}", e); + // Check for shutdown signal + if shutdown_rx.try_recv().is_ok() { + info!("Received shutdown signal, stopping server"); + break; + } + + tokio::select! { + // Accept new connections + conn = listener.accept() => { + match conn { + Ok((stream, _addr)) => { + let session_manager = self.session_manager.clone(); + let shutdown_tx = shutdown_tx.clone(); + tokio::spawn(async move { + match handle_client(stream, session_manager).await { + Ok(true) => { + // Client requested shutdown + info!("Client requested shutdown"); + shutdown_tx.send(true).await.ok(); + } + Err(e) => { + error!("Error handling client: {}", e); + } + _ => {} + } + }); } - }); - } - Err(e) => { - error!("Error accepting connection: {}", e); + Err(e) => { + error!("Error accepting connection: {}", e); + } + } } + // Wait for a small duration to prevent CPU spin + _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {} } } + + // Clean up the socket file before exiting + if self.socket_path.exists() { + match fs::remove_file(&self.socket_path) { + Ok(_) => debug!("Removed socket file on shutdown"), + Err(e) => error!("Failed to remove socket file on shutdown: {}", e), + } + } + + Ok(()) } } /// Handle a client connection. -async fn handle_client(stream: UnixStream, session_manager: SessionManager) -> Result<()> { +/// Returns Ok(true) if the daemon should be shut down. +async fn handle_client(stream: UnixStream, session_manager: SessionManager) -> Result { let (reader, writer) = tokio::io::split(stream); let mut reader = BufReader::new(reader); let mut writer = BufWriter::new(writer); @@ -83,6 +119,9 @@ async fn handle_client(stream: UnixStream, session_manager: SessionManager) -> R Ok::<_, anyhow::Error>(()) }); + // Track if the message was a shutdown request + let mut should_shutdown = false; + // Process incoming messages let mut buffer = String::new(); loop { @@ -95,12 +134,22 @@ async fn handle_client(stream: UnixStream, session_manager: SessionManager) -> R match serde_json::from_str::(&buffer) { Ok(msg) => { debug!("Received message: {:?}", msg); - if let Err(e) = process_message(msg, &session_manager, tx.clone()).await { - error!("Error processing message: {}", e); - let error_msg = DaemonMessage::Error { - message: e.to_string(), - }; - tx.send(error_msg).await.ok(); + + match process_message(msg, &session_manager, tx.clone()).await { + Ok(Some(true)) => { + // Shutdown signal received + info!("Shutting down connection due to shutdown request"); + should_shutdown = true; + break; + } + Ok(_) => {} // Continue processing + Err(e) => { + error!("Error processing message: {}", e); + let error_msg = DaemonMessage::Error { + message: e.to_string(), + }; + tx.send(error_msg).await.ok(); + } } } Err(e) => { @@ -122,7 +171,12 @@ async fn handle_client(stream: UnixStream, session_manager: SessionManager) -> R // Cancel the write task write_task.abort(); - Ok(()) + // Return true if a shutdown was requested during processing + if should_shutdown { + Ok(true) + } else { + Ok(false) + } } /// Process a client message. @@ -130,9 +184,14 @@ async fn process_message( msg: ClientMessage, session_manager: &SessionManager, tx: mpsc::Sender, -) -> Result<()> { +) -> Result> { match msg { - ClientMessage::RunCommand { session, cmd, cwd, env } => { + ClientMessage::RunCommand { + session, + cmd, + cwd, + env, + } => { // Get or create the session let session_arc = session_manager.get_or_create_session(&session).await; let mut session_guard = session_arc.lock().await; @@ -145,7 +204,7 @@ async fn process_message( }) .await .context("Failed to send error message")?; - return Ok(()); + return Ok(None); } } @@ -207,7 +266,9 @@ async fn process_message( session, code: exit_code, }; - tx.send(exit_msg).await.context("Failed to send exit message")?; + tx.send(exit_msg) + .await + .context("Failed to send exit message")?; } ClientMessage::Attach { session } => { @@ -226,7 +287,9 @@ async fn process_message( let msg = DaemonMessage::Success { message: "Detached from session".to_string(), }; - tx.send(msg).await.context("Failed to send success message")?; + tx.send(msg) + .await + .context("Failed to send success message")?; } ClientMessage::ListSessions => { @@ -262,9 +325,14 @@ async fn process_message( match session_guard.change_directory(&dir) { Ok(_) => { let msg = DaemonMessage::Success { - message: format!("Changed directory to {}", session_guard.get_cwd().display()), + message: format!( + "Changed directory to {}", + session_guard.get_cwd().display() + ), }; - tx.send(msg).await.context("Failed to send success message")?; + tx.send(msg) + .await + .context("Failed to send success message")?; } Err(e) => { let msg = DaemonMessage::Error { @@ -284,7 +352,7 @@ async fn process_message( message: format!("Session not found: {}", session), }; tx.send(msg).await.context("Failed to send error message")?; - return Ok(()); + return Ok(None); } }; @@ -293,23 +361,39 @@ async fn process_message( let msg = DaemonMessage::SessionDetails { session: info }; tx.send(msg).await.context("Failed to send session info")?; } + + ClientMessage::Shutdown => { + info!("Received shutdown request"); + // Send success message to client before shutting down + tx.send(DaemonMessage::Success { + message: "Daemon shutting down".to_string(), + }) + .await + .context("Failed to send shutdown acknowledgment")?; + + // Return signal to break the main server loop + return Ok(Some(true)); + } } - Ok(()) + Ok(None) } /// Send a message to the client. -async fn send_message( - writer: &mut W, - msg: &DaemonMessage, -) -> Result<()> { +async fn send_message(writer: &mut W, msg: &DaemonMessage) -> Result<()> { // Serialize the message to JSON let json = serde_json::to_string(msg).context("Failed to serialize message")?; - + // Write the message followed by a newline - writer.write_all(json.as_bytes()).await.context("Failed to write message")?; - writer.write_all(b"\n").await.context("Failed to write newline")?; + writer + .write_all(json.as_bytes()) + .await + .context("Failed to write message")?; + writer + .write_all(b"\n") + .await + .context("Failed to write newline")?; writer.flush().await.context("Failed to flush writer")?; - + Ok(()) }