import contextlib
import json
import os
from qtpy.QtCore import QDir, QPoint, QPointF, Qt, Signal
from qtpy.QtWidgets import QFileDialog, QGraphicsScene
from . import exceptions
from . import style as style_module
from .connection import Connection
from .connection_graphics_object import ConnectionGraphicsObject
from .data_model_registry import DataModelRegistry
from .exceptions import ConnectionDataTypeFailure
from .node import Node
from .node_data import NodeDataModel, NodeDataType
from .node_graphics_object import NodeGraphicsObject
from .port import Port, PortType
from .type_converter import TypeConverter
def locate_node_at(scene_point, scene, view_transform):
items = scene.items(scene_point, Qt.IntersectsItemShape,
Qt.DescendingOrder, view_transform)
filtered_items = [item for item in items
if isinstance(item, NodeGraphicsObject)]
return filtered_items[0].node if filtered_items else None
class FlowSceneModel:
'''
A model representing a flow scene
Emits the following signals upon connection/node creation/deletion::
connection_created : Signal(Connection)
connection_deleted : Signal(Connection)
node_created : Signal(Node)
node_deleted : Signal(Node)
'''
connection_created = Signal(Connection)
connection_deleted = Signal(Connection)
node_created = Signal(Node)
node_deleted = Signal(Node)
def __init__(self, registry=None, **kwargs):
super().__init__(**kwargs)
self._connections = []
self._nodes = {}
if registry is None:
registry = DataModelRegistry()
self._registry = registry
# this connection should come first
self.connection_created.connect(self._setup_connection_signals)
self.connection_created.connect(self._send_connection_created_to_nodes)
self.connection_deleted.connect(self._send_connection_deleted_to_nodes)
@property
def registry(self) -> DataModelRegistry:
"""
Registry
Returns
-------
value : DataModelRegistry
"""
return self._registry
@registry.setter
def registry(self, registry: DataModelRegistry):
self._registry = registry
@property
def nodes(self) -> dict:
"""
All nodes in the scene
Returns
-------
value : dict
Key: uuid
Value: Node
"""
return dict(self._nodes)
@property
def connections(self) -> list:
"""
All connections in the scene
Returns
-------
conn : list of Connection
"""
return list(self._connections)
def clear_scene(self):
# Manual node cleanup. Simply clearing the holding datastructures
# doesn't work, the code crashes when there are both nodes and
# connections in the scene. (The data propagation internal logic tries
# to propagate data through already freed connections.)
for conn in list(self._connections):
self.delete_connection(conn)
for node in list(self._nodes.values()):
self.remove_node(node)
def save(self, file_name=None):
if file_name is None:
file_name, _ = QFileDialog.getSaveFileName(
None, "Save Flow Scene", QDir.homePath(),
"Flow Scene Files (*.flow)")
if file_name:
file_name = str(file_name)
if not file_name.endswith(".flow"):
file_name += ".flow"
with open(file_name, 'wt') as f:
json.dump(self.__getstate__(), f)
def load(self, file_name=None):
if file_name is None:
file_name, _ = QFileDialog.getOpenFileName(
None, "Open Flow Scene", QDir.homePath(),
"Flow Scene Files (*.flow)")
if not os.path.exists(file_name):
return
with open(file_name, 'rt') as f:
doc = json.load(f)
self.__setstate__(doc)
def __getstate__(self) -> dict:
"""
Save scene state to a dictionary
Returns
-------
value : dict
"""
scene_json = {}
nodes_json_array = []
connection_json_array = []
for node in self._nodes.values():
nodes_json_array.append(node.__getstate__())
scene_json["nodes"] = nodes_json_array
for connection in self._connections:
connection_json = connection.__getstate__()
if connection_json:
connection_json_array.append(connection_json)
scene_json["connections"] = connection_json_array
return scene_json
def restore_connection(self, connection_json: dict) -> Connection:
"""
Restore a connection. To be overridden in a subclass.
Parameters
----------
connection_json : dict
Returns
-------
value : Connection
"""
def __setstate__(self, doc: dict):
"""
Load scene state from a dictionary
Parameters
----------
doc : dict
Dictionary of settings
"""
self.clear_scene()
for node in doc["nodes"]:
self.restore_node(node)
for connection in doc["connections"]:
self.restore_connection(connection)
def _setup_connection_signals(self, conn: Connection):
"""
Setup connection signals
Parameters
----------
conn : Connection
"""
conn.connection_made_incomplete.connect(
self.connection_deleted.emit, Qt.UniqueConnection)
def _send_connection_created_to_nodes(self, conn: Connection):
"""
Send connection created to nodes
Parameters
----------
conn : Connection
"""
input_node, output_node = conn.nodes
assert input_node is not None
assert output_node is not None
output_node.model.output_connection_created(conn)
input_node.model.input_connection_created(conn)
def _send_connection_deleted_to_nodes(self, conn: Connection):
"""
Send connection deleted to nodes
Parameters
----------
conn : Connection
"""
input_node, output_node = conn.nodes
assert input_node is not None
assert output_node is not None
output_node.model.output_connection_deleted(conn)
input_node.model.input_connection_deleted(conn)
def iterate_over_nodes(self):
"""
Generator: Iterate over nodes
"""
for node in self._nodes.values():
yield node
def iterate_over_node_data(self):
"""
Generator: Iterate over node data
"""
for node in self._nodes.values():
yield node.model
def iterate_over_node_data_dependent_order(self):
"""
Generator: Iterate over node data dependent order
"""
visited_nodes = []
# A leaf node is a node with no input ports, or all possible input ports empty
def is_node_leaf(node, model):
for port in node[PortType.input].values():
if not port.connections:
return False
return True
# Iterate over "leaf" nodes
for node in self._nodes.values():
model = node.model
if is_node_leaf(node, model):
yield model
visited_nodes.append(node)
def are_node_inputs_visited_before(node, model):
for port in node[PortType.input].values():
for conn in port.connections:
other = conn.get_node(PortType.output)
if visited_nodes and other == visited_nodes[-1]:
return False
return True
# Iterate over dependent nodes
while len(self._nodes) != len(visited_nodes):
for node in self._nodes.values():
if node in visited_nodes and node is not visited_nodes[-1]:
continue
model = node.model
if are_node_inputs_visited_before(node, model):
yield model
visited_nodes.append(node)
def to_digraph(self):
'''
Create a networkx digraph
Returns
-------
digraph : networkx.DiGraph
The generated DiGraph
Raises
------
ImportError
If networkx is unavailable
'''
import networkx
graph = networkx.DiGraph()
for node in self._nodes.values():
graph.add_node(node)
for node in self._nodes.values():
graph.add_edges_from(conn.nodes
for conn in node.state.all_connections)
return graph
def remove_node(self, node: Node):
"""
Remove node
Parameters
----------
node : Node
"""
self.node_deleted.emit(node)
for conn in list(node.state.all_connections):
self.delete_connection(conn)
node._cleanup()
del self._nodes[node.id]
def _restore_node(self, node_json: dict) -> Node:
"""
Restore a node from a state dictionary
Parameters
----------
node_json : dict
Returns
-------
value : Node
"""
with self._new_node_context(node_json["model"]["name"]) as node:
...
return node
@contextlib.contextmanager
def _new_node_context(self, data_model_name, *, emit_placed=False):
'Context manager: creates Node/yields it, handling necessary Signals'
data_model = self._registry.create(data_model_name)
node = Node(data_model)
yield node
self._nodes[node.id] = node
if emit_placed:
self.node_placed.emit(node)
self.node_created.emit(node)
def restore_node(self, node_json: dict) -> Node:
"""
Restore a node from a state dictionary
Parameters
----------
node_json : dict
Returns
-------
value : Node
"""
name = node_json["model"]["name"]
with self._new_node_context(name, emit_placed=True) as node:
node.__setstate__(node_json)
return node
def delete_connection(self, connection: Connection):
"""
Delete connection
Parameters
----------
connection : Connection
"""
try:
self._connections.remove(connection)
except ValueError:
...
else:
connection.remove_from_nodes()
connection._cleanup()
[docs]class FlowScene(FlowSceneModel, QGraphicsScene):
connection_hover_left = Signal(Connection)
connection_hovered = Signal(Connection, QPoint)
# Node has been added to the scene.
# Connect to self signal if need a correct position of node.
node_placed = Signal(Node)
# node_context_menu(node, scene_position, screen_position)
node_context_menu = Signal(Node, QPointF, QPoint)
node_double_clicked = Signal(Node)
node_hover_left = Signal(Node)
node_hovered = Signal(Node, QPoint)
node_moved = Signal(Node, QPointF)
def __init__(self, registry=None, style=None, parent=None,
allow_node_creation=True, allow_node_deletion=True):
'''
Create a new flow scene
Parameters
----------
registry : DataModelRegistry, optional
style : StyleCollection, optional
parent : QObject, optional
'''
super().__init__(parent=parent)
self._registry = registry or self._registry
if style is None:
style = style_module.default_style
self._style = style
self.allow_node_deletion = allow_node_creation
self.allow_node_creation = allow_node_deletion
self.setItemIndexMethod(QGraphicsScene.NoIndex)
def _cleanup(self):
self.clear_scene()
@property
def allow_node_creation(self):
return self._allow_node_creation
@allow_node_creation.setter
def allow_node_creation(self, allow):
self._allow_node_creation = bool(allow)
@property
def allow_node_deletion(self):
return self._allow_node_deletion
@allow_node_deletion.setter
def allow_node_deletion(self, allow):
self._allow_node_deletion = bool(allow)
@property
def style_collection(self) -> style_module.StyleCollection:
'The style collection for the scene'
return self._style
def locate_node_at(self, point, transform):
return locate_node_at(point, self, transform)
[docs] def create_connection(self, port_a: Port, port_b: Port = None, *,
converter: TypeConverter = None,
check_cycles=True) -> Connection:
"""
Create a connection
Parameters
----------
port_a : Port
The first port, either input or output
port_b : Port, optional
The second port, opposite of the type of port_a
converter : TypeConverter, optional
The type converter to use for data propagation
check_cycles : bool, optional
Ensures that creating the connection would not introduce a cycle
Returns
-------
value : Connection
Raises
------
NodeConnectionFailure
If it is not possible to create the connection
ConnectionDataTypeFailure
If port data types are not compatible
"""
if port_a is not None and port_b is not None:
in_port = port_a if port_a.port_type == PortType.input else port_b
out_port = port_b if port_a.port_type == PortType.input else port_a
if in_port.data_type.id != out_port.data_type.id:
if not converter:
# If not specified, try to get it from the registry
converter = self.registry.get_type_converter(out_port.data_type,
in_port.data_type)
if (not converter or (converter.type_in != out_port.data_type
or converter.type_out != in_port.data_type)):
raise ConnectionDataTypeFailure(
f'{in_port.data_type} and {out_port.data_type} are not compatible'
)
connection = Connection(port_a=port_a, port_b=port_b, style=self._style, converter=converter)
if port_a is not None:
port_a.add_connection(connection)
if port_b is not None:
port_b.add_connection(connection)
if port_a and port_b and check_cycles:
# In the case of a fully-specified connection, ensure adding the
# connection would not create a cycle in the graph. For
# partially-specified connections (i.e., one port only), the
# validation happens in the NodeConnectionInteraction
node_a, node_b = port_a.node, port_b.node
if node_a.has_connection_by_port_type(node_b, port_b.port_type):
raise exceptions.ConnectionCycleFailure(
f'Connecting {node_a} and {node_b} would introduce a '
f'cycle in the graph'
)
cgo = ConnectionGraphicsObject(self, connection)
# after self function connection points are set to node port
connection.graphics_object = cgo
self._connections.append(connection)
if not port_a or not port_b:
# This connection isn't truly created yet. It's only partially
# created. Thus, don't send the connection_created(...) signal.
connection.connection_completed.connect(self.connection_created.emit)
else:
in_port, out_port = connection.ports
out_port.node.on_data_updated(out_port)
self.connection_created.emit(connection)
return connection
[docs] def create_connection_by_index(
self, node_in: Node, port_index_in: int,
node_out: Node, port_index_out: int,
converter: TypeConverter) -> Connection:
"""
Create connection
Parameters
----------
node_in : Node
port_index_in : int
node_out : Node
port_index_out : int
converter : TypeConverter
Returns
-------
value : Connection
"""
port_in = node_in[PortType.input][port_index_in]
port_out = node_out[PortType.output][port_index_out]
return self.create_connection(port_out, port_in, converter=converter)
[docs] def restore_connection(self, connection_json: dict) -> Connection:
"""
Restore a connection.
Parameters
----------
connection_json : dict
Returns
-------
value : Connection
"""
node_in_id = connection_json["in_id"]
node_out_id = connection_json["out_id"]
port_index_in = connection_json["in_index"]
port_index_out = connection_json["out_index"]
node_in = self._nodes[node_in_id]
node_out = self._nodes[node_out_id]
def get_converter():
converter = connection_json.get("converter", None)
if converter is None:
return None
in_type = NodeDataType(
id=converter["in"]["id"],
name=converter["in"]["name"],
)
out_type = NodeDataType(
id=converter["out"]["id"],
name=converter["out"]["name"],
)
return self._registry.get_type_converter(out_type, in_type)
connection = self.create_connection_by_index(
node_in, port_index_in,
node_out, port_index_out,
converter=get_converter())
# Note: the connection_created(...) signal has already been sent by
# create_connection(...)
return connection
[docs] def create_node(self, data_model: NodeDataModel) -> Node:
"""
Create a node in the scene
Parameters
----------
data_model : NodeDataModel
Returns
-------
value : Node
"""
with self._new_node_context(data_model.name) as node:
ngo = NodeGraphicsObject(self, node)
node.graphics_object = ngo
return node
[docs] def restore_node(self, node_json: dict) -> Node:
"""
Restore a node from a state dictinoary
Parameters
----------
node_json : dict
Returns
-------
value : Node
"""
# NOTE: Overrides FlowSceneModel.restore_node
with self._new_node_context(node_json["model"]["name"]) as node:
node.graphics_object = NodeGraphicsObject(self, node)
node.__setstate__(node_json)
return node
[docs] def auto_arrange(self, layout='bipartite', scale=700, align='horizontal',
**kwargs):
'''
Automatically arrange nodes with networkx, if available
Raises
------
ImportError
If networkx is unavailable
'''
import networkx
dig = self.to_digraph()
layouts = {
name: getattr(networkx.layout, '{}_layout'.format(name))
for name in ('bipartite', 'circular', 'kamada_kawai', 'random',
'shell', 'spring', 'spectral')
}
try:
layout_func = layouts[layout]
except KeyError:
raise ValueError('Unknown layout type {}'.format(layout)) from None
layout = layout_func(dig, **kwargs)
for node, pos in layout.items():
pos_x, pos_y = pos
node.position = (pos_x * scale, pos_y * scale)
[docs] def selected_nodes(self) -> list:
"""
Selected nodes
Returns
-------
value : list of Node
"""
return [item.node for item in self.selectedItems()
if isinstance(item, NodeGraphicsObject)]