diff --git a/egui_node_graph/src/editor_ui.rs b/egui_node_graph/src/editor_ui.rs index 358d0f1..401b462 100644 --- a/egui_node_graph/src/editor_ui.rs +++ b/egui_node_graph/src/editor_ui.rs @@ -15,11 +15,17 @@ pub type PortLocations = std::collections::HashMap; #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum NodeResponse { ConnectEventStarted(NodeId, AnyParameterId), - ConnectEventEnded(AnyParameterId), + ConnectEventEnded { + output: OutputId, + input: InputId, + }, CreatedNode(NodeId), SelectNode(NodeId), DeleteNode(NodeId), - DisconnectEvent(InputId), + DisconnectEvent { + output: OutputId, + input: InputId, + }, /// Emitted when a node is interacted with, and should be raised RaiseNode(NodeId), User(UserResponse), @@ -52,10 +58,10 @@ where ValueType = ValueType, >, UserResponse: UserResponseTrait, - ValueType: WidgetValueTrait, + ValueType: WidgetValueTrait, NodeTemplate: NodeTemplateTrait, - DataType: DataTypeTrait, + DataType: DataTypeTrait, { #[must_use] pub fn draw_graph_editor( @@ -152,13 +158,9 @@ where } /* Draw connections */ - - let connection_color = if ui.visuals().dark_mode { - color_from_hex("#efefef").unwrap() - } else { - color_from_hex("#bbbbbb").unwrap() - }; if let Some((_, ref locator)) = self.connection_in_progress { + let port_type = self.graph.any_param_type(*locator).unwrap(); + let connection_color = port_type.data_type_color(&self.user_state); let start_pos = port_locations[locator]; let (src_pos, dst_pos) = match locator { AnyParameterId::Output(_) => (start_pos, cursor_pos), @@ -168,6 +170,11 @@ where } for (input, output) in self.graph.iter_connections() { + let port_type = self + .graph + .any_param_type(AnyParameterId::Output(output)) + .unwrap(); + let connection_color = port_type.data_type_color(&self.user_state); let src_pos = port_locations[&AnyParameterId::Output(output)]; let dst_pos = port_locations[&AnyParameterId::Input(input)]; draw_connection(ui.painter(), src_pos, dst_pos, connection_color); @@ -175,29 +182,17 @@ where /* Handle responses from drawing nodes */ + // Some responses generate additional responses when processed. These + // are stored here to report them back to the user. + let mut extra_responses: Vec> = Vec::new(); + for response in delayed_responses.iter().copied() { match response { NodeResponse::ConnectEventStarted(node_id, port) => { self.connection_in_progress = Some((node_id, port)); } - NodeResponse::ConnectEventEnded(locator) => { - let in_out = match ( - self.connection_in_progress - .map(|(_node, param)| param) - .take() - .expect("Cannot end drag without in-progress connection."), - locator, - ) { - (AnyParameterId::Input(input), AnyParameterId::Output(output)) - | (AnyParameterId::Output(output), AnyParameterId::Input(input)) => { - Some((input, output)) - } - _ => None, - }; - - if let Some((input, output)) = in_out { - self.graph.add_connection(output, input) - } + NodeResponse::ConnectEventEnded { input, output } => { + self.graph.add_connection(output, input) } NodeResponse::CreatedNode(_) => { //Convenience NodeResponse for users @@ -206,7 +201,12 @@ where self.selected_node = Some(node_id); } NodeResponse::DeleteNode(node_id) => { - self.graph.remove_node(node_id); + let removed = self.graph.remove_node(node_id); + extra_responses.extend( + removed + .into_iter() + .map(|(input, output)| NodeResponse::DisconnectEvent { input, output }), + ); self.node_positions.remove(node_id); // Make sure to not leave references to old nodes hanging if self.selected_node.map(|x| x == node_id).unwrap_or(false) { @@ -214,15 +214,11 @@ where } self.node_order.retain(|id| *id != node_id); } - NodeResponse::DisconnectEvent(input_id) => { - let corresp_output = self - .graph - .connection(input_id) - .expect("Connection data should be valid"); - let other_node = self.graph.get_input(input_id).node(); - self.graph.remove_connection(input_id); + NodeResponse::DisconnectEvent { input, output } => { + let other_node = self.graph.get_input(input).node(); + self.graph.remove_connection(input); self.connection_in_progress = - Some((other_node, AnyParameterId::Output(corresp_output))); + Some((other_node, AnyParameterId::Output(output))); } NodeResponse::RaiseNode(node_id) => { let old_pos = self @@ -239,6 +235,11 @@ where } } + // Push any responses that were generated during response handling. + // These are only informative for the end-user and need no special + // treatment here. + delayed_responses.extend(extra_responses); + /* Mouse input handling */ // This locks the context, so don't hold on to it for too long. @@ -299,8 +300,8 @@ where ValueType = ValueType, >, UserResponse: UserResponseTrait, - ValueType: WidgetValueTrait, - DataType: DataTypeTrait, + ValueType: WidgetValueTrait, + DataType: DataTypeTrait, { pub const MAX_NODE_SIZE: [f32; 2] = [200.0, 200.0]; @@ -325,15 +326,12 @@ where let mut responses = Vec::new(); let background_color; - let titlebar_color; let text_color; if ui.visuals().dark_mode { background_color = color_from_hex("#3f3f3f").unwrap(); - titlebar_color = background_color.lighten(0.8); text_color = color_from_hex("#fefefe").unwrap(); } else { background_color = color_from_hex("#ffffff").unwrap(); - titlebar_color = background_color.lighten(0.8); text_color = color_from_hex("#505050").unwrap(); } @@ -363,6 +361,7 @@ where .text_style(TextStyle::Button) .color(text_color), )); + ui.add_space(8.0); // The size of the little cross icon }); ui.add_space(margin.y); title_height = ui.min_size().y; @@ -375,7 +374,13 @@ where if self.graph.connection(param_id).is_some() { ui.label(param_name); } else { - self.graph[param_id].value.value_widget(¶m_name, ui); + responses.extend( + self.graph[param_id] + .value + .value_widget(¶m_name, ui) + .into_iter() + .map(NodeResponse::User), + ); } let height_after = ui.min_rect().bottom(); input_port_heights.push((height_before + height_after) / 2.0); @@ -406,10 +411,11 @@ where let port_right = outer_rect.right(); #[allow(clippy::too_many_arguments)] - fn draw_port( + fn draw_port( ui: &mut Ui, graph: &Graph, node_id: NodeId, + user_state: &UserState, port_pos: Pos2, responses: &mut Vec>, param_id: AnyParameterId, @@ -417,7 +423,7 @@ where ongoing_drag: Option<(NodeId, AnyParameterId)>, is_connected_input: bool, ) where - DataType: DataTypeTrait, + DataType: DataTypeTrait, UserResponse: UserResponseTrait, { let port_type = graph.any_param_type(param_id).unwrap(); @@ -434,14 +440,21 @@ where let port_color = if resp.hovered() { Color32::WHITE } else { - port_type.data_type_color() + port_type.data_type_color(user_state) }; ui.painter() .circle(port_rect.center(), 5.0, port_color, Stroke::none()); if resp.drag_started() { if is_connected_input { - responses.push(NodeResponse::DisconnectEvent(param_id.assume_input())); + let input = param_id.assume_input(); + let corresp_output = graph + .connection(input) + .expect("Connection data should be valid"); + responses.push(NodeResponse::DisconnectEvent { + input: param_id.assume_input(), + output: corresp_output, + }); } else { responses.push(NodeResponse::ConnectEventStarted(node_id, param_id)); } @@ -454,7 +467,13 @@ where && resp.hovered() && ui.input().pointer.any_released() { - responses.push(NodeResponse::ConnectEventEnded(param_id)); + match (param_id, origin_param) { + (AnyParameterId::Input(input), AnyParameterId::Output(output)) + | (AnyParameterId::Output(output), AnyParameterId::Input(input)) => { + responses.push(NodeResponse::ConnectEventEnded { input, output }); + } + _ => { /* Ignore in-in or out-out connections */ } + } } } } @@ -480,6 +499,7 @@ where ui, self.graph, self.node_id, + user_state, pos_left, &mut responses, AnyParameterId::Input(*param), @@ -501,6 +521,7 @@ where ui, self.graph, self.node_id, + user_state, pos_right, &mut responses, AnyParameterId::Output(*param), @@ -511,7 +532,7 @@ where } // Draw the background shape. - // NOTE: This code is a bit more involve than it needs to be because egui + // NOTE: This code is a bit more involved than it needs to be because egui // does not support drawing rectangles with asymmetrical round corners. let (shape, outline) = { @@ -524,7 +545,10 @@ where let titlebar = Shape::Rect(RectShape { rect: titlebar_rect, rounding, - fill: titlebar_color, + fill: self.graph[self.node_id] + .user_data + .titlebar_color(ui, self.node_id, self.graph, user_state) + .unwrap_or_else(|| background_color.lighten(0.8)), stroke: Stroke::none(), }); diff --git a/egui_node_graph/src/graph_impls.rs b/egui_node_graph/src/graph_impls.rs index 7f898a6..9c26881 100644 --- a/egui_node_graph/src/graph_impls.rs +++ b/egui_node_graph/src/graph_impls.rs @@ -63,18 +63,36 @@ impl Graph { output_id } - pub fn remove_node(&mut self, node_id: NodeId) { - self.connections - .retain(|i, o| !(self.outputs[*o].node == node_id || self.inputs[i].node == node_id)); - let inputs: SVec<_> = self[node_id].input_ids().collect(); - for input in inputs { + /// Removes a node from the graph with given `node_id`. This also removes + /// any incoming or outgoing connections from that node + /// + /// This function returns the list of connections that has been removed + /// after deleting this node as input-output pairs. Note that one of the two + /// ids in the pair (the one on `node_id`'s end) will be invalid after + /// calling this function. + pub fn remove_node(&mut self, node_id: NodeId) -> Vec<(InputId, OutputId)> { + let mut disconnect_events = vec![]; + + self.connections.retain(|i, o| { + if self.outputs[*o].node == node_id || self.inputs[i].node == node_id { + disconnect_events.push((i, *o)); + false + } else { + true + } + }); + + // NOTE: Collect is needed because we can't borrow the input ids while + // we remove them inside the loop. + for input in self[node_id].input_ids().collect::>() { self.inputs.remove(input); } - let outputs: SVec<_> = self[node_id].output_ids().collect(); - for output in outputs { + for output in self[node_id].output_ids().collect::>() { self.outputs.remove(output); } self.nodes.remove(node_id); + + disconnect_events } pub fn remove_connection(&mut self, input_id: InputId) -> Option { diff --git a/egui_node_graph/src/traits.rs b/egui_node_graph/src/traits.rs index 5be8908..2f5c6d6 100644 --- a/egui_node_graph/src/traits.rs +++ b/egui_node_graph/src/traits.rs @@ -4,18 +4,50 @@ use super::*; /// [`Graph`]. The trait allows drawing custom inline widgets for the different /// types of the node graph. pub trait WidgetValueTrait { - fn value_widget(&mut self, param_name: &str, ui: &mut egui::Ui); + type Response; + /// This method will be called for each input parameter with a widget. The + /// return value is a vector of custom response objects which can be used + /// to implement handling of side effects. If unsure, the response Vec can + /// be empty. + fn value_widget(&mut self, param_name: &str, ui: &mut egui::Ui) -> Vec; } /// This trait must be implemented by the `DataType` generic parameter of the /// [`Graph`]. This trait tells the library how to visually expose data types /// to the user. -pub trait DataTypeTrait: PartialEq + Eq { - // The associated port color of this datatype - fn data_type_color(&self) -> egui::Color32; +pub trait DataTypeTrait: PartialEq + Eq { + /// The associated port color of this datatype + fn data_type_color(&self, user_state: &UserState) -> egui::Color32; - // The name of this datatype - fn name(&self) -> &str; + /// The name of this datatype. Return type is specified as Cow because + /// some implementations will need to allocate a new string to provide an + /// answer while others won't. + /// + /// ## Example (borrowed value) + /// Use this when you can get the name of the datatype from its fields or as + /// a &'static str. Prefer this method when possible. + /// ```rust + /// pub struct DataType { name: String } + /// + /// impl DataTypeTrait<()> for DataType { + /// fn name(&self) -> std::borrow::Cow { + /// Cow::Borrowed(&self.name) + /// } + /// } + /// ``` + /// + /// ## Example (owned value) + /// Use this when you can't derive the name of the datatype from its fields. + /// ```rust + /// pub struct DataType { some_tag: i32 } + /// + /// impl DataTypeTrait<()> for DataType { + /// fn name(&self) -> std::borrow::Cow { + /// Cow::Owned(format!("Super amazing type #{}", self.some_tag)) + /// } + /// } + /// ``` + fn name(&self) -> std::borrow::Cow; } /// This trait must be implemented for the `NodeData` generic parameter of the @@ -43,6 +75,18 @@ where ) -> Vec> where Self::Response: UserResponseTrait; + + /// Set background color on titlebar + /// If the return value is None, the default color is set. + fn titlebar_color( + &self, + _ui: &egui::Ui, + _node_id: NodeId, + _graph: &Graph, + _user_state: &Self::UserState, + ) -> Option { + None + } } /// This trait can be implemented by any user type. The trait tells the library diff --git a/egui_node_graph_example/src/app.rs b/egui_node_graph_example/src/app.rs index 6e0379a..e8e47dc 100644 --- a/egui_node_graph_example/src/app.rs +++ b/egui_node_graph_example/src/app.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{borrow::Cow, collections::HashMap}; use eframe::egui::{self, DragValue, TextStyle}; use egui_node_graph::*; @@ -89,18 +89,18 @@ pub struct MyGraphState { // =========== Then, you need to implement some traits ============ // A trait for the data types, to tell the library how to display them -impl DataTypeTrait for MyDataType { - fn data_type_color(&self) -> egui::Color32 { +impl DataTypeTrait for MyDataType { + fn data_type_color(&self, _user_state: &MyGraphState) -> egui::Color32 { match self { MyDataType::Scalar => egui::Color32::from_rgb(38, 109, 211), MyDataType::Vec2 => egui::Color32::from_rgb(238, 207, 109), } } - fn name(&self) -> &str { + fn name(&self) -> Cow<'_, str> { match self { - MyDataType::Scalar => "scalar", - MyDataType::Vec2 => "2d vector", + MyDataType::Scalar => Cow::Borrowed("scalar"), + MyDataType::Vec2 => Cow::Borrowed("2d vector"), } } } @@ -250,7 +250,8 @@ impl NodeTemplateIter for AllMyNodeTemplates { } impl WidgetValueTrait for MyValueType { - fn value_widget(&mut self, param_name: &str, ui: &mut egui::Ui) { + type Response = MyResponse; + fn value_widget(&mut self, param_name: &str, ui: &mut egui::Ui) -> Vec { // This trait is used to tell the library which UI to display for the // inline parameter widgets. match self { @@ -270,6 +271,8 @@ impl WidgetValueTrait for MyValueType { }); } } + // This allows you to return your responses from the inline widgets. + Vec::new() } } @@ -380,7 +383,7 @@ impl eframe::App for NodeGraphExample { Err(err) => format!("Execution error: {}", err), }; ctx.debug_painter().text( - egui::pos2(10.0, 10.0), + egui::pos2(10.0, 35.0), egui::Align2::LEFT_TOP, text, TextStyle::Button.resolve(&ctx.style()),