r"""
Description: Processes a binary tif stack, with the option to remove
disconnected voxels based on an undirected graph. The number of clusters
to retain can be specified and connectivity is defined on a 26 point basis,
i.e faces, edges and corners. Standard outputs include the processed
tif image stack, an aperture map and offset map based on the processed image.
Offset maps are filtered based on gradient steepness to provide a smoother
surface. Data gaps left by zero apeture zones or filtering are filled by
linear and nearest interpolation methods to prevent artificial features.
For usage information run: ``apm_process_image_stack -h``
| Written By: Matthew stadelman
| Date Written: 2016/08/30
| Last Modfied: 2017/04/23
|
"""
import argparse
from argparse import RawDescriptionHelpFormatter as RawDesc
from itertools import product
import os
import scipy as sp
from scipy import sparse as sprs
from scipy.sparse import csgraph
from scipy.interpolate import griddata
from apmapflow import _get_logger, set_main_logger_level
from apmapflow import DataField, calc_percentile, FractureImageStack
# setting up logger
set_main_logger_level('info')
logger = _get_logger('apmapflow.scripts')
# creating arg parser
parser = argparse.ArgumentParser(description=__doc__, formatter_class=RawDesc)
# adding arguments
parser.add_argument('-f', '--force', action='store_true',
help='allows program to overwrite existing files')
parser.add_argument('-v', '--verbose', action='store_true',
help='debug messages are printed to the screen')
parser.add_argument('-o', '--output-dir',
type=os.path.realpath, default=os.getcwd(),
help='''outputs files to the specified
directory, sub-directories are created as needed''')
parser.add_argument('-n', '--num-clusters', type=int, default=None,
help='''number of clusters to retain, ordered by size
is option is disabled by default''')
parser.add_argument('-i', '--invert', action='store_true',
help='use this flag if your fracture is in black')
parser.add_argument('--offset-map-name', default=None,
help='alternate name to save the offset map as')
parser.add_argument('--aper-map-name', default=None,
help='alternate name to save the aperture map as')
parser.add_argument('--img-stack-name', default=None,
help='''alternate name to save the tiff stack as,
has no effect if the -n # flag is omitted''')
parser.add_argument('--gen-cluster-img', action='store_true',
help='generates a tiff image colored by cluster number')
parser.add_argument('--no-aper-map', action='store_true',
help='do not generate aperture map')
parser.add_argument('--no-offset-map', action='store_true',
help='do not generate offset map')
parser.add_argument('--no-img-stack', action='store_true',
help='''do not save a processed tif stack,
has no effect when the -n # flag is omitted''')
parser.add_argument('image_file', type=os.path.realpath,
help='binary TIFF stack image to process')
[docs]def main():
r"""
Driver program to load an image and generate maps. Memory
requirements when processing a large TIFF stack can be very high.
"""
# parsing commandline args
args = parser.parse_args()
if args.verbose:
set_main_logger_level('debug')
#
# initializing output filenames as needed and pre-appending the output path
img_basename = os.path.basename(args.image_file)
img_basename = os.path.splitext(img_basename)[0]
if args.aper_map_name is None:
args.aper_map_name = img_basename + '-aperture-map.txt'
#
if args.offset_map_name is None:
args.offset_map_name = img_basename + '-offset-map.txt'
#
if args.img_stack_name is None:
args.img_stack_name = img_basename + '-processed.tif'
#
aper_map_file = os.path.join(args.output_dir, args.aper_map_name)
offset_map_file = os.path.join(args.output_dir, args.offset_map_name)
img_stack_file = os.path.join(args.output_dir, args.img_stack_name)
#
# checking paths
if not args.no_aper_map:
if os.path.exists(aper_map_file) and not args.force:
msg = '{} already exists, use "-f" option to overwrite'
raise FileExistsError(msg.format(aper_map_file))
#
if not args.no_offset_map:
if os.path.exists(offset_map_file) and not args.force:
msg = '{} already exists, use "-f" option to overwrite'
raise FileExistsError(msg.format(offset_map_file))
#
if not args.no_img_stack:
if os.path.exists(img_stack_file) and not args.force:
msg = '{} already exists, use "-f" option to overwrite'
raise FileExistsError(msg.format(img_stack_file))
#
# loading image data
logger.info('loading image...')
img_data = FractureImageStack(args.image_file)
if args.invert:
logger.debug('inverting image data')
img_data = ~img_data
logger.debug('image dimensions: {} {} {}'.format(*img_data.shape))
#
# processing image stack based on connectivity
if args.num_clusters:
kwargs = {
'output_img': args.gen_cluster_img,
'img_name': os.path.splitext(img_stack_file)[0] + '-clusters.tif',
'img_shape': img_data.shape
}
img_data = process_image(img_data, args.num_clusters, **kwargs)
#
# outputing aperture map
if not args.no_aper_map:
aper_map = img_data.create_aperture_map()
logger.info('saving aperture map file')
sp.savetxt(aper_map_file, aper_map, fmt='%d', delimiter='\t')
del aper_map
#
# outputing offset map
if not args.no_offset_map:
offset_map = calculate_offset_map(img_data)
#
# saving map
logger.info('saving offset map file')
sp.savetxt(offset_map_file, offset_map, fmt='%f', delimiter='\t')
del offset_map
#
# saving image data
if args.num_clusters and not args.no_img_stack:
logger.info('saving copy of processed image data')
img_data.save(img_stack_file, overwrite=args.force)
[docs]def process_image(img_data, num_clusters, **kwargs):
r"""
Processes a tiff stack on retaining voxels based on node connectivity.
The clusters are sorted by size and the large N are retained.
"""
#
img_dims = img_data.shape
nonzero_locs = img_data.get_fracture_voxels()
index_map = generate_index_map(nonzero_locs, img_dims)
#
# determing connectivity and removing clusters
conns = generate_node_connectivity_array(index_map, img_data)
del img_data, index_map
nonzero_locs = remove_isolated_clusters(conns,
nonzero_locs,
num_clusters,
**kwargs)
# reconstructing 3-D array
logger.info('reconstructing processed data back into 3-D array')
#
img_data = sp.zeros(img_dims, dtype=bool)
x_coords, y_coords, z_coords = sp.unravel_index(nonzero_locs, img_dims)
#
del nonzero_locs
img_data[x_coords, y_coords, z_coords] = True
del x_coords, y_coords, z_coords
#
return img_data.view(FractureImageStack)
[docs]def calculate_offset_map(img_data):
r"""
Handles calculation of an offset map based on image data
"""
#
logger.info('creating initial offset map')
offset_map = img_data.create_offset_map(no_data_fill=sp.nan)
#
logger.info('interpolating missing data due to zero aperture zones')
offset_map = patch_holes(offset_map)
offset_map = filter_high_gradients(offset_map)
#
return offset_map
[docs]def generate_index_map(nonzero_locs, shape):
r"""
Determines the i,j,k indicies of the flattened array
"""
#
logger.info('creating index map of non-zero values...')
x_c = sp.unravel_index(nonzero_locs, shape)[0].astype(sp.int16)
y_c = sp.unravel_index(nonzero_locs, shape)[1].astype(sp.int16)
z_c = sp.unravel_index(nonzero_locs, shape)[2].astype(sp.int16)
index_map = sp.stack((x_c, y_c, z_c), axis=1)
#
return index_map
[docs]def generate_node_connectivity_array(index_map, data_array):
r"""
Generates a node connectivity array based on faces, edges and corner
adjacency
"""
#
logger.info('generating network connections...')
#
# setting up some constants
x_dim, y_dim, z_dim = data_array.shape
conn_map = list(product([0, -1, 1], [0, -1, 1], [0, -1, 1]))
#
conn_map = sp.array(conn_map, dtype=int)
conn_map = conn_map[1:]
#
# creating slice list to process data chunks
slice_list = [slice(0, 10000)]
for i in range(slice_list[0].stop, index_map.shape[0], slice_list[0].stop):
slice_list.append(slice(i, i+slice_list[0].stop))
slice_list[-1] = slice(slice_list[-1].start, index_map.shape[0])
#
conns = sp.ones((0, 2), dtype=data_array.index_int_type)
logger.debug('\tnumber of slices to process: {}'.format(len(slice_list)))
percent = 10
for n, sect in enumerate(slice_list):
# getting coordinates of nodes and their neighbors
nodes = index_map[sect]
inds = sp.repeat(nodes, conn_map.shape[0], axis=0)
inds += sp.tile(conn_map, (nodes.shape[0], 1))
#
# calculating the flattened index of the central nodes and storing
nodes = sp.ravel_multi_index(sp.hsplit(nodes, 3), data_array.shape)
inds = sp.hstack([inds, sp.repeat(nodes, conn_map.shape[0], axis=0)])
#
# removing neigbors with negative indicies
mask = ~inds[:, 0:3] < 0
inds = inds[sp.sum(mask, axis=1) == 3]
# removing neighbors with indicies outside of bounds
mask = (inds[:, 0] < x_dim, inds[:, 1] < y_dim, inds[:, 2] < z_dim)
mask = sp.stack(mask, axis=1)
inds = inds[sp.sum(mask, axis=1) == 3]
# removing indices with zero-weight connection
mask = data_array[inds[:, 0], inds[:, 1], inds[:, 2]]
inds = inds[mask]
if inds.size:
# calculating flattened index of remaining nieghbor nodes
nodes = sp.ravel_multi_index(sp.hsplit(inds[:, 0:3], 3),
data_array.shape)
inds = sp.hstack([sp.reshape(inds[:, -1], (-1, 1)), nodes])
# ensuring conns[0] is always < conns[1] for duplicate removal
mask = inds[:, 0] > inds[:, 1]
inds[mask] = inds[mask][:, ::-1]
# appending section connectivity data to conns array
conns = sp.append(conns, inds.astype(sp.uint32), axis=0)
if int(n/len(slice_list)*100) == percent:
logger.debug('\tprocessed slice {:5d}, {}% complete'.format(n, percent))
percent += 10
#
# using scipy magic from stackoverflow to remove dupilcate connections
logger.info('removing duplicate connections...')
dim0 = conns.shape[0]
conns = sp.ascontiguousarray(conns)
dtype = sp.dtype((sp.void, conns.dtype.itemsize*conns.shape[1]))
dim1 = conns.shape[1]
conns = sp.unique(conns.view(dtype)).view(conns.dtype).reshape(-1, dim1)
logger.debug('\tremoved {} duplicates'.format(dim0 - conns.shape[0]))
#
return conns
[docs]def generate_adjacency_matrix(conns, nonzero_locs):
r"""
generates a ajacency matrix based on connectivity array
"""
msg = 're-indexing connections array from absolute to relative indicies'
logger.info(msg)
mapper = sp.ones(nonzero_locs[-1]+1, dtype=sp.uint32) * sp.iinfo(sp.uint32).max
mapper[nonzero_locs] = sp.arange(nonzero_locs.size, dtype=sp.uint32)
conns[:, 0] = mapper[conns[:, 0]]
conns[:, 1] = mapper[conns[:, 1]]
del mapper
#
logger.info('creating adjacency matrix...')
num_blks = nonzero_locs.size
row = sp.append(conns[:, 0], conns[:, 1])
col = sp.append(conns[:, 1], conns[:, 0])
weights = sp.ones(conns.size) # using size automatically multiplies by 2
#
# Generate sparse adjacency matrix in 'coo' format and convert to csr
adj_mat = sprs.coo_matrix((weights, (row, col)), (num_blks, num_blks))
adj_mat = adj_mat.tocsr()
#
return adj_mat
[docs]def remove_isolated_clusters(conns, nonzero_locs, num_to_keep, **kwargs):
r"""
Identifies and removes all disconnected clusters except the number of
groups specified by "num_to_keep". num_to_keep=N retains the N largest
clusters
"""
#
adj_mat = generate_adjacency_matrix(conns, nonzero_locs)
#
logger.info('determining connected components...')
cs_ids = csgraph.connected_components(csgraph=adj_mat, directed=False)[1]
groups, counts = sp.unique(cs_ids, return_counts=True)
order = sp.argsort(counts)[::-1]
groups = groups[order]
counts = counts[order]
del adj_mat, order
num_to_keep = min(num_to_keep, groups.size)
#
msg = '\t{} component groups for {} total nodes'
logger.debug(msg.format(groups.size, cs_ids.size))
msg = '\tlargest group number: {}, size {}'
logger.debug(msg.format(groups[0], counts[0]))
msg = '\t{} % of nodes contained in largest group'
logger.debug(msg.format(counts[0]/cs_ids.size*100))
msg = '\t{} % of nodes contained in {} retained groups'
num = sp.sum(counts[0:num_to_keep])/cs_ids.size*100
logger.debug(msg.format(num, num_to_keep))
#
# creating image colored by clusters if desired
if kwargs.get('output_img', False):
save_cluster_image(cs_ids,
groups,
counts,
nonzero_locs,
kwargs.get('img_shape'),
kwargs.get('img_name'))
#
inds = sp.where(sp.in1d(cs_ids, groups[0:num_to_keep]))[0]
del cs_ids, groups, counts
#
num = nonzero_locs.size
nonzero_locs = nonzero_locs[inds]
msg = '\tremoved {} disconnected nodes'
logger.debug(msg.format(num - nonzero_locs.size))
#
return nonzero_locs
[docs]def save_cluster_image(cs_ids, groups, counts, locs, img_shape, img_name):
r"""
Saves an 8 bit image colored by cluster number
"""
logger.info('creating tiff image file colored by cluster number')
#
msg = '\t{} % of nodes covered in {} colored groups'
num_cs = min(16, groups.size)
num = sp.sum(counts[0:num_cs])/cs_ids.size*100
logger.debug(msg.format(num, num_cs))
#
# setting the top 16 groups separated by increments of 8 and the rest are 255
data = sp.ones(cs_ids.size, dtype=sp.uint8) * 255
for n, cs_id in enumerate(groups[0:num_cs-1]):
inds = sp.where(cs_ids == cs_id)[0]
data[inds] = 67 + n * 8
#
x_coords, y_coords, z_coords = sp.unravel_index(locs, img_shape)
img_data = sp.zeros(img_shape, dtype=sp.uint8)
img_data[x_coords, y_coords, z_coords] = data
# save image data
img_data = img_data.view(FractureImageStack)
logger.info('saving image cluster data to file' + img_name)
img_data.save(img_name, overwrite=True)
[docs]def patch_holes(data_map):
r"""
Fills in any areas with a non finite value by taking a linear average of
the nearest non-zero values along each axis
"""
#
# getting coordinates of all valid data points
data_vector = sp.ravel(data_map)
inds = sp.where(sp.isfinite(data_vector))[0]
points = sp.unravel_index(inds, data_map.shape)
values = data_vector[inds]
#
# linearly interpolating data to fill gaps
xi = sp.where(~sp.isfinite(data_vector))[0]
msg = '\tattempting to fill %d values with a linear interpolation'
logger.debug(msg, xi.size)
xi = sp.unravel_index(xi, data_map.shape)
intrp = griddata(points, values, xi, fill_value=sp.nan, method='linear')
data_map[xi[0], xi[1]] = intrp
#
# performing a nearest interpolation any remaining regions
data_vector = sp.ravel(data_map)
xi = sp.where(~sp.isfinite(data_vector))[0]
msg = '\tfilling %d remaining values with a nearest interpolation'
logger.debug(msg, xi.size)
xi = sp.unravel_index(xi, data_map.shape)
intrp = griddata(points, values, xi, fill_value=0, method='nearest')
data_map[xi[0], xi[1]] = intrp
#
return data_map
[docs]def filter_high_gradients(data_map):
r"""
Filters the offset field to reduce the number of very steep gradients.
The magnitude of the gradient is taken and all values less than or
greater than +-99th percentile are removed and recalculated.
"""
#
logger.info('filtering offset map to remove steeply sloped cells')
#
zdir_grad, xdir_grad = sp.gradient(data_map)
mag = sp.sqrt(zdir_grad**2 + xdir_grad**2)
data_map += 1
data_vector = sp.ravel(data_map)
#
# setting regions outside of 99th percentile to 0 for cluster removal
val = calc_percentile(99, sp.ravel(mag))
data_map[zdir_grad < -val] = 0
data_map[zdir_grad > val] = 0
data_map[xdir_grad < -val] = 0
data_map[xdir_grad > val] = 0
#
logger.debug('\tremoving clusters isolated by high gradients')
offsets = DataField(data_map)
adj_mat = offsets.create_adjacency_matrix()
cs_num, cs_ids = csgraph.connected_components(csgraph=adj_mat,
directed=False)
cs_num, counts = sp.unique(cs_ids, return_counts=True)
cs_num = cs_num[sp.argsort(counts)][-1]
#
data_vector[sp.where(cs_ids != cs_num)[0]] = sp.nan
data_map = sp.reshape(data_vector, data_map.shape)
#
# re-interpolating for the nan regions
logger.debug('\tpatching holes left by cluster removal')
patch_holes(data_map)
#
return data_map