Source code for generation.generator

import numpy as np
import itertools
import sys
from mpi4py import MPI
import sympy
from sympy.core.sympify import kernS
import os
import pprint

import esr.generation.simplifier as simplifier
import esr.generation.utils as utils
from esr.fitting.sympy_symbols import sympy_locs

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()


[docs] def is_float(string): """Determine whether a string is a float or not Args: :string (str): The string to check Returns: bool: Whether the string is a float (True) or not (False). """ try: float(eval(string)) return True except Exception: return False
[docs] class Node: def __init__(self, t): self.type = t self.parent = None self.left = None self.right = None self.op = None self.val = None self.node_name = None self.tree = None
[docs] def copy(self): new_node = Node(self.type) new_node.parent = self.parent new_node.left = self.left new_node.right = self.right return new_node
[docs] def is_used(self): if (self.type == 0) and (self.parent is None): return False elif (self.type == 1) and (self.left is None): return False elif (self.type == 2) and (self.left is None) and (self.right is None): return False return True
[docs] def assign_op(self, op): self.op = op v = op if v.lstrip("-").isdigit(): self.val = int(v) elif v.lstrip("-").lstrip("/").isnumeric(): self.val = float(v) else: self.val = v
[docs] class DecoratedNode: def __init__(self, fun, basis_functions, parent_op=None, parent=None): if fun is not None: self.expr = fun self.type = type(fun) self.constant = fun.is_number self.degree = len(fun.args) self.op = fun.__class__.__name__ self.parent_op = parent_op self.parent = parent self.tree = None if self.constant: self.val = str(fun) elif fun.is_symbol: self.val = fun.name else: self.val = None if self.op == 'Pow' and fun.args[1] == 2 and 'square' in basis_functions[1]: self.op = 'Square' self.children = [DecoratedNode( fun.args[0], basis_functions, parent_op=self.op, parent=self)] elif self.op == 'Pow' and fun.args[1] == 3 and 'cube' in basis_functions[1]: self.op = 'Cube' self.children = [DecoratedNode( fun.args[0], basis_functions, parent_op=self.op, parent=self)] elif self.op == 'Pow' and fun.args[1] == 1/2 and ('sqrt' in basis_functions[1] or 'sqrt_abs' in basis_functions[1]): self.op = 'Sqrt' self.children = [DecoratedNode( fun.args[0], basis_functions, parent_op=self.op, parent=self)] elif self.op == 'Mul' and len(fun.args) == 2 and fun.args[1].__class__.__name__ == 'Pow' and fun.args[1].args[1] == -1: self.op = 'Div' self.children = [DecoratedNode(fun.args[0], basis_functions, parent_op=self.op, parent=self), DecoratedNode(fun.args[1].args[0], basis_functions, parent_op=self.op, parent=self)] elif self.op == 'Pow' and fun.args[1] == -1 and 'inv' in basis_functions[1]: self.op = 'Inv' self.children = [DecoratedNode( fun.args[0], basis_functions, parent_op=self.op, parent=self)] else: if (len(fun.args) > 2): f = fun.as_two_terms() self.children = [DecoratedNode(f[0], basis_functions, parent_op=self.op, parent=self), DecoratedNode(f[1], basis_functions, parent_op=self.op, parent=self)] else: self.children = [DecoratedNode( a, basis_functions, parent_op=self.op, parent=self) for a in fun.args]
[docs] def from_node_list(self, idx, nodes, basis_functions, parent_op=None, parent=None): self.expr = nodes[idx].op self.type = nodes[idx].type self.constant = is_float(nodes[idx].val) self.degree = nodes[idx].type self.op = nodes[idx].op self.parent_op = parent_op self.parent = parent self.tree = None if self.constant: self.val = str(nodes[idx].val) else: self.val = nodes[idx].val if nodes[idx].right is not None: self.children = [DecoratedNode( None, basis_functions), DecoratedNode(None, basis_functions)] self.children[0].from_node_list( nodes[idx].left, nodes, basis_functions, parent_op=nodes[idx].op, parent=self) self.children[1].from_node_list( nodes[idx].right, nodes, basis_functions, parent_op=nodes[idx].op, parent=self) elif nodes[idx].left is not None: self.children = [DecoratedNode(None, basis_functions)] self.children[0].from_node_list( nodes[idx].left, nodes, basis_functions, parent_op=nodes[idx].op, parent=self) else: self.children = []
[docs] def is_unity(self): try: f = float(self.val) return f == float(1) except Exception: return False
[docs] def count_nodes(self, basis_functions): """ :basis_functions (list): list of lists basis functions. basis_functions[0] are nullary, basis_functions[1] are unary and basis_functions[2] are binary operators """ return len(self.to_list(basis_functions))
[docs] def to_list(self, basis_functions): """ """ if self.degree == 0: return [str(self.val)] elif self.degree == 1: return [self.op] + self.children[0].to_list(basis_functions) #  Sqrt(x) instead of pow(x, 1/2) elif self.op == "Pow" and (self.children[1].type == sympy.core.numbers.Half) and (("sqrt" in basis_functions[1]) or ("sqrt_abs" in basis_functions[1])): if ("sqrt" in basis_functions[1]): return ["sqrt"] + self.children[0].to_list(basis_functions) else: return ["sqrt_abs"] + self.children[0].to_list(basis_functions) #  Square(x) instead of pow(x, 2) if possible elif self.op == "Pow" and (self.children[1].val == str(2)) and "square" in basis_functions[1]: return ["square"] + self.children[0].to_list(basis_functions) # pow(x,2) instead of Square(x) if necessary elif self.op == "Square" and "square" not in basis_functions[1]: return ["pow"] + self.children[0].to_list(basis_functions) + ["2"] #  Cube(x) instead of pow(x, 3) elif self.op == "Pow" and (self.children[1].val == str(3)) and "cube" in basis_functions[1]: return ["cube"] + self.children[0].to_list(basis_functions) # pow(x,2) instead of Square(x) if necessary elif self.op == "Cube" and "cube" not in basis_functions[1]: return ["pow"] + self.children[0].to_list(basis_functions) + ["3"] #  Inv(x) instead of pow(x, -1) elif self.op == "Pow" and (self.children[1].type == sympy.core.numbers.NegativeOne) and ("inv" in basis_functions[1]): return ["Inv"] + self.children[0].to_list(basis_functions) # Deal with * inv = / elif self.op == "Mul" and self.children[0].op == "Pow" and (self.children[1].type == sympy.core.numbers.NegativeOne) and ("/" in basis_functions[2]): return ["Mul"] + self.children[1].to_list(basis_functions) # Deal with / inv = * elif self.op == "Div" and self.children[0].op == "Pow" and (self.children[1].type == sympy.core.numbers.NegativeOne) and ("*" in basis_functions[2]): return ["Mul"] + self.children[1].to_list(basis_functions) #  Multiply or divide by one doesn't do anything elif self.op == "Mul" and (self.children[0].is_unity() or self.children[1].is_unity()): if self.children[0].is_unity(): return self.children[1].to_list(basis_functions) else: return self.children[0].to_list(basis_functions) # Don't keep abs after pow or sqrt elif self.op == "Abs" and self.parent.op in ["Sqrt", "Pow"]: return self.children[0].to_list(basis_functions) elif self.op == "Div" and (self.children[0] == 1 or self.children[1] == 1): pass elif self.op == "Add" and self.children[1].op == "Mul" and (self.children[1].children[0].op == "NegativeOne" or self.children[1].children[1].op == "NegativeOne"): if self.children[1].children[0].op == "NegativeOne": return ["Sub"] + self.children[0].to_list(basis_functions) + self.children[1].children[1].to_list(basis_functions) else: return ["Sub"] + self.children[0].to_list(basis_functions) + self.children[1].children[0].to_list(basis_functions) else: r = [self.op] for c in self.children: r = r + c.to_list(basis_functions) return r
[docs] def get_lineage(self): p = [self.op, self.parent_op] q = self.parent while q is not None: p += [q.parent_op] q = q.parent p.reverse() p = [tuple(p)] for c in self.children: p += c.get_lineage() return p
[docs] def get_sibling_lineage(self): #  First get direct lineage of nodes p = [self.op, self.parent_op] v = [self.val] if self.parent_op is None: v += [None] else: v += [self.parent.val] q = self.parent while q is not None: p += [q.parent_op] if q.parent_op is None: v += [None] else: v += [q.parent.val] q = q.parent p.reverse() v.reverse() # Now add the siblings if len(self.children) == 1: p += [(self.children[0].op, None)] v += [(self.children[0].val, None)] elif len(self.children) > 1: p += [tuple([c.op for c in self.children])] v += [tuple([c.val for c in self.children])] p = [tuple(p)] v = [tuple(v)] for c in self.children: if len(c.children) > 0: pp, vv = c.get_sibling_lineage() p += pp v += vv return p, v
[docs] def get_siblings(self): if self.parent is not None and len(self.parent.children) > 1: p = [tuple([c.op for c in self.parent.children])] else: p = [(self.op, 'None')] for c in self.children: p += c.get_siblings() return p
[docs] def check_tree(s): """ Given a candidate string of 0, 1 and 2s, see whether one can make a function out of this Args: :s (str): string comprised of 0, 1 and 2 representing tree of nullary, unary and binary nodes Returns: :success (bool): whether candidate string can form a valid tree (True) or not (False) :part_considered (str): string of length <= s, where s[:len(part_considered)] = part_considered :tree (list): list of Node objects corresponding to string s """ tree = [Node(t) for t in s] for i in range(len(s)-1): success = False if (tree[i].type == 2) or (tree[i].type == 1): # Add to the left if possible tree[i].left = i+1 tree[i+1].parent = i success = True else: #  try to go up the tree j = tree[i].parent while not success: if (tree[j].type == 2) and (tree[j].right is None): # Add to right of node if possible tree[j].right = i+1 tree[i+1].parent = j success = True elif (tree[j].parent is None): #  Check if can't move up the tree any higher break # Go up the tree to this node's parent j = tree[j].parent if not success: break if len(s) > 1: #  Need to check for parents without left nodes which should have them if success: lefts = [t.left for t in tree if t.type == 1 or t.type == 2] if None in lefts: success = False #  Need to check for parents without right nodes which should have them if success: rights = [t.right for t in tree if t.type == 2] if None in rights: success = False # This will allow us to delete any trees which start with this part_considered = s[:i+2] else: success = True part_considered = None return success, part_considered, tree
[docs] def get_allowed_shapes(compl): """ Find the shapes of all allowed trees containing compl nodes Args: :compl (int): complexity of tree = number of nodes Returns: :cand (list): list of strings comprised of 0, 1 and 2 representing valid trees of nullary, unary and binary nodes """ if rank == 0: #  Make all graphs with this complexity cand = np.array([list(t) for t in itertools.product( '012', repeat=compl)], dtype=int) # Graph cannot start with a type0 node if compl > 1: cand = cand[cand[:, 0] != 0] # Graph must end at a type0 node cand = cand[cand[:, -1] == 0] #  The penultimate node cannot be of type2 if cand.shape[1] > 1: cand = cand[cand[:, -2] != 2] msk = np.ones(cand.shape[0], dtype=bool) for i in range(cand.shape[0]): if not msk[i]: pass success, part_considered, tree = check_tree(cand[i, :]) if not success: msk[i] = False # Remove other candidates where this string appears at the start m = cand[:, :len(part_considered)] == part_considered[None, :] m = np.prod(m, axis=1) msk[np.where(m)] = False cand = cand[msk, :] else: cand = None cand = comm.bcast(cand, root=0) return cand
[docs] def node_to_string(idx, tree, labels): """Convert a tree with labels into a string giving function Args: :idx (int): index of tree to consider :tree (list): list of Node objects corresponding to the tree :labels (list): list of strings giving node labels of tree Returns: Function as a string """ if len(tree) == 0: return '0' elif tree[idx].type == 0: return labels[idx] elif tree[idx].type == 1: return labels[idx] + '(' + node_to_string(tree[idx].left, tree, labels) + ')' elif tree[idx].type == 2: if labels[idx] in ['*', '/', '-', '+']: return '(' + node_to_string(tree[idx].left, tree, labels) + ')' + labels[idx] + \ '(' + node_to_string(tree[idx].right, tree, labels) + ')' else: return labels[idx] + '(' + node_to_string(tree[idx].left, tree, labels) + \ ',' + node_to_string(tree[idx].right, tree, labels) + ')' return
[docs] def string_to_expr(s, kern=False, evaluate=False, locs=None): """Convert a string giving function into a sympy object Args: :s (str): string representation of the function considered :kern (bool): whether to use sympy's kernS function or sympify :evaluate (bool): whether to use powsimp, factor and subs :locs (dict): dictionary of string:sympy objects. If None, will create here Returns: :expr (sympy object): expression corresponding to s """ s = s.replace('[', '(') s = s.replace(']', ')') s = s.replace('Sqrt', 'sqrt') s = s.replace('*^', '*10^') if locs is None: locs = sympy_locs if kern: expr = kernS(s) else: expr = sympy.sympify(s, evaluate=False, locals=locs) if evaluate: expr = expr.powsimp(expr) expr = expr.factor() expr = expr.subs(1.0, 1) return expr
[docs] def check_operators(nodes, basis_functions): """Check whether all operators in the tree are in the basis Args: :nodes (DecoratedNode): Node representation of the function tree :basis_functions (list): list of lists basis functions. basis_functions[0] are nullary, basis_functions[1] are unary and basis_functions[2] are binary operators Returns: :all_in_basis (bool): Whether all functions in tree are in basis """ sympy_numerics = ['Number', 'Float', 'Rational', 'Integer', 'AlgebraicNumber', 'NumberSymbol', 'RealNumber', 'igcd', 'ilcm', 'seterr', 'Zero', 'One', 'NegativeOne', 'Half', 'NaN', 'Infinity', 'NegativeInfinity', 'ComplexInfinity', 'Exp1', 'ImaginaryUnit', 'Pi', 'EulerGamma', 'Catalan', 'GoldenRatio', 'TribonacciConstant', 'mod_inverse'] sympy_numerics = [s.lower() for s in sympy_numerics] labels = nodes.to_list(basis_functions) for i in range(len(labels)): if labels[i] == 'Add' and '+' in basis_functions[2]: labels[i] = '+' elif labels[i] == 'Sub' and '-' in basis_functions[2]: labels[i] = '-' elif labels[i] == 'Mul' and '*' in basis_functions[2]: labels[i] = '*' elif labels[i] == 'Div' and '/' in basis_functions[2]: labels[i] = '/' elif labels[i].lower() in sympy_numerics or is_float(labels[i]): labels[i] = 'a' elif labels[i].startswith('a') and labels[i][1:].isdigit(): labels[i] = 'a' elif labels[i].startswith('x') and labels[i][1:].isdigit(): labels[i] = 'x' else: labels[i] = labels[i].lower() flat_basis = [item for sublist in basis_functions for item in sublist] all_in_basis = all([ll in flat_basis for ll in labels]) return all_in_basis
[docs] def string_to_node(s, basis_functions, locs=None, evalf=False, allow_eval=True, check_ops=False): """Convert a string giving function into a tree with labels Args: :s (str): string representation of the function considered :basis_functions (list): list of lists basis functions. basis_functions[0] are nullary, basis_functions[1] are unary and basis_functions[2] are binary operators :locs (dict): dictionary of string:sympy objects. If None, will create here :evalf (bool): whether to run evalf() on function (default=False) :allow_eval (bool, default=True): whether to run the (kernS=False and evaluate=True) option :check_ops (bool, default=False): whether to check all operators appear in basis functions Returns: :tree (list): list of Node objects corresponding to the tree :labels (list): list of strings giving node labels of tree """ expr = [None] * 4 nodes = [None] * 4 if check_ops: all_in_basis = [False] * 4 c = np.full(4, np.nan) if allow_eval: i = 0 try: expr[i] = string_to_expr(s, kern=False, evaluate=True, locs=locs) if evalf: expr[i] = expr[i].evalf() nodes[i] = DecoratedNode(expr[i], basis_functions) c[i] = nodes[i].count_nodes(basis_functions) if check_ops: all_in_basis[i] = check_operators(nodes[i], basis_functions) except Exception: c[i] = np.nan i = 1 try: expr[i] = string_to_expr(s, kern=False, evaluate=False, locs=locs) if evalf: expr[i] = expr[i].evalf() nodes[i] = DecoratedNode(expr[i], basis_functions) c[i] = nodes[i].count_nodes(basis_functions) if check_ops: all_in_basis[i] = check_operators(nodes[i], basis_functions) except Exception: c[i] = np.nan i = 2 try: expr[i] = string_to_expr(s, kern=True, evaluate=True, locs=locs) if evalf: expr[i] = expr[i].evalf() nodes[i] = DecoratedNode(expr[i], basis_functions) c[i] = nodes[i].count_nodes(basis_functions) if check_ops: all_in_basis[i] = check_operators(nodes[i], basis_functions) except Exception: c[i] = np.nan i = 3 try: expr[i] = string_to_expr(s, kern=True, evaluate=False, locs=locs) if evalf: expr[i] = expr[i].evalf() nodes[i] = DecoratedNode(expr[i], basis_functions) c[i] = nodes[i].count_nodes(basis_functions) if check_ops: all_in_basis[i] = check_operators(nodes[i], basis_functions) except Exception: c[i] = np.nan if check_ops and any(all_in_basis): for i in range(len(all_in_basis)): if not all_in_basis[i]: c[i] = np.nan i = np.nanargmin(c) return expr[i], nodes[i], int(c[i])
[docs] def update_tree(tree, labels, try_idx, basis_functions): """Try to combine exponentials and powers to make simpler representations of functions Args: :tree (list): list of Node objects corresponding to tree of function :labels (list): list of strings giving node labels of tree :try_idx (int): when we have multiple substituions we can attempt, this indicates which one to try :basis_functions (list): list of lists basis functions. basis_functions[0] are nullary, basis_functions[1] are unary and basis_functions[2] are binary operators Returns: :new_labels (list): list of strings giving node labels of new tree :new_shape (list): list of 0, 1 and 2 representing whether nodes in new tree are nullary, unary or binary :nadded (int): number of new functions added """ pow_set = ["square", "cube", "sqrt_abs", "inv"] pow_num = {"square": '*2', "cube": '*3', "sqrt_abs": '/2', "inv": '*-1'} exp_set = ["log_abs", "exp", "pow_abs"] #  log_abs comes first, exp comes second, pow_abs can go in either order exp_ord = {"log_abs": 1, "exp": 2, "pow_abs": 3} common_pow = list(set(labels) & set(pow_set)) common_exp = list(set(labels) & set(exp_set)) special_idx = [] diff1_idx = [] diff2_idx = [] num1 = [] num2 = [] # See if we have any of the correct patterns in the labels list if len(common_pow) != 0 and len(common_exp) != 0: for i in range(len(labels)): if labels[i] in common_exp: if (exp_ord[labels[i]] == 1) or (exp_ord[labels[i]] == 3): success = False if (i < len(labels) - 1) and (labels[i+1] in common_pow): special_idx.append(i) j = 0 success = False s = '*1' while not success: j += 1 if (i >= len(labels) - j): success = True elif labels[i+j] in common_pow: n = sympy.sympify( (s + pow_num[labels[i+j]][0] + '(' + pow_num[labels[i+j]][1:] + ')')[1:]) if n.is_integer or (1/n).is_integer: s += pow_num[labels[i+j]][0] + \ '(' + pow_num[labels[i+j]][1:] + ')' else: success = True else: success = True diff1_idx.append(j-1) diff2_idx.append(0) n = sympy.sympify(s[1:]) if n.is_integer: s = '*' + str(n) else: s = '/' + str(1/n) num1.append(s) num2.append(None) if (exp_ord[labels[i]] == 2) or (exp_ord[labels[i]] == 3): success = False if (i > 0) and (labels[i-1] in common_pow): if i not in special_idx: special_idx.append(i) j = 0 success = False s = '*1' while not success: j += 1 if (i - j) < 0: success = True elif labels[i-j] in common_pow: n = sympy.sympify( (s + pow_num[labels[i-j]][0] + '(' + pow_num[labels[i-j]][1:] + ')')[1:]) if n.is_integer or (1/n).is_integer: s += pow_num[labels[i-j]][0] + \ '(' + pow_num[labels[i-j]][1:] + ')' else: success = True else: success = True if len(diff2_idx) != len(special_idx): diff1_idx.append(0) diff2_idx.append(j-1) else: diff2_idx[-1] = j-1 n = sympy.sympify(s[1:]) if n.is_integer: s = '*' + str(n) else: s = '/' + str(1/n) if len(num2) != len(special_idx): num1.append(None) num2.append(s) else: num2[-1] = s new_shape = None new_labels = None nadded = 0 if len(special_idx) > try_idx: i = special_idx[try_idx] if (exp_ord[labels[i]] == 1): n = num1[try_idx] d = diff1_idx[try_idx] elif (exp_ord[labels[i]] == 2): n = num2[try_idx] d = diff2_idx[try_idx] elif (exp_ord[labels[i]] == 3): n1 = num1[try_idx] n2 = num2[try_idx] d1 = diff1_idx[try_idx] d2 = diff2_idx[try_idx] if (exp_ord[labels[i]] == 3) and (n1 is None or n1[0] in basis_functions[2]) and (n2 is None or n2[0] in basis_functions[2]): # Combine the two numbers s = '*1' if n1 is not None: s += n1 if n2 is not None: s += n2 s = sympy.sympify(s[1:]) if s.is_integer: n = '*' + str(s) else: n = '/' + str(1/s) orig_parents = np.array([t.parent for t in tree]) orig_shape = [t.type for t in tree] # Start of exponent j = np.argwhere(orig_parents[i+2:] == i) j = j[0, 0] + i + 2 # End of exponent if tree[i].parent is None: k = len(labels) else: k = np.argwhere(orig_parents[i+2:] <= tree[i].parent) if len(k) == 0: k = len(labels) else: k = k[0, 0] + i + 2 if int(n[1:]) == 1: new_labels = ( # First part of tree (up to the d2 operators which make number) labels[:i-d2] + # The pow comes next [labels[i]] + # Skip the d1 operators which give the number labels[i+d1+1:] ) new_shape = orig_shape[:i-d2] + \ [orig_shape[i]] + \ orig_shape[i+d1+1:] else: new_labels = ( # First part of tree (up to the d2 operators which make number) labels[:i-d2] + # The pow comes next [labels[i]] + # Skip the d2 operators which give the number labels[i+d1+1:j] + # Add in a * or / [n[0]] + # The original exponent on left of * or / labels[j:k] + # Put number at right of * or / [n[1:]] + # Rest of tree labels[k:] ) new_shape = orig_shape[:i-d2] + \ [orig_shape[i]] + \ orig_shape[i+d1+1:j] + \ [2] + \ orig_shape[j:k] + \ [0] + \ orig_shape[k:] nadded += 1 elif (exp_ord[labels[i]] != 3) and (n[0] in basis_functions[2]): orig_parents = np.array([t.parent for t in tree]) orig_shape = [t.type for t in tree] if i > 0: j = np.argwhere(orig_parents[i+1:] <= tree[i].parent) if len(j) == 0: j = len(labels) else: j = j[0, 0] + i + 1 else: j = len(labels) if (i > 0) and (labels[orig_parents[i]] in ["+", "-"]) and n.startswith('*-') and exp_ord[labels[i]] == 1: inv_op = "-" if (labels[orig_parents[i]] == "+") else "+" if (tree[orig_parents[i]].right == i) and (inv_op in basis_functions[2]): if int(n[2:]) == 1: new_labels = ( # First part of tree labels[:orig_parents[i]] + [inv_op] + # Left part of + unchanged labels[orig_parents[i]+1:i] + # Skip the d operators which give the number [labels[i]] + labels[i+d+1:] ) new_shape = orig_shape[:i] + [orig_shape[i]] + \ orig_shape[i+d+1:j] + orig_shape[j:] else: new_labels = ( # First part of tree labels[:orig_parents[i]] + [inv_op] + # Left part of + unchanged labels[orig_parents[i]+1:i] + # Add * or / to right of + [n[0]] + # Put "log_abs" at left of * or / [labels[i]] + # Skip the d operators which give the number labels[i+d+1:j] + # Put number at right of * or / and add rest of tree [n[2:]] + labels[j:] ) new_shape = orig_shape[:orig_parents[i]] + [2] + \ orig_shape[orig_parents[i]+1:i] + \ [2] + \ [orig_shape[i]] + \ orig_shape[i+d+1:j] + \ [0] + orig_shape[j:] nadded += 1 elif (labels[orig_parents[i]] == "+") and (inv_op in basis_functions[2]): # Index of where right side of + ends k = np.argwhere( orig_parents[tree[orig_parents[i]].right+1:] <= tree[i].parent) if len(k) == 0: k = len(labels) else: k = k[0, 0] + tree[orig_parents[i]].right + 1 if int(n[2:]) == 1: new_labels = ( # First part of tree labels[:orig_parents[i]] + [inv_op] + # Move right part of + to the left labels[tree[orig_parents[i]].right:k] + # Put "log_abs" at start of right of + labels[orig_parents[i]+1:i+1] + #  Skip the d operators which give the number labels[i+d+1:tree[orig_parents[i]].right] + # Rest of tree labels[k:] ) new_shape = orig_shape[:orig_parents[i]] + [2] + \ orig_shape[tree[orig_parents[i]].right:k] + \ orig_shape[orig_parents[i]+1:i+1] + \ orig_shape[i+d+1:tree[orig_parents[i]].right] + \ orig_shape[k:] else: new_labels = ( # First part of tree labels[:orig_parents[i]] + [inv_op] + # Move right part of + to the left labels[tree[orig_parents[i]].right:k] + # Add * or / to right of + [n[0]] + # Put "log_abs" at left of * or / labels[orig_parents[i]+1:i+1] + # Skip the d operators which give the number labels[i+d+1:tree[orig_parents[i]].right] + # Put number at right of * or / and add rest of tree [n[2:]] + labels[k:] ) new_shape = orig_shape[:orig_parents[i]] + [2] + \ orig_shape[tree[orig_parents[i]].right:k] + \ [2] + \ orig_shape[orig_parents[i]+1:i+1] + \ orig_shape[i+d+1:tree[orig_parents[i]].right] + \ [0] + orig_shape[k:] nadded += 1 else: # TWO CASES: # (1) F - G -> (-1)*(n * H + G) / #  (2) F - G -> (-n)*H - G new_labels = [] new_shape = [] if (inv_op in basis_functions[2]): new_labels.append( # First part of tree labels[:orig_parents[i]] + # Add *(-1) before + or - ['*', '-1', inv_op] + # * or / the first term [n[0], n[2:]] + # Put "log_abs" at right of * or / [labels[i]] + # Skip the d operators which give the number labels[i+d+1:] ) new_shape.append( orig_shape[:orig_parents[i]] + [2, 0, 2] + [2, 0] + [orig_shape[i]] + orig_shape[i+d+1:] ) nadded += 1 new_labels.append( # First part of tree labels[:orig_parents[i]+1] + # * or / the first term [n[0], n[1:]] + # Put "log_abs" at right of * or / [labels[i]] + # Skip the d operators which give the number labels[i+d+1:] ) new_shape.append( orig_shape[:orig_parents[i]+1] + [2, 0] + [orig_shape[i]] + orig_shape[i+d+1:] ) nadded += 1 else: if exp_ord[labels[i]] == 1: if int(n[1:]) == 1: new_labels = ( # First part of tree labels[:i+1] + # Skip the d operators which give the number labels[i+d+1:] ) new_shape = orig_shape[:i+1] + \ orig_shape[i+d+1:] else: new_labels = ( # First part of tree labels[:i] + # * or / the "log_abs" [n[0]] + # Put "log_abs" at left of tree [labels[i]] + # Skip the d operators which give the number labels[i+d+1:j] + # Put number at right of * or / [n[1:]] + #  Rest of tree labels[j:] ) new_shape = orig_shape[:i] + \ [2] + \ [orig_shape[i]] + \ orig_shape[i+d+1:j] + \ [0] + \ orig_shape[j:] elif exp_ord[labels[i]] == 2: if int(n[1:]) == 1: new_labels = ( # First part of tree (up to the d operators which make number) labels[:i-d] + # The rest of the tree labels[i:] ) new_shape = orig_shape[:i-d] + \ orig_shape[i:] else: new_labels = ( # First part of tree (up to the d operators which make number) labels[:i-d] + # The exp comes next [labels[i]] + # * or / the argument of "exp" [n[0]] + # First part of argument of "exp" on left of * or / labels[i+1:j] + # Put number at right of * or / [n[1:]] + # Rest of tree labels[j:] ) new_shape = orig_shape[:i-d] + \ [orig_shape[i]] + \ [2] + \ orig_shape[i+1:j] + \ [0] + \ orig_shape[j:] nadded += 1 return new_labels, new_shape, nadded
[docs] def update_sums(tree, labels, try_idx, basis_functions): """Try to combine sums to make simpler representations of functions Args: :tree (list): list of Node objects corresponding to tree of function :labels (list): list of strings giving node labels of tree :try_idx (int): when we have multiple substituions we can attempt, this indicates which one to try :basis_functions (list): list of lists basis functions. basis_functions[0] are nullary, basis_functions[1] are unary and basis_functions[2] are binary operators Returns: :new_labels (list): list of strings giving node labels of new tree :new_shape (list): list of 0, 1 and 2 representing whether nodes in new tree are nullary, unary or binary :nadded (int): number of new functions added """ new_shape = None new_labels = None nadded = 0 if ("+" not in labels) and ("-" not in labels): return new_labels, new_shape, nadded # Find all the +s or -s which aren't children of other +s or -s plus_idx = [i for i in range( len(labels)) if labels[i] == "+" or labels[i] == "-"] plus_idx = [i for i in plus_idx if (tree[i].parent is None) or ( labels[tree[i].parent] not in ["+", "-"])] if try_idx >= len(plus_idx): return new_labels, new_shape, nadded i = plus_idx[try_idx] orig_parents = np.array([t.parent for t in tree]) orig_parents[0] = -1 # Find all terms in sum (sometimes this happens over multiple layers of tree if nested +s or -s) def get_sum(j, run_anyway): r = run_anyway if labels[j] in ["+", "-"]: sum_list = [] idx_list = [] s, idx, r2 = get_sum(tree[j].left, r) if r2: r = True if len(s) > 0: if type(s[0]) is str: sum_list.append(s) idx_list.append(idx) else: for k in range(len(s)): sum_list.append(s[k]) idx_list.append(idx[k]) s, idx, r2 = get_sum(tree[j].right, r) if r2: r = True if len(s) > 0: if type(s[0]) is str: sum_list.append(s) idx_list.append(idx) else: for k in range(len(s)): sum_list.append(s[k]) idx_list.append(idx[k]) else: r = True elif labels[j] == "*" and labels[tree[j].left].lstrip("-").isdigit(): temp_list, temp_idx, r2 = get_sum(tree[j].right, r) if r2: r = True sum_list = [] idx_list = [] for k in range(len(temp_list)): if type(temp_list[k][0]) in [str, np.str_]: sum_list += [temp_list[k]] * \ int(labels[tree[j].left].lstrip("-")) idx_list += [temp_idx[k]] * \ int(labels[tree[j].left].lstrip("-")) else: for s in temp_list[k]: sum_list += s * int(labels[tree[j].left].lstrip("-")) idx_list += temp_idx[k] * \ int(labels[tree[j].left].lstrip("-")) elif labels[j] == "*" and labels[tree[j].right].lstrip("-").isdigit(): temp_list, temp_idx, r2 = get_sum(tree[j].left, r) if r2: r = True sum_list = [] idx_list = [] for k in range(len(temp_list)): temp_idx[k][-1] = tree[j].right + 1 if type(temp_list[k][0]) in [str, np.str_]: sum_list += [temp_list[k]] * \ int(labels[tree[j].right].lstrip("-")) idx_list += [temp_idx[k]] * \ int(labels[tree[j].right].lstrip("-")) else: for s in temp_list[k]: sum_list += s * int(labels[tree[j].right].lstrip("-")) idx_list += temp_idx[k] * \ int(labels[tree[j].right].lstrip("-")) else: k = np.argwhere(orig_parents[j+1:] <= tree[j].parent) + j + 1 if len(k) == 0: k = len(labels) else: k = k[0, 0] sum_list = [labels[j:k]] idx_list = [[j, k, k]] return sum_list, idx_list, r # run_anyway checks to see if we have any 0* in the sum # this won't show up in all_s, but means we will want to # try to rewrite this sum all_s, all_idx, run_anyway = get_sum(i, False) all_start = [tree[j[0]].parent for j in all_idx] orig_parents = [tt.parent for tt in tree] children = [j for j in range(len(tree)) if tree[j].parent == i] # p last_child = children[-1] if last_child == len(tree) - 1: end_idx = len(tree) else: end_idx = np.argwhere(np.array(orig_parents[last_child+1:]) < i) # p if len(end_idx) == 0: end_idx = len(tree) else: end_idx = end_idx[0, 0] + last_child + 1 # Work out the signs of each term in the sum all_sign = [] neg_const = [] orig_parents = np.array([t.parent for t in tree]) for j in range(len(all_start)): n = 1 k = all_start[j] neg_const.append(False) if k is not None and labels[k] in ["*", "/"]: if labels[tree[k].left].lstrip("-").isdigit() and labels[tree[k].left].startswith("-"): n *= -1 neg_const[-1] = True elif labels[tree[k].right].lstrip("-").isdigit() and labels[tree[k].right].startswith("-"): n *= -1 neg_const[-1] = True if (labels[k] == "-") and (tree[k].right == all_idx[j][0]): n *= -1 while (tree[k].parent is not None) and (tree[k].parent >= i): if (labels[tree[k].parent] == "-") and (tree[tree[k].parent].right == k): n *= -1 if labels[tree[k].parent] in ["*", "/"]: if labels[tree[tree[k].parent].left].lstrip("-").isdigit() and labels[tree[tree[k].parent].left].startswith("-"): n *= -1 elif labels[tree[tree[k].parent].right].lstrip("-").isdigit() and labels[tree[tree[k].parent].right].startswith("-"): n *= -1 neg_const[-1] = True k = tree[k].parent all_sign.append(n) # Get unique terms in sum s = [] for ss in all_s: if len(ss) > 0 and isinstance(ss[0], list): s.append(tuple(list(**s))) else: s.append(tuple(list(ss))) s = sorted(set(s), key=s.index) s = [list(ss) for ss in s] new_labels = [] new_shape = [] if (len(s) != len(all_s)) or run_anyway: for j in range(len(s)): L = labels[:i] t = [tt.type for tt in tree[:i]] rep = [a for a in range(len(all_s)) if all_s[a] == s[j]] rep_val = np.array([all_sign[a] for a in rep], dtype=int).sum() uni = [all_s.index(s[a]) for a in range(len(s)) if s[a] != s[j]] # Add the unique stuff # Always adding to the right l_uni = [] n_uni = [] t_uni = [] for a in uni: nrep = [all_sign[b] for b in range(len(all_sign)) if (all_s[a] == all_s[b])] len_nrep = len(nrep) nrep = sum(nrep) if neg_const[a] and (len_nrep == 1): # If neg_const try to get the version with the - instead of + (or vice versa) left_idx = tree[tree[all_idx[a][0]].parent].left right_idx = tree[tree[all_idx[a][0]].parent].right if tree[tree[all_idx[a][0]].parent].left == all_idx[a][0]: if nrep == 1: l_uni = labels[all_idx[a][0]:all_idx[a][1]] + l_uni t_uni = [tt.type for tt in tree[all_idx[a] [0]:all_idx[a][1]]] + t_uni n_uni = n_uni + ['+'] elif nrep != 0: l_uni = [ '*', str(nrep)] + labels[all_idx[a][0]:all_idx[a][1]] + l_uni t_uni = [ 2, 0] + [tt.type for tt in tree[all_idx[a][0]:all_idx[a][1]]] + t_uni n_uni = n_uni + ['+'] else: if labels[left_idx].lstrip("-").isdigit() and labels[left_idx].startswith("-"): x = labels[left_idx].lstrip("-") if x == str(1) and nrep == 1: l_uni = labels[left_idx + 1:all_idx[a][1]] + l_uni t_uni = [ tt.type for tt in tree[left_idx+1:all_idx[a][1]]] + t_uni else: l_uni = labels[all_idx[a][0]:left_idx] + \ ["*", str(abs(nrep))] + \ labels[left_idx+1:all_idx[a][1]] + l_uni t_uni = [tt.type for tt in tree[all_idx[a][0]:left_idx]] + \ [2, 0] + \ [tt.type for tt in tree[left_idx + 1:all_idx[a][1]]] + t_uni elif right_idx is None: if nrep == 1: l_uni = labels[left_idx - 1:all_idx[a][1]] + l_uni t_uni = [ tt.type for tt in tree[left_idx-1:all_idx[a][1]]] + t_uni else: l_uni = labels[all_idx[a][0]:left_idx-1] + \ ["*", str(abs(nrep))] + \ labels[left_idx-1:all_idx[a][1]] + l_uni t_uni = [tt.type for tt in tree[all_idx[a][0]:left_idx-1]] + \ [2, 0] + \ [tt.type for tt in tree[left_idx - 1:all_idx[a][1]]] + t_uni else: x = labels[right_idx].lstrip("-") if x == str(1) and nrep == 1: l_uni = labels[all_idx[a][0]+1:right_idx] + \ labels[right_idx+1:all_idx[a][1]] + l_uni t_uni = [tt.type for tt in tree[all_idx[a][0]+1:right_idx]] + \ [tt.type for tt in tree[right_idx + 1:all_idx[a][1]]] + t_uni elif nrep != 0: l_uni = labels[all_idx[a][0]:right_idx] + \ ["*", str(abs(nrep))] + \ labels[right_idx+1:all_idx[a][1]] + l_uni t_uni = [tt.type for tt in tree[all_idx[a][0]:right_idx]] + \ [2, 0] + \ [tt.type for tt in tree[right_idx + 1:all_idx[a][1]]] + t_uni if all_sign[a] == 1: n_uni = n_uni + ['+'] else: n_uni = n_uni + ['-'] elif all_sign[a] == 1: if nrep == 1: l_uni = labels[all_idx[a][0]:all_idx[a][1]] + l_uni t_uni = [tt.type for tt in tree[all_idx[a] [0]:all_idx[a][1]]] + t_uni n_uni = n_uni + ['+'] elif nrep != 0: l_uni = ['*', str(nrep)] + \ labels[all_idx[a][0]:all_idx[a][1]] + l_uni t_uni = [ 2, 0] + [tt.type for tt in tree[all_idx[a][0]:all_idx[a][1]]] + t_uni n_uni = n_uni + ['+'] else: if abs(nrep) == 1: l_uni = labels[all_idx[a][0]:all_idx[a][1]] + l_uni t_uni = [tt.type for tt in tree[all_idx[a] [0]:all_idx[a][1]]] + t_uni elif nrep != 0: l_uni = ['*', str(abs(nrep))] + \ labels[all_idx[a][0]:all_idx[a][1]] + l_uni t_uni = [ 2, 0] + [tt.type for tt in tree[all_idx[a][0]:all_idx[a][1]]] + t_uni if nrep > 0: n_uni = n_uni + ['+'] elif nrep < 0: n_uni = n_uni + ['-'] l_rep = [] t_rep = [] if ((len(rep) == 1) or ("*" in basis_functions[2])) and (rep_val != 0): if rep_val != 1: l_rep = ['*', str(rep_val)] t_rep = [2, 0] for a in range(len(s[j])): l_rep.append(s[j][a]) t_rep.append(tree[labels.index(s[j][a])].type) if len(l_uni) == 0: L = L + l_uni + l_rep t = t + t_uni + t_rep else: L = L + n_uni + l_rep + l_uni t = t + [2] * len(n_uni) + t_rep + t_uni else: #  Now remove the +/- 0 if len(n_uni) == 0: L = L + ['0'] t = t + [0] elif n_uni[-1] == '+': L = L + n_uni[:-1] + l_uni t = t + [2] * (len(n_uni)-1) + t_uni else: if n_uni == ["-"] * len(n_uni): # If all the things added are negative, we can change the top node # and make them all +'s provided they are on the right of that node L = L + ["*", "-1"] + ["+"] * (len(n_uni)-1) + l_uni t = t + [2, 0] + [2] * (len(n_uni) - 1) + t_uni else: # Otherwise we can move the right hand side of one of the + nodes # down to bottom left to terminate the tree plus_idx = n_uni.index("+") l_uni = [] n_uni = [] t_uni = [] for k in range(len(uni)): if k != plus_idx: a = uni[k] nrep = [all_sign[b] for b in range( len(all_sign)) if (all_s[a] == all_s[b])] nrep = sum(nrep) if neg_const[a]: # If neg_const try to get the version with the - instead of + (or vice versa) left_idx = tree[all_idx[a][0]].left right_idx = tree[all_idx[a][0]].right if tree[tree[all_idx[a][0]].parent].left == all_idx[a][0]: if nrep == 1: l_uni = labels[all_idx[a] [0]:all_idx[a][1]] + l_uni t_uni = [ tt.type for tt in tree[all_idx[a][0]:all_idx[a][1]]] + t_uni n_uni = n_uni + ['+'] elif nrep != 0: l_uni = [ '*', str(nrep)] + labels[all_idx[a][0]:all_idx[a][1]] + l_uni t_uni = [ 2, 0] + [tt.type for tt in tree[all_idx[a][0]:all_idx[a][1]]] + t_uni n_uni = n_uni + ['+'] else: if (labels[left_idx].lstrip("-").isdigit() and labels[left_idx].startswith("-")): x = labels[left_idx].lstrip("-") if x == str(1) and nrep == 1: l_uni = labels[left_idx + 1:all_idx[a][1]] + l_uni t_uni = [ tt.type for tt in tree[left_idx+1:all_idx[a][1]]] + t_uni elif nrep != 0: l_uni = labels[all_idx[a][0]:left_idx] + \ ["*", str(abs(nrep))] + \ labels[left_idx + 1:all_idx[a][1]] + l_uni t_uni = [tt.type for tt in tree[all_idx[a][0]:left_idx]] + \ [2, 0] + \ [tt.type for tt in tree[left_idx + 1:all_idx[a][1]]] + t_uni elif right_idx is None: if nrep == 1: l_uni = labels[left_idx - 1:all_idx[a][1]] + l_uni t_uni = [ tt.type for tt in tree[left_idx-1:all_idx[a][1]]] + t_uni else: l_uni = labels[all_idx[a][0]:left_idx-1] + \ ["*", str(abs(nrep))] + \ labels[left_idx - 1:all_idx[a][1]] + l_uni t_uni = [tt.type for tt in tree[all_idx[a][0]:left_idx-1]] + \ [2, 0] + \ [tt.type for tt in tree[left_idx - 1:all_idx[a][1]]] + t_uni else: x = labels[right_idx].lstrip("-") if x == str(1) and nrep == 1: l_uni = labels[all_idx[a][0]+1:right_idx] + \ labels[right_idx + 1:all_idx[a][1]] + l_uni t_uni = [tt.type for tt in tree[all_idx[a][0]+1:right_idx]] + \ [tt.type for tt in tree[right_idx + 1:all_idx[a][1]]] + t_uni elif nrep != 0: l_uni = labels[all_idx[a][0]:right_idx] + \ ["*", str(abs(nrep))] + \ labels[right_idx + 1:all_idx[a][1]] + l_uni t_uni = [tt.type for tt in tree[all_idx[a][0]:right_idx]] + \ [2, 0] + \ [tt.type for tt in tree[right_idx + 1:all_idx[a][1]]] + t_uni if all_sign[a] == 1: n_uni = n_uni + ['+'] else: n_uni = n_uni + ['-'] elif all_sign[a] == 1: if nrep == 1: l_uni = labels[all_idx[a] [0]:all_idx[a][1]] + l_uni t_uni = [ tt.type for tt in tree[all_idx[a][0]:all_idx[a][1]]] + t_uni n_uni = n_uni + ['+'] elif nrep != 0: l_uni = ['*', str(nrep)] + \ labels[all_idx[a][0] :all_idx[a][1]] + l_uni t_uni = [ 2, 0] + [tt.type for tt in tree[all_idx[a][0]:all_idx[a][1]]] + t_uni n_uni = n_uni + ['+'] else: if abs(nrep) == 1: l_uni = labels[all_idx[a] [0]:all_idx[a][1]] + l_uni t_uni = [ tt.type for tt in tree[all_idx[a][0]:all_idx[a][1]]] + t_uni n_uni = n_uni elif nrep != 0: l_uni = ['*', str(abs(nrep))] + \ labels[all_idx[a][0] :all_idx[a][1]] + l_uni t_uni = [ 2, 0] + [tt.type for tt in tree[all_idx[a][0]:all_idx[a][1]]] + t_uni n_uni = n_uni if nrep > 0: n_uni = n_uni + ['+'] elif nrep < 0: n_uni = n_uni + ['-'] a = uni[plus_idx] nrep = [all_sign[b] for b in range( len(all_sign)) if (all_s[a] == all_s[b])] nrep = sum(nrep) if nrep == 1: L = L + n_uni + \ labels[all_idx[a][0]:all_idx[a][1]] + l_uni t = t + [2] * len(n_uni) + \ [tt.type for tt in tree[all_idx[a] [0]:all_idx[a][1]]] + t_uni else: L = L + n_uni + ['*', str(nrep)] + \ labels[all_idx[a][0]:all_idx[a][1]] + l_uni t = t + [2] * len(n_uni) + [2, 0] + \ [tt.type for tt in tree[all_idx[a] [0]:all_idx[a][1]]] + t_uni L += labels[end_idx:] t += [tt.type for tt in tree[end_idx:]] new_labels.append(L) new_shape.append(t) nadded += 1 if nadded == 1: new_labels = new_labels[0] new_shape = new_shape[0] return new_labels, new_shape, nadded
[docs] def find_additional_trees(tree, labels, basis_functions): """For a given tree, try to find all simpler representations of the function by combining sums, exponentials and powers Args: :tree (list): list of Node objects corresponding to tree of function :labels (list): list of strings giving node labels of tree :basis_functions (list): list of lists basis functions. basis_functions[0] are nullary, basis_functions[1] are unary and basis_functions[2] are binary operators Returns: :new_tree (list): list of equivalent trees, given as lists of Node objects :new_labels (list): list of lists of strings giving node labels of new_tree """ new_tree = [tree] new_labels = [labels] try_idx = [0] old_len = 0 # Try log, exp and sqrt changes while len(new_tree) != old_len: old_len = len(new_tree) for i in range(old_len): L, s, n = update_tree(new_tree[i], new_labels[i], try_idx[i], basis_functions) if (s is not None) and (L not in new_labels): if n == 1: _, _, t = check_tree(s) new_tree.append(t) new_labels.append(L) try_idx.append(0) else: for j in range(n): if (L[j] not in new_labels): _, _, t = check_tree(s[j]) new_tree.append(t) new_labels.append(L[j]) try_idx.append(0) try_idx[i] += 1 # Try sum changes try_idx = [0] * len(try_idx) old_len = 0 while len(new_tree) != old_len: old_len = len(new_tree) for i in range(old_len): L, s, n = update_sums(new_tree[i], new_labels[i], try_idx[i], basis_functions) if (s is not None) and (L not in new_labels): if n == 1: _, _, t = check_tree(s) max_param = max( 1, len([a for a in new_labels[i] if a.startswith('a')])) f = [node_to_string( 0, new_tree[i], new_labels[i]), node_to_string(0, t, L)] try: _, sym = simplifier.initial_sympify( f, max_param, verbose=False, parallel=False) if len(sym) != 1: print('Maybe bad (not keeping):', new_labels[i], '\t', sym[0], '\t', sym[1]) else: new_tree.append(t) new_labels.append(L) try_idx.append(0) except Exception: print('Failed sympy (not keeping):', new_labels[i], '\t', L) else: for j in range(n): if (L[j] not in new_labels): _, _, t = check_tree(s[j]) max_param = max( 1, len([a for a in new_labels[i] if a.startswith('a')])) f = [node_to_string( 0, new_tree[i], new_labels[i]), node_to_string(0, t, L[j])] try: _, sym = simplifier.initial_sympify( f, max_param, verbose=False, parallel=False) # if not sym[0].equals(sym[1]): if len(sym) != 1: print('Maybe bad (not keeping):', new_labels[i], '\t', sym[0], '\t', sym[1]) else: new_tree.append(t) new_labels.append(L[j]) try_idx.append(0) except Exception: print('Failed sympy (not keeping):', new_labels[i], '\t', L[j]) if n <= 1: try_idx[i] += 1 return new_tree, new_labels
[docs] def shape_to_functions(s, basis_functions): """Find all possible functions formed from the given list of 0s, 1s and 2s defining a tree and basis functions Args: :s (str): string comprised of 0, 1 and 2 representing tree of nullary, unary and binary nodes :basis_functions (list): list of lists basis functions. basis_functions[0] are nullary, basis_functions[1] are unary and basis_functions[2] are binary operators Returns: :all_fun (list): list of strings containing all functions generated directly from tree :all_tree (list): list of lists of Node objects corresponding to the trees of functions in all_fun :extra_fun (list): list of strings containing functions generated by combining sums, exponentials and powers of the functions in all_fun :extra_tree (list): list of lists of Node objects corresponding to the trees of functions in extra_fun :extra_orig (list): list of strings corresponding to original versions of extra_fun, as found in all_fun """ n0 = np.sum(s == 0) n1 = np.sum(s == 1) n2 = np.sum(s == 2) t0 = [list(t) for t in itertools.product(basis_functions[0], repeat=n0)] t1 = [list(t) for t in itertools.product(basis_functions[1], repeat=n1)] t2 = [list(t) for t in itertools.product(basis_functions[2], repeat=n2)] # Rename parameters so appear in order for i in range(len(t0)): indices = [j for j, x in enumerate(t0[i]) if x == 'a'] for j in range(len(indices)): t0[i][indices[j]] = 'a%i' % j success, part_considered, tree = check_tree(s) all_fun = [None] * (len(t0) * len(t1) * len(t2)) if rank == 0: all_tree = [None] * (len(t0) * len(t1) * len(t2)) else: all_tree = None pos = 0 labels = np.empty(len(s), dtype='U100') m0 = (s == 0) m1 = (s == 1) m2 = (s == 2) t0 = np.array(t0) t1 = np.array(t1) t2 = np.array(t2) extra_tree = [] extra_fun = [] extra_orig = [] i = utils.split_idx(len(t0) * len(t1) * len(t2), rank, size) if len(i) == 0: imin = 0 imax = 0 else: imin = i[0] imax = i[-1] + 1 for i in range(len(t0)): for j in range(len(t1)): for k in range(len(t2)): labels[:] = None labels[m0] = t0[i, :] labels[m1] = t1[j, :] labels[m2] = t2[k, :] if rank == 0: all_tree[pos] = labels.copy() all_fun[pos] = node_to_string(0, tree, labels) if (pos >= imin) and (pos < imax): new_tree, new_labels = find_additional_trees( tree, list(labels), basis_functions) if len(new_tree) > 1: for n in range(1, len(new_tree)): extra_tree.append(new_labels[n].copy()) extra_fun.append(node_to_string( 0, new_tree[n], new_labels[n])) extra_orig.append(all_fun[pos]) pos += 1 extra_tree = comm.gather(extra_tree, root=0) extra_fun = comm.gather(extra_fun, root=0) extra_orig = comm.gather(extra_orig, root=0) if rank == 0: extra_tree = list(itertools.chain(*extra_tree)) extra_fun = list(itertools.chain(*extra_fun)) extra_orig = list(itertools.chain(*extra_orig)) extra_fun = comm.bcast(extra_fun, root=0) extra_orig = comm.bcast(extra_orig, root=0) comm.Barrier() return all_fun, all_tree, extra_fun, extra_tree, extra_orig
[docs] def labels_to_shape(labels, basis_functions): """Find the representation of the shape of a tree given its labels Args: :labels (list): list of strings giving node labels of tree :basis_functions (list): list of lists basis functions. basis_functions[0] are nullary, basis_functions[1] are unary and basis_functions[2] are binary operators Returns: :s (str): string comprised of 0, 1 and 2 representing tree of nullary, unary and binary nodes """ basis_dict = {} for i in range(len(basis_functions)): for f in basis_functions[i]: basis_dict[f] = i s = [None] * len(labels) for i, t in enumerate(labels): try: s[i] = basis_dict[t] except Exception: if (t.startswith('a') and t[1:].isdigit()) or (is_float(t)): s[i] = 0 else: raise ValueError return s
[docs] def aifeyn_complexity(tree, param_list): """Compute contribution to description length from describing tree Args: :tree (list): list of strings giving node labels of tree :param_list (list): list of strings of all possible parameter names Returns: :aifeyn (float): the contribution to description length from describing tree """ t = [tt for tt in tree if (tt not in param_list) and ( not tt.lstrip("-").isdigit())] #  Operators n = np.array([int(tt) for tt in tree if tt.lstrip("-").isdigit()]) # Integers n[n == 0] = 1 #  So we have log(1) for 0 instead of log(0) has_param = int(len(t) != len(tree)) #  Has either an a0 or an integer nop = len(set(t)) + has_param return len(tree) * np.log(nop) + np.sum(np.log(np.abs(n)))
[docs] def generate_equations(compl, basis_functions, dirname): """Generate all equations at a given complexity for a set of basis functions and save results to file Args: :compl (int): complexity of functions to consider :basis_functions (list): list of lists basis functions. basis_functions[0] are nullary, basis_functions[1] are unary and basis_functions[2] are binary operators :dirname (str): directory path to save results in Returns: :all_fun (list): list of strings containing all functions generated :extra_orig (list): list of strings containing functions generated by combining sums, exponentials and powers of the functions in all_fun as they appear in all_fun """ shapes = get_allowed_shapes(compl) nfun = np.empty((shapes.shape[0], 3)) for i in range(3): nfun[:, i] = len(basis_functions[i]) ** np.sum(shapes == i, axis=1) nfun = np.prod(nfun, axis=1) if rank == 0: print('\nNumber of topologies:', shapes.shape[0]) for i in range(shapes.shape[0]): print(shapes[i, :], int(nfun[i])) sys.stdout.flush() nfun = np.sum(nfun) if rank == 0: print('\nOriginal number of trees:', int(nfun)) sys.stdout.flush() all_fun = [None] * len(shapes) extra_fun = [None] * len(shapes) extra_orig = [None] * len(shapes) sys.stdout.flush() comm.Barrier() # Clear the files if rank == 0: for fname in ['orig_trees', 'extra_trees', 'orig_aifeyn', 'extra_aifeyn']: with open(dirname + '/%s_%i.txt' % (fname, compl), 'w') as f: pass ntree = 0 nextratree = 0 for i in range(len(shapes)): if rank == 0: print('%i of %i' % (i+1, len(shapes))) sys.stdout.flush() all_fun[i], all_tree, extra_fun[i], extra_tree, extra_orig[i] = shape_to_functions( shapes[i], basis_functions) if rank == 0: ntree += len(all_tree) nextratree += len(extra_tree) max_param = simplifier.get_max_param(all_fun[i], verbose=False) param_list = ['a%i' % j for j in range(max_param)] if rank == 0: with open(dirname + '/orig_trees_%i.txt' % compl, 'a') as f: w = 80 pp = pprint.PrettyPrinter(width=w, stream=f) for t in all_tree: s = str(t) if len(s + '\n') > w / 2: w = 2 * len(s) pp = pprint.PrettyPrinter(width=w, stream=f) pp.pprint(s) with open(dirname + '/extra_trees_%i.txt' % compl, 'a') as f: w = 80 pp = pprint.PrettyPrinter(width=w, stream=f) for t in extra_tree: s = str(t) if len(s + '\n') > w / 2: w = 2 * len(s) pp = pprint.PrettyPrinter(width=w, stream=f) pp.pprint(s) with open(dirname + '/orig_aifeyn_%i.txt' % compl, 'a') as f: for tree in all_tree: print(aifeyn_complexity(tree, param_list), file=f) with open(dirname + '/extra_aifeyn_%i.txt' % compl, 'a') as f: for tree in extra_tree: print(aifeyn_complexity(tree, param_list), file=f) if rank == 0: s = 'cat %s/orig_trees_%i.txt %s/extra_trees_%i.txt > %s/trees_%i.txt' % ( dirname, compl, dirname, compl, dirname, compl) sys.stdout.flush() os.system(s) s = 'cat %s/orig_aifeyn_%i.txt %s/extra_aifeyn_%i.txt > %s/aifeyn_%i.txt' % ( dirname, compl, dirname, compl, dirname, compl) sys.stdout.flush() os.system(s) all_fun = list(itertools.chain(*all_fun)) extra_fun = list(itertools.chain(*extra_fun)) extra_orig = list(itertools.chain(*extra_orig)) all_fun = all_fun + extra_fun if rank == 0: print('\nNew number of trees:', len(all_fun)) sys.stdout.flush() return all_fun, extra_orig