diff --git a/egui_node_graph/src/editor_ui.rs b/egui_node_graph/src/editor_ui.rs index c3f2fd4..4c42df3 100644 --- a/egui_node_graph/src/editor_ui.rs +++ b/egui_node_graph/src/editor_ui.rs @@ -55,7 +55,7 @@ where ValueType: WidgetValueTrait, NodeTemplate: NodeTemplateTrait, - DataType: DataTypeTrait, + DataType: DataTypeTrait, { #[must_use] pub fn draw_graph_editor( @@ -150,13 +150,9 @@ where } /* Draw connections */ - - let connection_color = if ui.visuals().dark_mode { - color_from_hex("#efefef").unwrap() - } else { - color_from_hex("#bbbbbb").unwrap() - }; if let Some((_, ref locator)) = self.connection_in_progress { + let port_type = self.graph.any_param_type(*locator).unwrap(); + let connection_color = port_type.data_type_color(&self.user_state); let start_pos = port_locations[locator]; let (src_pos, dst_pos) = match locator { AnyParameterId::Output(_) => (start_pos, cursor_pos), @@ -166,6 +162,11 @@ where } for (input, output) in self.graph.iter_connections() { + let port_type = self + .graph + .any_param_type(AnyParameterId::Output(output)) + .unwrap(); + let connection_color = port_type.data_type_color(&self.user_state); let src_pos = port_locations[&AnyParameterId::Output(output)]; let dst_pos = port_locations[&AnyParameterId::Input(input)]; draw_connection(ui.painter(), src_pos, dst_pos, connection_color); @@ -298,7 +299,7 @@ where >, UserResponse: UserResponseTrait, ValueType: WidgetValueTrait, - DataType: DataTypeTrait, + DataType: DataTypeTrait, { pub const MAX_NODE_SIZE: [f32; 2] = [200.0, 200.0]; @@ -410,10 +411,11 @@ where let port_right = outer_rect.right(); #[allow(clippy::too_many_arguments)] - fn draw_port( + fn draw_port( ui: &mut Ui, graph: &Graph, node_id: NodeId, + user_state: &UserState, port_pos: Pos2, responses: &mut Vec>, param_id: AnyParameterId, @@ -421,7 +423,7 @@ where ongoing_drag: Option<(NodeId, AnyParameterId)>, is_connected_input: bool, ) where - DataType: DataTypeTrait, + DataType: DataTypeTrait, UserResponse: UserResponseTrait, { let port_type = graph.any_param_type(param_id).unwrap(); @@ -438,7 +440,7 @@ where let port_color = if resp.hovered() { Color32::WHITE } else { - port_type.data_type_color() + port_type.data_type_color(user_state) }; ui.painter() .circle(port_rect.center(), 5.0, port_color, Stroke::none()); @@ -484,6 +486,7 @@ where ui, self.graph, self.node_id, + user_state, pos_left, &mut responses, AnyParameterId::Input(*param), @@ -505,6 +508,7 @@ where ui, self.graph, self.node_id, + user_state, pos_right, &mut responses, AnyParameterId::Output(*param), diff --git a/egui_node_graph/src/traits.rs b/egui_node_graph/src/traits.rs index 376d870..d054968 100644 --- a/egui_node_graph/src/traits.rs +++ b/egui_node_graph/src/traits.rs @@ -15,9 +15,9 @@ pub trait WidgetValueTrait { /// This trait must be implemented by the `DataType` generic parameter of the /// [`Graph`]. This trait tells the library how to visually expose data types /// to the user. -pub trait DataTypeTrait: PartialEq + Eq { +pub trait DataTypeTrait: PartialEq + Eq { // The associated port color of this datatype - fn data_type_color(&self) -> egui::Color32; + fn data_type_color(&self, user_state: &UserState) -> egui::Color32; // The name of this datatype fn name(&self) -> &str; diff --git a/egui_node_graph_example/src/app.rs b/egui_node_graph_example/src/app.rs index 1c2c812..563751c 100644 --- a/egui_node_graph_example/src/app.rs +++ b/egui_node_graph_example/src/app.rs @@ -89,8 +89,8 @@ pub struct MyGraphState { // =========== Then, you need to implement some traits ============ // A trait for the data types, to tell the library how to display them -impl DataTypeTrait for MyDataType { - fn data_type_color(&self) -> egui::Color32 { +impl DataTypeTrait for MyDataType { + fn data_type_color(&self, _user_state: &MyGraphState) -> egui::Color32 { match self { MyDataType::Scalar => egui::Color32::from_rgb(38, 109, 211), MyDataType::Vec2 => egui::Color32::from_rgb(238, 207, 109),