import numpy as np
import sympy
import signal
import sys
import itertools
from mpi4py import MPI
from contextlib import contextmanager
import csv
import ast
import gc
from collections import OrderedDict
import pprint
import os
import esr.generation.utils as utils
from esr.generation.custom_printer import ESRPrinter
from esr.fitting.sympy_symbols import (
sympy_locs, square, cube, pow_abs, sqrt_abs, log_abs
)
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
[docs]
class TimeoutException(Exception):
pass
[docs]
@contextmanager
def time_limit(seconds):
""" Check function call does not exceed allotted time
Args:
:seconds (float): maximum time function can run in seconds
Raises:
TimeoutException if time exceeds seconds
"""
def signal_handler(signum, frame):
raise TimeoutException("Timed out")
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)
[docs]
def get_max_param(all_fun, verbose=True):
""" Find maximum number of free parameters in list of functions
Args:
:all_fun (list): list of strings containing functions
:verbose (bool, default=True): Whether to print result (True) or not (False)
Returns:
:max_param (int): maximum number of free parameters in any equation in all_fun
"""
max_param = -1
with_ai = all_fun.copy()
while len(with_ai) > 0:
max_param += 1
with_ai = [f for f in with_ai if 'a%i' % max_param in f]
if max_param < 0:
max_param = 0
if verbose and rank == 0:
print('\nMax number of parameters:', max_param)
sys.stdout.flush()
return max_param
[docs]
def count_params(all_fun, max_param):
""" Count the number of free parameters in each member of a list of functions
Args:
:all_fun (list): list of strings containing functions
:max_param (int): maximum number of free parameters in any equation in all_fun
Returns:
:nparam (np.array): array of ints containing number of free parameters in corresponding member of all_fun
"""
nparam = np.zeros(len(all_fun), dtype=int)
param_list = ['a%i' % i for i in range(max_param)]
for i in range(len(nparam)):
for j in range(max_param-1, -1, -1):
if param_list[j] in all_fun[i]:
nparam[i] = j+1
break
return nparam
[docs]
def make_changes(all_fun, all_sym, all_inv_subs, str_fun, sym_fun, inv_subs_fun):
""" Update global variables of functions and symbolic expressions by combining rank
calculations
Args:
:all_fun (list): list of strings containing all functions
:all_sym (list): list of sympy objects containing all functions
:all_inv_subs (list): list of dictionaries giving subsitutions to be applied to all functions
:str_fun (list): list of strings containing functions considered by rank
:sym_fun (list): list of sympy objects containing functions considered by rank
:inv_subs_fun (list): list of dictionaries giving subsitutions to be applied to functions considered by rank
Returns:
:all_fun: list of strings containing all (updated) functions
:all_sym (list): list of sympy objects containing all (updated) functions
:all_inv_subs: list of dictionaries giving subsitutions to be applied to all (updated) functions
"""
i = utils.split_idx(len(all_fun), rank, size)
if len(i) > 0:
imin = i[0]
imax = i[-1] + 1
start_idx = imax - imin
else:
start_idx = 0
imin = len(all_fun)
start_idx = comm.gather(start_idx, root=0)
if rank == 0:
start_idx = np.array([0] + start_idx)
start_idx = np.cumsum(start_idx)
start_idx = comm.bcast(start_idx, root=0)
chidx = [i for i in range(len(str_fun)) if str_fun[i] != all_fun[imin+i]]
str_changes = [str_fun[c] for c in chidx]
sym_changes = [sym_fun[c] for c in chidx]
inv_changes = [inv_subs_fun[c] for c in chidx]
chidx = comm.gather(chidx, root=0)
str_changes = comm.gather(str_changes, root=0)
sym_changes = comm.gather(sym_changes, root=0)
inv_changes = comm.gather(inv_changes, root=0)
chidx = comm.bcast(chidx, root=0)
str_changes = comm.bcast(str_changes, root=0)
sym_changes = comm.bcast(sym_changes, root=0)
inv_changes = comm.bcast(inv_changes, root=0)
for i in range(size):
j = chidx[i] + start_idx[i]
for k in range(len(inv_changes[i])):
all_fun[j[k]] = str_changes[i][k]
all_sym[j[k]] = sym_changes[i][k]
if inv_changes[i][k] is None:
all_inv_subs[j[k]] = None
else:
all_inv_subs[j[k]] = inv_changes[i][k].copy()
return all_fun, all_sym, all_inv_subs
[docs]
def initial_sympify(all_fun, max_param, verbose=True, parallel=True, track_memory=False, save_sympy=True):
"""Convert list of strings of functions into list of sympy objects
Args:
:all_fun (list): list of strings containing functions
:max_param (int): maximum number of free parameters in any equation in all_fun
:verbose (bool, default=True): whether to print progress (True) or not (False)
:parallel (bool, default=True): whether to split equations amongst ranks (True) or each equation considered by all ranks (False)
:track_memory (bool, default=True): whether to compute and print memory statistics (True) or not (False)
:save_sympy (bool, default=True): whether to return sympy objects (True) or not (False)
Returns:
:str_fun (list): list of strings containing functions
:sym_fun (OrderedDict): dictionary of sympy objects which can be accessed by their string representations. If save_sympy is False, then sym_fun is None.
"""
if rank == 0 and verbose:
if track_memory:
utils.using_mem("start initial sympify")
utils.locals_size(locals())
print('\nSympy simplify')
sys.stdout.flush()
x, x0, y = sympy.symbols('x x0 y', positive=True)
if max_param > 0:
param_list = ['a%i' % i for i in range(max_param)]
all_a = sympy.symbols(" ".join(param_list), real=True)
if max_param == 1:
all_a = [all_a]
else:
param_list = []
sympy.init_printing(use_unicode=True)
locs = sympy_locs
if max_param > 0:
for i in range(len(all_a)):
locs["a%i" % i] = all_a[i]
if parallel:
i = np.atleast_1d(utils.split_idx(len(all_fun), rank, size))
if len(i) == 0:
str_fun = []
else:
str_fun = all_fun[i[0]:i[-1]+1]
else:
str_fun = all_fun
if save_sympy:
sym_fun = OrderedDict()
else:
sym_fun = None
p = ESRPrinter()
for i in range(len(str_fun)):
try:
s = sympy.sympify(str_fun[i], locals=locs)
except Exception:
print('Making %s a zoo' % str_fun[i])
s = sympy.zoo
str_fun[i] = p.doprint(s)
if save_sympy:
if str_fun[i] not in sym_fun:
sym_fun[str_fun[i]] = s
# We have to gather these, although won't do this again
if parallel:
# First find which ranks contain which indices
start_idx = len(str_fun)
start_idx = comm.gather(start_idx, root=0)
if rank == 0:
start_idx = np.array([0] + start_idx, dtype=int)
start_idx = np.squeeze(np.cumsum(start_idx))
start_idx = comm.bcast(start_idx, root=0)
# Now send each rank to everyone else
all_fun = [None] * start_idx[-1]
for r in range(size):
all_fun[start_idx[r]:start_idx[r+1]] = comm.bcast(str_fun, root=r)
str_fun = all_fun
if save_sympy:
all_sym = OrderedDict()
for r in range(size):
sym_keys = comm.bcast(list(sym_fun.keys()), root=r)
sym_vals = comm.bcast(list(sym_fun.values()), root=r)
for i in range(len(sym_keys)):
key = sym_keys[i]
if key not in all_sym:
all_sym[key] = sym_vals[i]
sym_fun = all_sym
return str_fun, sym_fun
[docs]
def sympy_simplify(all_fun, all_sym, all_inv_subs, max_param, expand_fun=True, tmax=1, check_perm=False):
"""Simplify equations and find duplicates.
Args:
:all_fun (list): list of strings containing all functions
:all_sym (list): list of sympy objects containing all functions
:all_inv_subs (list): list of dictionaries giving subsitutions to be applied to all functions
:max_param (int): maximum number of free parameters in any equation in all_fun
:expand_fun (bool, default=True): whether to run the sympy expand options (True) or not (False)
:tmax (float, default=1.): maximum time in seconds to run any one part of simplification procedure for a given function
:check_perm (bool, default=False): whether to check all possible permutations and inverses of constants (True) or not (False)
Returns:
:all_fun: list of strings containing all (updated) functions
:all_sym (list): list of sympy objects containing all (updated) functions
:all_inv_subs: list of dictionaries giving subsitutions to be applied to all (updated) functions
"""
if max_param == 0:
return all_fun, all_sym, all_inv_subs
if len(all_fun) == 0:
return all_fun, all_sym, all_inv_subs
esrp = ESRPrinter()
if max_param > 0:
param_list = ['a%i' % i for i in range(max_param)]
all_a = sympy.symbols(" ".join(param_list), real=True)
if max_param == 1:
all_a = [all_a]
else:
param_list = []
i = np.atleast_1d(utils.split_idx(len(all_inv_subs), rank, size))
if len(i) == 0:
str_fun = []
sym_fun = []
inv_subs_fun = []
else:
str_fun = all_fun[i[0]:i[-1]+1]
sym_fun = all_sym[i[0]:i[-1]+1]
inv_subs_fun = all_inv_subs[i[0]:i[-1]+1]
identity_subs = {a: a for a in all_a}
# Do some substitutions to simplify
comm.Barrier()
comb = list(itertools.combinations(np.flip(np.arange(max_param)), 2))
if max_param > 1:
for c in comb:
# Second number = 0 if normal subs, = 1 if abs subs
all_expr = [[all_a[c[0]] + all_a[c[1]], 0],
[all_a[c[0]] - all_a[c[1]], 0],
[all_a[c[1]] - all_a[c[0]], 0],
[all_a[c[0]] * all_a[c[1]], 0],
[all_a[c[0]] / all_a[c[1]], 0],
[all_a[c[1]] / all_a[c[0]], 0],
[all_a[c[0]] + sympy.Abs(all_a[c[1]]), 0],
[all_a[c[0]] - sympy.Abs(all_a[c[1]]), 0],
[sympy.Abs(all_a[c[1]]) - all_a[c[0]], 0],
[all_a[c[0]] * sympy.Abs(all_a[c[1]]), 0],
[all_a[c[0]] / sympy.Abs(all_a[c[1]]), 0],
[all_a[c[1]] / sympy.Abs(all_a[c[0]]), 0],
[all_a[c[1]] + sympy.Abs(all_a[c[0]]), 0],
[all_a[c[1]] - sympy.Abs(all_a[c[0]]), 0],
[sympy.Abs(all_a[c[0]]) - all_a[c[1]], 0],
[all_a[c[1]] * sympy.Abs(all_a[c[0]]), 0],
[all_a[c[1]] / sympy.Abs(all_a[c[0]]), 0],
[all_a[c[0]] / sympy.Abs(all_a[c[1]]), 0],
[sympy.Abs(all_a[c[0]]) * sympy.Abs(all_a[c[1]]), 1],
[sympy.Abs(all_a[c[0]]) + sympy.Abs(all_a[c[1]]), 1],
[sympy.Abs(all_a[c[0]]) - sympy.Abs(all_a[c[1]]), 0],
[sympy.Abs(all_a[c[1]]) - sympy.Abs(all_a[c[0]]), 0],
[sympy.Abs(all_a[c[0]]) / sympy.Abs(all_a[c[1]]), 1],
[sympy.Abs(all_a[c[1]]) / sympy.Abs(all_a[c[0]]), 1],
[pow_abs(all_a[c[1]], all_a[c[0]]), 1],
[pow_abs(all_a[c[0]], all_a[c[1]]), 1],
[pow_abs(all_a[c[1]], sympy.Abs(all_a[c[0]])), 1],
[pow_abs(all_a[c[0]], sympy.Abs(all_a[c[1]])), 1],
]
for i in range(len(str_fun)):
orig_fun = str_fun[i]
orig_sym = sym_fun[i]
try:
with time_limit(tmax):
if (all_a[c[0]] in sym_fun[i].free_symbols) and (all_a[c[1]] in sym_fun[i].free_symbols):
# Make sure symbols only appear once in sym version
if sym_fun[i].count(all_a[c[1]]) == 1:
if sym_fun[i].count(all_a[c[0]]) == 1:
v = 1
keep = False
else:
v = None # Don't have to ignore this combination
keep = True
# or str_fun[i].count(param_list[c[0]]) == 1:
elif sym_fun[i].count(all_a[c[0]]) == 1:
if sym_fun[i].count(all_a[c[1]]) == 1:
v = 0
keep = False
else:
v = None # Don't have to ignore this combination
keep = True
else:
v = None
if v is not None:
for expr in all_expr:
if sym_fun[i].has(expr[0]):
s = sym_fun[i]
f1 = str(sym_fun[i])
if expr[1] == 0:
sym_fun[i] = sym_fun[i].subs(
expr[0], all_a[c[v]])
f2 = str(sym_fun[i])
if inv_subs_fun[i] is None:
if keep:
inv_subs_fun[i] = [
str({expr[0]: all_a[c[v]]})]
else:
inv_subs_fun[i] = [
str(np.nan)]
else:
if keep:
inv_subs_fun[i].append(
str({expr[0]: all_a[c[v]]}))
else:
inv_subs_fun[i].append(
str(np.nan))
elif expr[1] == 1:
sym_fun[i] = sym_fun[i].subs(
expr[0], sympy.Abs(all_a[c[v]], evaluate=False))
f2 = str(sym_fun[i])
if inv_subs_fun[i] is None:
if keep:
inv_subs_fun[i] = [
str({expr[0]: sympy.Abs(all_a[c[v]])})]
else:
inv_subs_fun[i] = [
str(np.nan)]
else:
if keep:
inv_subs_fun[i].append(
str({expr[0]: sympy.Abs(all_a[c[v]])}))
else:
inv_subs_fun[i].append(
str(np.nan))
if expand_fun:
str_fun[i] = esrp.doprint(
sym_fun[i].expand())
else:
str_fun[i] = esrp.doprint(
sym_fun[i])
except TimeoutException:
print('TIMED OUT:', orig_fun)
str_fun[i] = orig_fun
sym_fun[i] = orig_sym
# See if multiples of constants appear
comm.Barrier()
if max_param > 0:
for i in range(len(str_fun)):
orig_fun = str_fun[i]
orig_sym = sym_fun[i]
try:
with time_limit(tmax):
# Can't use force=True since sometimes pulls out factor -1 and computes log(-1)
if expand_fun:
sym_fun[i] = sympy.expand_log(sym_fun[i])
numbers = [atom for atom in sym_fun[i].atoms(
) if atom.is_number and atom.is_finite]
even = [n for n in numbers if n.is_Integer and n.is_even]
odd = [n for n in numbers if n.is_Integer and n.is_odd]
for j in range(len(param_list)):
if str_fun[i].count(param_list[j]) > 0:
all_expr = [
[n*all_a[j], all_a[j], 0, str({all_a[j]: all_a[j]/n})] for n in numbers]
all_expr += [[all_a[j]**n, sympy.Abs(all_a[j], evaluate=False), 0, str(
{all_a[j]: pow_abs(all_a[j], 1/n)})] for n in even]
all_expr += [[all_a[j]**n, all_a[j], 0,
str({all_a[j]: all_a[j] ** (1/n)})] for n in odd]
all_expr += [[all_a[j]**n * sympy.Abs(all_a[j]), sympy.Abs(
all_a[j], evaluate=False), 0, str({all_a[j]: pow_abs(all_a[j], 1/(n+1))})] for n in even]
all_expr += [[all_a[j]**n * sympy.Abs(all_a[j]), all_a[j], 0, str(
{all_a[j]: pow_abs(all_a[j], 1/(n+1)) * sympy.sign(all_a[j])})] for n in odd]
all_expr += [[square(all_a[j]), sympy.Abs(all_a[j], evaluate=False), 0, str({all_a[j]: sqrt_abs(all_a[j])})],
[cube(all_a[j]), all_a[j], 0, str(
{all_a[j]: all_a[j]**(1/3)})],
[square(sympy.Abs(all_a[j])), sympy.Abs(
all_a[j], evaluate=False), 0, str({all_a[j]: sqrt_abs(all_a[j])})],
[cube(sympy.Abs(all_a[j])), sympy.Abs(all_a[j], evaluate=False), 0, str(
{all_a[j]: pow_abs(all_a[j], 1/3)})],
[sqrt_abs(all_a[j]), sympy.Abs(all_a[j], evaluate=False), 0, str(
{all_a[j]: square(all_a[j])})],
[log_abs(all_a[j]), all_a[j], 1, str(
{all_a[j]: sympy.exp(all_a[j])})],
[sympy.exp(all_a[j]), sympy.Abs(all_a[j], evaluate=False), 0, str(
{all_a[j]: log_abs(all_a[j])})]
]
for expr in all_expr:
if sym_fun[i].has(expr[0]):
ss = str(sym_fun[i]).replace(" ", "")
ee = str(expr[0]).replace(" ", "")
# Make sure variable only appears in this form in the sym version
if ss.count(param_list[j]) in [1, ss.count(ee), ee.count(param_list[j])]:
f0 = sym_fun[i].copy()
sym_fun[i] = sym_fun[i].subs(
{expr[0]: expr[1]})
try:
if expr[2] == 0:
if 'zoo' in str(expr[3]):
# Don't make this substitution
sym_fun[i] = f0.copy()
else:
if inv_subs_fun[i] is None:
inv_subs_fun[i] = [
expr[3]]
else:
inv_subs_fun[i].append(
expr[3])
elif expr[2] == 1:
# These cases can be tricky with Abs
s = {expr[1]: expr[0]}
f1 = sym_fun[i].subs(
{sympy.Abs(all_a[j]): all_a[j]}).subs(s)
if f0.equals(f1):
# It worked, so append original subs
if inv_subs_fun[i] is None:
inv_subs_fun[i] = [
expr[3]]
else:
inv_subs_fun[i].append(
expr[3])
else:
s = {expr[1]: sympy.Abs(
expr[0], evaluate=False)}
f2 = sym_fun[i].subs(
{sympy.Abs(all_a[j]): all_a[j]}).subs(s)
if f0.equals(f2):
if inv_subs_fun[i] is None:
inv_subs_fun[i] = [
expr[3]]
else:
inv_subs_fun[i].append(
expr[3])
else:
# Can't undo the simplification, so we won't do it
sym_fun[i] = f0.copy()
except Exception:
print('Bad comparison:', f0, f1)
sys.stdout.flush()
sym_fun[i] = f0.copy()
if expand_fun:
str_fun[i] = esrp.doprint(
sym_fun[i].expand())
else:
str_fun[i] = esrp.doprint(
sym_fun[i])
break
except TimeoutException:
print('TIMED OUT:', orig_fun)
str_fun[i] = orig_fun
sym_fun[i] = orig_sym
comm.Barrier()
all_fun, all_sym, all_inv_subs = make_changes(all_fun, all_sym, all_inv_subs,
str_fun, sym_fun, inv_subs_fun)
i = np.atleast_1d(utils.split_idx(len(all_inv_subs), rank, size))
if len(i) == 0:
str_fun = []
sym_fun = []
inv_subs_fun = []
else:
str_fun = all_fun[i[0]:i[-1]+1]
sym_fun = all_sym[i[0]:i[-1]+1]
inv_subs_fun = all_inv_subs[i[0]:i[-1]+1]
change_indices = []
ref_indices = []
new_inv_subs = []
# Check permutations and inverses of constants
comm.Barrier()
if max_param > 1 and check_perm:
use_a = list(all_a) + [1/a for a in all_a]
perm = list(itertools.permutations(
np.flip(np.arange(len(use_a))), len(all_a)))
for i in range(len(str_fun)):
orig_fun = str_fun[i]
orig_sym = sym_fun[i]
s = list(sym_fun[i].free_symbols)
s = list(set(s).intersection(all_a))
perm = list(itertools.permutations(
np.flip(np.arange(len(s))), len(s)))
perm.remove(tuple(range(len(s))))
try_subs = [{s[i]: s[p[i]]
for i in range(len(p)) if i != p[i]} for p in perm]
try:
with time_limit(tmax):
for p in range(len(try_subs)):
if all([a in sym_fun[i].free_symbols for a in list(try_subs[p].keys())]):
expr = sym_fun[i].subs(
try_subs[p], simultaneous=True)
if expand_fun:
str_expand = esrp.doprint(expr.expand())
else:
str_expand = esrp.doprint(expr)
if str_expand in all_fun:
m = all_fun.index(str_expand)
n = all_fun.index(str_fun[i])
if n != m:
change_indices.append(n)
ref_indices.append(m)
new_inv_subs.append(str(try_subs[p]))
break
except TimeoutException:
print('TIMED OUT:', orig_fun)
str_fun[i] = orig_fun
sym_fun[i] = orig_sym
comm.Barrier()
change_indices = comm.gather(change_indices, root=0)
ref_indices = comm.gather(ref_indices, root=0)
new_inv_subs = comm.gather(new_inv_subs, root=0)
if rank == 0:
change_indices = list(itertools.chain(*change_indices))
ref_indices = list(itertools.chain(*ref_indices))
new_inv_subs = list(itertools.chain(*new_inv_subs))
change_indices = comm.bcast(change_indices, root=0)
ref_indices = comm.bcast(ref_indices, root=0)
new_inv_subs = comm.bcast(new_inv_subs, root=0)
for i in range(len(change_indices)):
# Check we haven't already made the change
if (ref_indices[i] not in change_indices[:i]) and (change_indices[i] not in change_indices[:i]):
all_fun[change_indices[i]] = all_fun[ref_indices[i]]
all_sym[change_indices[i]] = all_sym[ref_indices[i]]
if all_inv_subs[change_indices[i]] is None:
all_inv_subs[change_indices[i]] = []
all_inv_subs[change_indices[i]].append(new_inv_subs[i])
i = np.atleast_1d(utils.split_idx(len(all_inv_subs), rank, size))
if len(i) == 0:
str_fun = []
sym_fun = []
inv_subs_fun = []
else:
str_fun = all_fun[i[0]:i[-1]+1]
sym_fun = all_sym[i[0]:i[-1]+1]
inv_subs_fun = all_inv_subs[i[0]:i[-1]+1]
comm.Barrier()
if max_param > 0:
change_indices = []
ref_indices = []
new_inv_subs = []
# See if function with a0 -> -a0 already in list
for i in range(len(str_fun)):
orig_fun = str_fun[i]
orig_sym = sym_fun[i]
try:
with time_limit(tmax):
for j in range(len(all_a)):
if param_list[j] in str_fun[i]:
expr = sym_fun[i].subs(all_a[j], -all_a[j])
if str(expr) != str_fun[i]:
if expand_fun:
str_expand = str(expr.expand())
else:
str_expand = str(expr)
if str_expand in all_fun:
m = all_fun.index(str_expand)
n = all_fun.index(str_fun[i])
if n != m:
change_indices.append(n)
ref_indices.append(m)
new_inv_subs.append(
str({all_a[j]: -all_a[j]}))
break
except TimeoutException:
print('TIMED OUT:', orig_fun)
str_fun[i] = orig_fun
sym_fun[i] = orig_sym
comm.Barrier()
change_indices = comm.gather(change_indices, root=0)
ref_indices = comm.gather(ref_indices, root=0)
new_inv_subs = comm.gather(new_inv_subs, root=0)
if rank == 0:
change_indices = list(itertools.chain(*change_indices))
ref_indices = list(itertools.chain(*ref_indices))
new_inv_subs = list(itertools.chain(*new_inv_subs))
change_indices = comm.bcast(change_indices, root=0)
ref_indices = comm.bcast(ref_indices, root=0)
new_inv_subs = comm.bcast(new_inv_subs, root=0)
for i in range(len(change_indices)):
# Check we haven't already made the change
if (ref_indices[i] not in change_indices[:i]) and (change_indices[i] not in change_indices[:i]):
all_fun[change_indices[i]] = all_fun[ref_indices[i]]
all_sym[change_indices[i]] = all_sym[ref_indices[i]]
if all_inv_subs[change_indices[i]] is None:
all_inv_subs[change_indices[i]] = []
all_inv_subs[change_indices[i]].append(new_inv_subs[i])
i = np.atleast_1d(utils.split_idx(len(all_inv_subs), rank, size))
if len(i) == 0:
str_fun = []
sym_fun = []
inv_subs_fun = []
else:
str_fun = all_fun[i[0]:i[-1]+1]
sym_fun = all_sym[i[0]:i[-1]+1]
inv_subs_fun = all_inv_subs[i[0]:i[-1]+1]
# Check parameters are in correct order
comm.Barrier()
for i in range(len(str_fun)):
orig_fun = str_fun[i]
orig_sym = sym_fun[i]
try:
with time_limit(tmax):
vars = list(sym_fun[i].free_symbols)
vars = [str(v) for v in vars]
param_list = ['a%i' % i for i in range(max_param)]
common = list(set(param_list).intersection(vars))
if len(common) > 0:
common.sort()
if common[-1] != param_list[len(common)-1]:
common = [int(v[1:]) for v in common]
s = {all_a[common[i]]: all_a[i]
for i in range(len(common))}
sym_fun[i] = sym_fun[i].subs(s, simultaneous=True)
if expand_fun:
str_fun[i] = esrp.doprint(sym_fun[i].expand())
else:
str_expand = esrp.doprint(expr)
if s != identity_subs:
if inv_subs_fun[i] is None:
inv_subs_fun[i] = [str(s)]
else:
inv_subs_fun[i].append(str(s))
except TimeoutException:
print('TIMED OUT:', orig_fun)
str_fun[i] = orig_fun
sym_fun[i] = orig_sym
# If we find a zoo, let's make this a nan
for i in range(len(sym_fun)):
if sympy.zoo in sym_fun[i].atoms():
sym_fun[i] = sympy.core.numbers.NaN
str_fun[i] = str(sympy.core.numbers.NaN)
comm.Barrier()
all_fun, all_sym, all_inv_subs = make_changes(all_fun, all_sym, all_inv_subs,
str_fun, sym_fun, inv_subs_fun)
return all_fun, all_sym, all_inv_subs
[docs]
def expand_or_factor(all_sym, tmax=1, method='expand'):
"""Run the sympy expand or factor functions
Args:
:all_sym (OrderedDict): dictionary of sympy objects which can be accessed by their string representations.
:tmax (float, default=1.): maximum time in seconds to run any one part of expand/simplify procedure for a given function
:method (str, default='expand'): whether to run expand ('expand') or simplify ('simplify'). All other options are ignored
Returns:
:all_sym (OrderedDict): dictionary of (updated) sympy objects which can be accessed by their string representations.
"""
vals = list(all_sym.values())
keys = list(all_sym.keys())
i = np.atleast_1d(utils.split_idx(len(vals), rank, size))
change_vals = []
change_idx = []
p = ESRPrinter()
if len(i) > 0:
for j in range(i[0], i[-1]+1):
if vals[j] is sympy.core.numbers.NaN:
continue
try:
with time_limit(tmax):
if method == 'expand':
v = vals[j].expand()
elif method == 'factor':
v = vals[j].powsimp()
v = v.factor()
if p.doprint(v) != keys[j]:
change_idx.append(j)
change_vals.append(v)
except TimeoutException:
print('Terminated expanding:', j, rank, vals[j])
change_vals = comm.gather(change_vals, root=0)
change_idx = comm.gather(change_idx, root=0)
if rank == 0:
change_vals = list(itertools.chain(*change_vals))
change_idx = list(itertools.chain(*change_idx))
change_vals = comm.bcast(change_vals, root=0)
change_idx = comm.bcast(change_idx, root=0)
for i in range(len(change_idx)):
all_sym[keys[change_idx[i]]] = change_vals[i]
return all_sym
[docs]
def do_sympy(all_fun, all_sym, compl, search_tmax, expand_tmax, dirname, track_memory=False):
"""Run the duplicate checking procedure
Args:
:all_fun (list): list of strings containing all functions
:all_sym (OrderedDict): dictionary of sympy objects which can be accessed by their string representations.
:compl (int):
:search_tmax (float, default=1.): maximum time in seconds to run any one part of simplification procedure for a given function
:expand_tmax (float, default=1.): maximum time in seconds to run any one part of expand/simplify procedure for a given function
:dirname (str): directory path to save results in
:track_memory (bool, default=True): whether to compute and print memory statistics (True) or not (False)
Returns:
:all_fun (list): list of strings containing all (updated) functions
:all_sym (list): dictionary of (updated) sympy objects which can be accessed by their string representations.
:count (int): number of rounds of optimisation which were performed
"""
if rank == 0 and track_memory:
utils.using_mem("start do_sympy")
utils.locals_size(locals())
max_param = get_max_param(all_fun)
# Split by number of parameters
old_nuniq = 0
new_nuniq = len(all_fun)
count = 0
# Initial optimisation
while old_nuniq != new_nuniq:
all_inv_subs = [None] * len(all_fun)
old_nuniq = new_nuniq
if rank == 0:
print('Optimisation', count, old_nuniq, len(all_fun))
sys.stdout.flush()
if rank == 0 and track_memory:
utils.using_mem("start")
utils.locals_size(locals())
# (1) Get unique functions and matches
if rank == 0:
print('\tGetting unique functions')
sys.stdout.flush()
uniq, match = utils.get_unique_indexes(all_fun)
uniq_fun = list(uniq.keys())
if rank == 0:
if track_memory:
utils.using_mem("end")
utils.locals_size(locals())
print('\tGetting unique sympy')
sys.stdout.flush()
all_sym = [all_sym[u] for u in uniq_fun]
if rank == 0:
print('\tGetting unique inverse subs')
sys.stdout.flush()
uniq_inv_subs = [all_inv_subs[i] for i in uniq.values()]
del uniq
gc.collect()
# (2) Simplify the unique functions
add_inv_subs = [None] * len(uniq_inv_subs)
nparam = count_params(uniq_fun, max_param)
for i in range(max_param+1):
if rank == 0:
print('\t\tnparam = %i' % i)
sys.stdout.flush()
check_perm = (count != 0)
m = nparam == i
j = np.atleast_1d(np.squeeze(np.argwhere(m)))
f = [uniq_fun[jj] for jj in j]
e = [all_sym[jj] for jj in j]
t = [None if uniq_inv_subs[jj] is None else uniq_inv_subs[jj].copy()
for jj in j]
f, e, t = sympy_simplify(
f, e, t, i, expand_fun=False, tmax=search_tmax, check_perm=check_perm)
for k in range(len(t)):
old_fun = uniq_fun[j[k]]
uniq_fun[j[k]] = f[k]
all_sym[j[k]] = e[k]
if uniq_inv_subs[j[k]] is None:
add_inv_subs[j[k]] = t[k]
else:
add_inv_subs[j[k]] = t[k][len(uniq_inv_subs[j[k]]):]
del f, e, t, m, j, k, uniq_inv_subs, nparam
gc.collect()
# (3) Make replacements to full functions list
for i in range(len(all_fun)):
old_fun = all_fun[i]
m = match[all_fun[i]]
all_fun[i] = uniq_fun[m]
if add_inv_subs[m] is not None and len(add_inv_subs[m]) > 0:
if all_inv_subs[i] is None:
all_inv_subs[i] = add_inv_subs[m].copy()
else:
all_inv_subs[i] = all_inv_subs[i].copy() + \
add_inv_subs[m].copy()
del add_inv_subs, match
gc.collect()
if rank == 0:
print('\tMaking dict')
sys.stdout.flush()
all_sym = dict(zip(uniq_fun, all_sym))
if rank == 0:
new_nuniq = len(set(all_fun))
else:
new_nuniq = None
new_nuniq = comm.bcast(new_nuniq, root=0)
if rank == 0:
print('\tPrinting inv_subs to file')
data = [i for i in range(len(all_inv_subs))
if all_inv_subs[i] is not None]
with open(dirname + '/inv_idx_%i_round_%i.txt' % (compl, count), "w") as f:
for i in data:
print(i, file=f)
print('\tPrinting inv to file')
data = [all_inv_subs[i] for i in data]
with open(dirname + '/inv_subs_%i_round_%i.txt' % (compl, count), "w") as f:
writer = csv.writer(f, delimiter=';')
writer.writerows(data)
del data
gc.collect()
if track_memory:
utils.using_mem("end of round")
utils.locals_size(locals())
count += 1
round1_count = count
if rank == 0 and track_memory:
utils.using_mem("END")
utils.locals_size(locals())
# Expand functions
if rank == 0:
print('\nExpanding')
sys.stdout.flush()
all_sym = expand_or_factor(all_sym, tmax=expand_tmax, method='expand')
count = 0
old_nuniq = 0
# Now replace functions by their expanded form
if rank == 0:
print('\nRewriting')
sys.stdout.flush()
while old_nuniq != new_nuniq:
all_inv_subs = [None] * len(all_fun)
old_nuniq = new_nuniq
if rank == 0:
print('Optimisation', count, new_nuniq, len(all_fun))
sys.stdout.flush()
if rank == 0 and track_memory:
utils.using_mem("start")
utils.locals_size(locals())
# (1) Get unique functions and matches
if rank == 0:
print('\tGetting unique functions')
sys.stdout.flush()
uniq, match = utils.get_unique_indexes(all_fun)
uniq_fun = list(uniq.keys())
if rank == 0:
if track_memory:
utils.using_mem("end")
utils.locals_size(locals())
print('\tGetting unique sympy')
sys.stdout.flush()
all_sym = [all_sym[u] for u in uniq_fun]
if rank == 0:
print('\tGetting unique inverse subs')
sys.stdout.flush()
uniq_inv_subs = [all_inv_subs[i] for i in uniq.values()]
del uniq
gc.collect()
# (2) Simplify the unique functions
add_inv_subs = [None] * len(uniq_inv_subs)
nparam = count_params(uniq_fun, max_param)
for i in range(max_param+1):
if rank == 0:
print('\t\tnparam = %i' % i)
sys.stdout.flush()
check_perm = True
m = nparam == i
j = np.atleast_1d(np.squeeze(np.argwhere(m)))
f = [uniq_fun[jj] for jj in j]
e = [all_sym[jj] for jj in j]
t = [None if uniq_inv_subs[jj] is None else uniq_inv_subs[jj].copy()
for jj in j]
f, e, t = sympy_simplify(
f, e, t, i, expand_fun=True, tmax=search_tmax)
for k in range(len(t)):
old_fun = uniq_fun[j[k]]
uniq_fun[j[k]] = f[k]
all_sym[j[k]] = e[k]
if uniq_inv_subs[j[k]] is None:
add_inv_subs[j[k]] = t[k]
else:
add_inv_subs[j[k]] = t[k][len(uniq_inv_subs[j[k]]):]
del f, e, t, m, j, k, uniq_inv_subs, nparam
gc.collect()
# (3) Make replacements to full functions list
for i in range(len(all_fun)):
old_fun = all_fun[i]
m = match[all_fun[i]]
all_fun[i] = uniq_fun[m]
if add_inv_subs[m] is not None and len(add_inv_subs[m]) > 0:
if all_inv_subs[i] is None:
all_inv_subs[i] = add_inv_subs[m].copy()
else:
all_inv_subs[i] = all_inv_subs[i].copy() + \
add_inv_subs[m].copy()
del add_inv_subs, match
gc.collect()
if rank == 0:
print('\tMaking dict')
sys.stdout.flush()
all_sym = dict(zip(uniq_fun, all_sym))
if rank == 0:
new_nuniq = len(set(all_fun))
else:
new_nuniq = None
new_nuniq = comm.bcast(new_nuniq, root=0)
if rank == 0:
print('\tPrinting inv_subs to file')
data = [i for i in range(len(all_inv_subs))
if all_inv_subs[i] is not None]
with open(dirname + '/inv_idx_%i_round_%i.txt' % (compl, round1_count + count), "w") as f:
for i in data:
print(i, file=f)
print('\tPrinting inv to file')
data = [all_inv_subs[i] for i in data]
with open(dirname + '/inv_subs_%i_round_%i.txt' % (compl, round1_count + count), "w") as f:
writer = csv.writer(f, delimiter=';')
writer.writerows(data)
del data
gc.collect()
if track_memory:
utils.using_mem("end of round")
utils.locals_size(locals())
count += 1
if rank == 0:
print('\nFinal factorisation')
sys.stdout.flush()
all_sym = expand_or_factor(all_sym, tmax=expand_tmax, method='factor')
if rank == 0 and track_memory:
utils.using_mem("END")
utils.locals_size(locals())
return all_fun, all_sym, round1_count + count
[docs]
def get_all_dup(max_param):
"""Finds self-inverse transformations of parameters, to be used
in simplify_inv_subs(inv_subs, all_dup)
Args:
:max_param (int): maximum number of parameters to consider
Returns:
:all_dup (list): list of dictionaries giving subsitutions which are self-inverse
"""
if max_param == 0:
return []
param_list = ['a%i' % i for i in range(max_param)]
all_a = sympy.symbols(" ".join(param_list), real=True)
if max_param == 1:
all_a = [all_a]
all_dup = [str({a: -a}) for a in all_a]
all_dup += [str({a: 1/a}) for a in all_a]
comb = list(itertools.combinations(np.flip(np.arange(max_param)), 2))
all_dup += [str({all_a[c[0]]: all_a[c[1]], all_a[c[1]]: all_a[c[0]]})
for c in comb]
all_dup += [str({all_a[c[1]]: all_a[c[0]], all_a[c[0]]: all_a[c[1]]})
for c in comb]
return all_dup
[docs]
def simplify_inv_subs(inv_subs, all_dup):
"""Find if two consecutive {a0: -a0} or {a0: a1, a1: a0} or {a0: 1/a0}
and then remove both of these
Args:
:inv_subs (list): list of dictionaries giving subsitutions to check
:all_dup (list): list of dictionaries giving subsitutions which are self-inverse
Returns:
:all_subs (list): list of dictionaries giving subsitutions without consecutive self-inverses
"""
if inv_subs is None or len(inv_subs) == 0:
return inv_subs
del_idx = []
i = 0
while i < len(inv_subs) - 1:
if inv_subs[i] in all_dup:
if inv_subs[i+1] == inv_subs[i]:
del_idx.append(i)
del_idx.append(i+1)
i += 2
else:
i += 1
else:
i += 1
new_inv = [inv_subs[i] for i in range(len(inv_subs)) if i not in del_idx]
if len(new_inv) == 0:
new_inv = None
return new_inv
[docs]
def count_lines(fname):
"""
Count the number of lines in a file.
Args:
:fname (str): file name to count lines in
Returns:
:int: number of lines in the file
"""
with open(fname, 'r') as f:
return sum(1 for _ in f)
[docs]
def get_line_range(n_lines):
"""
Return (imin, imax) inclusive range of lines for a given rank.
Args:
:n_lines (int): total number of lines in the file
Returns:
:tuple: (imin, imax) where imin is the first line index for this rank and
imax is the exclusive last line index for this rank
"""
counts = [n_lines // size + (1 if i < n_lines % size else 0)
for i in range(size)]
offsets = np.cumsum([0] + counts[:-1])
imin = offsets[rank]
imax = imin + counts[rank] # exclusive
return imin, imax
[docs]
def load_subs(fname, max_param, use_sympy=True, bcast_res=True):
"""Load the subsitutions required to convert between all and unique functions
Args:
:fname (str): file name containing the subsitutions
:max_param (int): maximum number of parameters to consider
:use_sympy (bool, default=True): whether to convert substituions to sympy objects (True) or leave as strings (False)
:bcast_res (bool, default=True): whether to allow all ranks to have the substitutions (True) or just the 0th rank (False)
Returns:
:all_subs (dict): dict of substitutions required to convert between all and unique functions.
Each item is either a dictionary with sympy objects as keys and values (use_sympy=True) or
a string version of this dictionary (use_sympy=False). If bcast_res=True, then all ranks have this dict,
otherwise all ranks receive a chunk of the dict corresponding to their rank.
"""
if rank == 0:
n_lines = count_lines(fname)
else:
n_lines = None
n_lines = comm.bcast(n_lines, root=0)
imin, imax = get_line_range(n_lines)
all_subs = {} # Use a dict instead of a list
with open(fname, 'r') as f:
for i, line in enumerate(f):
if i >= imax:
break
if i >= imin:
sub = line.strip().split(';')
if sub != ['']:
all_subs[i] = sub
param_list = ['a%i' % i for i in range(max_param)]
all_a = sympy.symbols(" ".join(param_list), real=True)
if max_param == 1:
all_a = [all_a]
locs = sympy_locs
if max_param > 0:
for i in range(len(all_a)):
locs["a%i" % i] = all_a[i]
for i in all_subs.keys():
for j in range(len(all_subs[i])):
all_subs[i][j] = all_subs[i][j].replace("{", "{'")
all_subs[i][j] = all_subs[i][j].replace("}", "'}")
all_subs[i][j] = all_subs[i][j].replace(", ", "', '")
all_subs[i][j] = all_subs[i][j].replace(": ", "': '")
if all_subs[i][j] == 'nan':
all_subs[i][j] = np.nan
else:
d = ast.literal_eval(all_subs[i][j])
k = list(d.keys())
v = list(d.values())
k = [sympy.sympify(kk, locals=locs) for kk in k]
v = [sympy.sympify(vv, locals=locs) for vv in v]
all_subs[i][j] = dict(zip(k, v))
if not use_sympy:
all_subs[i][j] = str(all_subs[i][j])
comm.Barrier()
if bcast_res:
gathered = comm.gather(all_subs, root=0)
if rank == 0:
all_subs = {}
for d in gathered:
all_subs.update(d)
# [all_subs.get(i, []) for i in range(total_lines)]
all_subs = comm.bcast(all_subs, root=0)
# Fix MPI4PY bug for empty lists
# if isinstance(all_subs, int):
# all_subs = [[] for _ in range(all_subs)]
return all_subs
[docs]
def convert_params(p_meas, fish_meas, inv_subs, n=4):
"""Convert parameters from those in unique function to those in actual function
Args:
:p_meas (list): list of measured parameters in unique function
:fish_meas (list): flattened version of the Hessian of -log(likelihood) at the maximum likelihood point
:inv_subs (list): list of substitutions required to convert between all and unique functions
:n (int, default=4): the number of dimensions of the array from which fish_meas was computed using
Returns:
:p_new (list): list of parameters for the actual function
:diag_fish (np.array): the diagonal entries of the Fisher matrix of the actual function at the maximum likelihood point
"""
max_param = len(p_meas)
if np.nan in inv_subs:
return np.array([np.nan]*max_param), np.array([np.nan]*max_param)
fish = np.zeros((n, n))
fish[np.triu_indices(n)] = fish_meas
fish = np.where(fish, fish, fish.T)
fish = fish[:max_param, :max_param]
param_list = ['a%i' % i for i in range(max_param)]
all_a = sympy.symbols(" ".join(param_list), real=True)
if max_param == 1:
all_a = [all_a]
p = sympy.Array(sympy.symbols(" ".join(param_list), real=True))
for i in range(len(inv_subs)):
p = p.subs(inv_subs[i], simultaneous=True)
jac = sympy.Matrix(p).jacobian(all_a)
if max_param == 1:
p_lam = sympy.lambdify(all_a[0], str(p))
else:
p_lam = sympy.lambdify(all_a[:len(p_meas)], p)
p_new = p_lam(*p_meas)
j_lam = sympy.lambdify(all_a[:len(p_meas)], jac)
j = j_lam(*p_meas)
jinv = np.linalg.inv(j)
fish_new = np.dot(jinv.T, np.dot(fish, jinv))
diag_fish = np.array([fish_new[i, i] for i in range(fish_new.shape[0])])
return p_new, diag_fish
[docs]
def check_results(dirname, compl, tmax=10):
"""Check that all functions can be recovered by applying the subsitutions to the unique functions.
If not, define a new unique function and save results to file.
Args:
:dirname (str): name of directory containing all the functions to consider
:compl (int): complexity of functions to consider
:tmax (float, default=10.): maximum time in seconds to run the substitutions
Returns:
None
"""
if rank == 0:
print('\tLoading all equations', flush=True)
with open(dirname + '/all_equations_%i.txt' % compl, 'r') as f:
all_fun = f.read().splitlines()
max_param = get_max_param(all_fun)
else:
all_fun = None
max_param = None
max_param = comm.bcast(max_param, root=0)
if rank == 0:
print('\tLoading inverse subs')
with open(dirname + '/inv_subs_%i.txt' % compl, 'r') as f:
reader = csv.reader(f, delimiter=';')
inv_subs = [row for row in reader]
# Only need functions with non-trivial inverse subs
shufidx = np.array(
[i for i in range(len(inv_subs)) if len(inv_subs[i]) != 0])
np.random.seed(1234)
# Shuffle to make each rank more similar
np.random.shuffle(shufidx)
all_fun = [all_fun[ii] for ii in shufidx]
nfun = len(all_fun)
inv_subs = [inv_subs[ii] for ii in shufidx]
else:
nfun = None
comm.Barrier()
nfun = comm.bcast(nfun, root=0)
imin, imax = utils.split_idx(nfun, rank, size)
imax += 1
imin = comm.gather(imin, root=0)
imax = comm.gather(imax, root=0)
if rank == 0:
all_fun = [all_fun[imin[i]:imax[i]] for i in range(size)]
inv_subs = [inv_subs[imin[i]:imax[i]] for i in range(size)]
else:
all_fun = None
inv_subs = None
all_fun = comm.scatter(all_fun, root=0)
inv_subs = comm.scatter(inv_subs, root=0)
all_nparam = count_params(all_fun, max_param)
if rank == 0:
print('\tLoading unique equations', flush=True)
with open(dirname + '/unique_equations_%i.txt' % compl, 'r') as f:
uniq_fun = f.read().splitlines()
else:
uniq_fun = None
uniq_fun = comm.bcast(uniq_fun, root=0)
uniq_nparam = count_params(uniq_fun, max_param)
if rank == 0:
print('\tLoading matches')
matches = np.loadtxt(dirname + '/matches_%i.txt' % compl).astype(int)
matches = matches[shufidx]
matches = np.array_split(matches, size)
else:
matches = None
matches = comm.scatter(matches, root=0)
param_list = ['a%i' % i for i in range(max_param)]
all_a = sympy.symbols(" ".join(param_list), real=True)
if max_param == 1:
all_a = [all_a]
locs = sympy_locs
if max_param > 0:
for i in range(len(all_a)):
locs["a%i" % i] = all_a[i]
to_change = []
imin, imax = utils.split_idx(nfun, rank, size)
for i in range(len(all_fun)):
if rank == 0 and (i % 100) == 0:
print(i, len(all_fun))
if all_nparam[i] != uniq_nparam[matches[i]]:
continue
s1 = sympy.sympify(all_fun[i], locals=locs)
try:
s2 = sympy.sympify(uniq_fun[matches[i]], locals=locs)
except Exception:
print(
f'Could not check {uniq_fun[matches[i]]} so will keep equation')
s2 = None
p = sympy.Array(sympy.symbols(" ".join(param_list), real=True))
try:
if s2 is None:
raise ValueError
with time_limit(tmax):
for j in range(len(inv_subs[i])):
inv_subs[i][j] = inv_subs[i][j].replace("{", "{'")
inv_subs[i][j] = inv_subs[i][j].replace("}", "'}")
inv_subs[i][j] = inv_subs[i][j].replace(", ", "', '")
inv_subs[i][j] = inv_subs[i][j].replace(": ", "': '")
d = ast.literal_eval(inv_subs[i][j])
k = list(d.keys())
v = list(d.values())
k = [sympy.sympify(kk, locals=locs) for kk in k]
v = [sympy.sympify(vv, locals=locs) for vv in v]
p = p.subs(dict(zip(k, v)), simultaneous=True)
sub = {all_a[j]: p[j] for j in range(len(all_a))}
s1 = s1.subs(sub, simultaneous=True)
if (not str(s1) == str(s2)) and (not s1.equals(s2)):
raise ValueError
except Exception:
to_change.append([i+imin, all_fun[i]])
del inv_subs, all_fun
gc.collect()
to_change = comm.gather(to_change, root=0)
if rank == 0:
to_change = list(itertools.chain(*to_change))
# Change indices to how they were before
for r in to_change:
r[0] = shufidx[r[0]]
del shufidx
print('\nNeed to change %i functions' % len(to_change))
for r in to_change:
print(r)
print('\nLoading all equations', flush=True)
with open(dirname + '/all_equations_%i.txt' % compl, 'r') as f:
all_fun = f.read().splitlines()
for r in to_change:
r[1] = all_fun[r[0]]
del all_fun
gc.collect()
print('\nAppending new unique equations')
with open(dirname + '/unique_equations_%i.txt' % compl, 'r') as f:
uniq_fun = f.read().splitlines()
nuniq = len(uniq_fun)
new_fun = [r[1] for r in to_change]
new_uniq, new_match = utils.get_unique_indexes(new_fun)
new_uniq_fun = list(new_uniq.keys())
with open(dirname + '/unique_equations_%i.txt' % compl, 'w') as f:
w = 80
pp = pprint.PrettyPrinter(width=w, stream=f)
for s in uniq_fun:
if len(s + '\n') > w / 2:
w = 2 * len(s)
pp = pprint.PrettyPrinter(width=w, stream=f)
pp.pprint(s)
for s in new_uniq_fun:
if len(s + '\n') > w / 2:
w = 2 * len(s)
pp = pprint.PrettyPrinter(width=w, stream=f)
pp.pprint(s)
del uniq_fun
gc.collect()
s = "sed 's/.$//; s/^.//' %s/%s%i.txt > %s/temp_%i.txt" % (
dirname, 'unique_equations_', compl, dirname, compl)
os.system(s)
s = "mv %s/temp_%i.txt %s/%s%i.txt" % (dirname,
compl, dirname, 'unique_equations_', compl)
os.system(s)
print('\nChanging inverse subs')
with open(dirname + '/inv_subs_%i.txt' % compl, 'r') as f:
reader = csv.reader(f, delimiter=';')
inv_subs = [row for row in reader]
for r in to_change:
inv_subs[r[0]] = ""
with open(dirname + '/inv_subs_%i.txt' % compl, 'w') as f:
writer = csv.writer(f, delimiter=';')
writer.writerows(inv_subs)
del inv_subs
gc.collect()
print('\nChanging matches')
matches = np.loadtxt(dirname + '/matches_%i.txt' % compl).astype(int)
for i in range(len(to_change)):
matches[to_change[i][0]] = nuniq + new_match[to_change[i][1]]
np.savetxt(dirname + '/matches_%i.txt' % compl, matches)
return