from gpt_graph.utils.uuid_ex import uuid_ex
import copy
from gpt_graph.utils.get_nested_value import get_nested_value
[docs]
class Group:
[docs]
def __init__(
self,
nodes=None,
node_graph=None,
filter_cri=None,
group_key=None,
parent_filter_cri=None,
parent_group_key=None,
name=None,
type=None,
gid=None,
if_yield=None,
):
self.uuid = uuid_ex(obj=self)
self.nodes = nodes
self.node_graph = node_graph
self.filter_cri = filter_cri
self.group_key = group_key
self.parent_filter_cri = parent_filter_cri
self.parent_group_key = parent_group_key or "node_id"
self.name = name or "group"
self.type = type
self.gid = gid
self.contains = {}
self.clones = []
self.prototype = None
self.if_refresh = None
self.if_yield = if_yield or False
self._sub_group_index = None
[docs]
def initialize(self):
"""
Set nodes and contains back to original empty state.
"""
self.nodes = None
self.contains = {}
[docs]
def clone(self, if_copy_nodes=False):
"""
Create a deep copy of the Group instance, excluding nodes, contains, prototype, and clones
resembles Component.clone
"""
deep_copy_keys = [
"filter_cri",
"group_key",
"parent_filter_cri",
"parent_group_key",
"name",
"type",
"gid",
]
if if_copy_nodes:
deep_copy_keys.append("nodes")
self.if_refresh = False
new_group = Group()
# Copy all attributes except nodes, contains, prototype, and clones
for attr, value in vars(self).items():
if attr in deep_copy_keys:
setattr(new_group, attr, copy.deepcopy(value))
# Set the prototype to the original group
new_group.prototype = self
# Add the new group to the original group's clones list
self.clones.append(new_group)
return new_group
[docs]
def reset_uuid(self, if_recursive=False):
from gpt_graph.core.closure import Closure
return Closure.reset_uuid(self, if_recursive=if_recursive)
[docs]
def run(
self,
filter_cri=None,
group_key=None, # can both be str or dict[str: func] for grouping criteria
parent_filter_cri=None,
parent_group_key=None,
nodes=None,
node_graph=None,
if_refresh=None,
):
"""
Execute the grouping process on nodes based on various criteria.
Parameters:
filter_cri (dict, optional): Criteria to filter nodes.
group_key (str or dict, optional): Key or dictionary of functions to group nodes.
parent_filter_cri (dict, optional): Criteria to filter parent nodes.
parent_group_key (str, optional): Key to group parent nodes.
nodes (list, optional): List of nodes to process. If None, uses filtered nodes.
node_graph (object, optional): Graph object containing nodes and their relationships.
if_refresh (bool, optional): Whether to refresh the node list using filter criteria.
Returns:
list: List of Group objects representing the grouped nodes.
This method performs the following steps:
1. Updates instance attributes if new values are provided.
2. Filters and groups parent nodes if parent_filter_cri is specified.
3. Filters child nodes based on filter_cri or uses provided nodes.
4. Groups child nodes based on their parent groups or directly by group_key.
5. Further subdivides groups based on the specified group_key.
6. Creates Group objects for each final group of nodes.
The grouping can be based on parent-child relationships and/or node attributes.
It supports both simple key-based grouping and complex grouping using custom functions.
Raises:
ValueError: If node_graph is not provided and not set in the instance.
"""
# Update nodes and node_graph if provided
if nodes is not None:
self.nodes = nodes
if node_graph is not None:
self.node_graph = node_graph
if self.node_graph is None:
raise ValueError("node_graph must be provided")
self.filter_cri = filter_cri or self.filter_cri
if parent_filter_cri is not None:
self.parent_filter_cri = parent_filter_cri
if group_key is not None:
self.group_key = group_key
if parent_group_key is not None:
self.parent_group_key = parent_group_key
# Step 1: Filter parent nodes if parent_filter_cri is provided
parent_groups = {}
if self.parent_filter_cri:
parent_nodes = self.node_graph.filter_nodes(self.parent_filter_cri)
# Step 2: Group parent nodes by group key
for parent in parent_nodes:
key_value = get_nested_value(parent, self.parent_group_key)
if key_value not in parent_groups:
parent_groups[key_value] = []
parent_groups[key_value].append(parent)
# Step 3: Filter nodes based on filter criteria or use provided nodes
if if_refresh is None:
if self.if_refresh is None:
if_refresh = True
else:
if_refresh = self.if_refresh
self.nodes = (
self.node_graph.filter_nodes(self.filter_cri) if if_refresh else self.nodes
)
# Step 4: Group nodes based on the group of their parent nodes or directly by group key
if self.parent_filter_cri:
node_ids = [n["node_id"] for n in self.nodes]
child_groups = {key: [] for key in parent_groups.keys()}
for key, parents in parent_groups.items():
nodes = self.node_graph.filter_nodes(
filter_cri={"node_id": {"$in": node_ids}},
if_inclusive=True,
parents=parents,
)
child_groups[key].extend(nodes)
else:
child_groups = {"all": self.nodes}
# Step 5: Further group child nodes based on the group key
def create_group(nodes, index):
subgroup_name = f"{self.name}.gp{index}"
new_group = Group(
nodes=nodes,
node_graph=self.node_graph,
name=subgroup_name,
type=self.type,
gid=f"{self.gid}.{index}" if self.gid else str(index),
)
return new_group
final_grouped_nodes = []
group_index = 0
for parent_key, nodes in child_groups.items():
if isinstance(self.group_key, dict):
subgroups = {}
for node in nodes:
for key, func in self.group_key.items():
key_value = func(get_nested_value(node, key))
if key_value not in subgroups:
subgroups[key_value] = []
subgroups[key_value].append(node)
for subgroup_nodes in subgroups.values():
new_group = create_group(subgroup_nodes, group_index)
final_grouped_nodes.append(new_group)
group_index += 1
elif self.group_key:
subgroups = {}
for node in nodes:
key_value = get_nested_value(node, self.group_key)
if key_value not in subgroups:
subgroups[key_value] = []
subgroups[key_value].append(node)
for subgroup_nodes in subgroups.values():
new_group = create_group(subgroup_nodes, group_index)
final_grouped_nodes.append(new_group)
group_index += 1
else:
new_group = create_group(nodes, group_index)
final_grouped_nodes.append(new_group)
group_index += 1
self.contains = final_grouped_nodes
return self.contains
[docs]
def get_nodes(self, if_yield=None):
"""
used in Step.run to be input to the Step.cp_run_func, if the Group is added in Step.input_schema
Retrieve nodes from the group.
Args:
if_yield (bool, optional): Whether to yield nodes iteratively.
Defaults to instance's if_yield value.
Returns:
If if_yield is True:
Single subgroup's nodes or None when exhausted.
If if_yield is False:
List of all nodes from all subgroups.
Runs grouping process if nodes haven't been processed yet.
"""
if if_yield is None:
if_yield = self.if_yield
if self.nodes is None:
self.run(node_graph=self.node_graph)
if if_yield:
if self._sub_group_index is None:
self._sub_group_index = 0
if self._sub_group_index < len(self.contains):
nodes = self.contains[self._sub_group_index].nodes
self._sub_group_index += 1
return nodes
else:
return None
else:
# Non-yield mode: return all nodes every time
input_key_nodes = [sub_group.nodes for sub_group in self.contains]
return input_key_nodes
if __name__ == "__main__":
# Import necessary modules
from gpt_graph.utils.uuid_ex import uuid_ex
from gpt_graph.utils.get_nested_value import get_nested_value
# Assuming we have a NodeGraph class implementation
class NodeGraph:
def __init__(self):
self.nodes = []
def add_node(self, node):
self.nodes.append(node)
def filter_nodes(self, filter_cri):
# Simple implementation for demonstration
return [node for node in self.nodes if all(node.get(k) == v for k, v in filter_cri.items())]
# Create a sample node graph
node_graph = NodeGraph()
node_graph.add_node({"node_id": 1, "type": "A", "value": 10})
node_graph.add_node({"node_id": 2, "type": "A", "value": 20})
node_graph.add_node({"node_id": 3, "type": "B", "value": 30})
node_graph.add_node({"node_id": 4, "type": "B", "value": 40})
# Create a Group instance
group = Group(node_graph=node_graph)
# Example 1: Simple grouping by type
result1 = group.run(group_key="type")
print("Example 1: Grouping by type")
for subgroup in result1:
print(f"Subgroup {subgroup.gid}: {[node['node_id'] for node in subgroup.nodes]}")
# Example 2: Filtering and grouping
result2 = group.run(filter_cri={"type": "A"}, group_key="value")
print("\nExample 2: Filtering type 'A' and grouping by value")
for subgroup in result2:
print(f"Subgroup {subgroup.gid}: {[node['node_id'] for node in subgroup.nodes]}")
# Example 3: Complex grouping using a dictionary of functions
def value_category(value):
return "Low" if value < 25 else "High"
result3 = group.run(group_key={"value": value_category})
print("\nExample 3: Grouping by value category (Low/High)")
for subgroup in result3:
print(f"Subgroup {subgroup.gid}: {[node['node_id'] for node in subgroup.nodes]}")
# Example 4: Using get_nodes method
group.run(group_key="type") # Run grouping first
print("\nExample 4: Using get_nodes method")
all_nodes = group.get_nodes(if_yield=False)
for i, subgroup_nodes in enumerate(all_nodes):
print(f"Subgroup {i}: {[node['node_id'] for node in subgroup_nodes]}")
# Example 5: Using get_nodes method with yield
print("\nExample 5: Using get_nodes method with yield")
group.if_yield = True
while True:
nodes = group.get_nodes()
if nodes is None:
break
print(f"Yielded nodes: {[node['node_id'] for node in nodes]}")