# Copyright 2017-2023 Tom Eulenfeld, MIT license
"""Preprocessing and correlation"""
import yam.stack
from yam.util import _filter, IterTime, smooth as smooth_func, _time2sec
from obspy.signal.cross_correlation import correlate as obscorr
from functools import partial
import itertools
import logging
import multiprocessing
import os
import numpy as np
import obspy
from obspy.core import Stream
from obspy.geodetics import gps2dist_azimuth
import scipy.signal
from scipy.signal import freqz, iirfilter, hilbert
_USE_FFTWS = os.environ.get('YAM_FFTW', 'no').lower() not in ('no', 'false')
if _USE_FFTWS:
try:
import pyfftw
pyfftw.interfaces.cache.enable()
pyfftw.interfaces.cache.set_keepalive_time(100)
except ImportError:
_USE_FFTWS = False
if _USE_FFTWS:
scipy.fftpack = pyfftw.interfaces.scipy_fftpack
np.fft = pyfftw.interfaces.numpy_fft
from pyfftw.interfaces.scipy_fftpack import (fft, ifft, fftshift,
ifftshift, next_fast_len)
from pyfftw.interfaces.numpy_fft import rfft, irfft, rfftfreq
else:
from numpy.fft import rfft, irfft, rfftfreq
from scipy.fftpack import fft, ifft, fftshift, ifftshift, next_fast_len
log = logging.getLogger('yam.correlate')
def start_parallel_jobs_inner_loop(tasks, do_work, njobs=0):
if njobs == 0:
results = [do_work(task) for task in tasks]
else:
pool = multiprocessing.Pool(njobs)
results = pool.map(do_work, tasks)
pool.close()
pool.join()
return results
def _fill_array(data, mask=None, fill_value=None):
"""
Mask numpy array and/or fill array value without demasking.
Additionally set fill_value to value.
If data is not a MaskedArray and mask is None returns silently data.
:param mask: apply mask to array
:param fill_value: fill value
"""
if mask is not None and mask is not False:
data = np.ma.MaskedArray(data, mask=mask, copy=False)
if np.ma.is_masked(data) and fill_value is not None:
data._data[data.mask] = fill_value
np.ma.set_fill_value(data, fill_value)
# elif not np.ma.is_masked(data):
# data = np.ma.filled(data)
return data
[docs]def time_norm(tr, method,
clip_factor=None, clip_set_zero=None,
clip_value=2, clip_std=True, clip_mode='clip',
mute_parts=48, mute_factor=2, plugin=None,
plugin_options={}):
"""
Calculate normalized data, see e.g. Bensen et al. (2007)
:param tr: Trace to manipulate
:param str method:
1bit: reduce data to +1 if >0 and -1 if <0\n
clip: clip data to value or multiple of root mean square (rms)\n
mute_envelope: calculate envelope and set data to zero where envelope
is larger than specified
plugin: use own function
:param mask_zeros: mask values that are set to zero, they will stay zero
in the further processing
:param float clip_value: value for clipping or list of lower and upper
value
:param bool clip_std: Multiply clip_value with rms of data
:param bool clip_mode:
'clip': clip data
'zero': set clipped data to zero
'mask': set clipped data to zero and mask it
:param int mute_parts: mean of the envelope is calculated by dividing the
envelope into several parts, the mean calculated in each part and
the median of this averages defines the mean envelope
:param float mute_factor: mean of envelope multiplied by this
factor defines the level for muting
:param str plugin: function in the form module:func
:param dict plugin_options: kwargs passed to plugin
:return: normalized data
"""
data = tr.data
data = _fill_array(data, fill_value=0)
mask = np.ma.getmask(data)
if method == '1bit':
np.sign(data, out=data)
elif method == 'clip':
if clip_factor is not None:
from warnings import warn
msg = 'clip_factor is deprecated, use clip_value instead'
warn(msg, DeprecationWarning)
clip_value = clip_factor
if clip_set_zero is not None:
from warnings import warn
msg = 'clip_set_zero is deprecated, use clip_mode instead'
warn(msg, DeprecationWarning)
clip_mode = 'zero' if clip_set_zero else 'clip'
from collections.abc import Iterable
if not isinstance(clip_value, Iterable):
clip_value = [-clip_value, clip_value]
if clip_std:
std = np.std(data)
clip_value = [clip_value[0] * std, clip_value[1] * std]
if clip_mode == 'clip':
np.clip(data, *clip_value, out=data)
else:
cmask = np.logical_or(data < clip_value[0], data > clip_value[1])
if clip_mode == 'mask':
mask = np.logical_or(np.ma.getmaskarray(data), cmask)
elif clip_mode == 'zero':
data[cmask] = 0
else:
raise ValueError('clip_mode must be one of clip, zeros, mask')
elif method == 'mute_envelope':
N = next_fast_len(len(data))
envelope = np.abs(hilbert(data, N))[:len(data)]
levels = [np.mean(d) for d in np.array_split(envelope, mute_parts)]
level = mute_factor * np.median(levels)
data[envelope > level] = 0
elif method == 'plugin':
from yam.util import _load_func
modulename, funcname = plugin.split(':')
func = _load_func(modulename.strip(), funcname.strip())
func(tr, **plugin_options)
data = tr.data
else:
msg = 'The method passed to time_norm is not known: %s.' % method
raise ValueError(msg)
tr.data = _fill_array(data, mask=mask, fill_value=0)
return tr
# http://azitech.wordpress.com/
# 2011/03/15/designing-a-butterworth-low-pass-filter-with-scipy/
def _filter_resp(freqmin, freqmax, corners=2, zerophase=False, sr=None,
N=None, whole=False):
"""
Complex frequency response of Butterworth-Bandpass Filter.
:param freqmin: Pass band low corner frequency.
:param freqmax: Pass band high corner frequency.
:param corners: Filter corners
:param zerophase: If True, apply filter once forwards and once backwards.
This results in twice the number of corners but zero phase shift in
the resulting filtered trace.
:param sr: Sampling rate in Hz.
:param N,whole: passed to scipy.signal.freqz
:return: frequencies and complex response
"""
df = sr
fe = 0.5 * df
low = freqmin / fe
high = freqmax / fe
# raise for some bad scenarios
if high > 1:
high = 1.0
msg = "Selected high corner frequency is above Nyquist. " + \
"Setting Nyquist as high corner."
log.warning(msg)
if low > 1:
msg = "Selected low corner frequency is above Nyquist."
raise ValueError(msg)
[b, a] = iirfilter(corners, [low, high], btype='band',
ftype='butter', output='ba')
freqs, values = freqz(b, a, N, whole=whole)
if zerophase:
values *= np.conjugate(values)
return freqs, values
[docs]def spectral_whitening(tr, smooth=None, filter=None,
waterlevel=1e-8, mask_again=True):
"""
Apply spectral whitening to data
Data is divided by its smoothed (Default: None) amplitude spectrum.
:param tr: trace to manipulate
:param smooth: length of smoothing window in Hz
(default None -> no smoothing)
:param filter: filter spectrum with bandpass after whitening
(tuple with min and max frequency)
:param waterlevel: waterlevel relative to mean of spectrum
:param mask_again: weather to mask array after this operation again and
set the corresponding data to 0
:return: whitened data
"""
sr = tr.stats.sampling_rate
data = tr.data
data = _fill_array(data, fill_value=0)
mask = np.ma.getmask(data)
nfft = next_fast_len(len(data))
spec = fft(data, nfft)
spec_ampl = np.abs(spec)
spec_ampl /= np.max(spec_ampl)
if smooth:
smooth = int(smooth * nfft / sr)
spec_ampl = ifftshift(smooth_func(fftshift(spec_ampl), smooth))
# save guard against division by 0
spec_ampl[spec_ampl < waterlevel] = waterlevel
spec /= spec_ampl
if filter is not None:
spec *= _filter_resp(*filter, sr=sr, N=len(spec), whole=True)[1]
ret = np.real(ifft(spec, nfft)[:len(data)])
if mask_again:
ret = _fill_array(ret, mask=mask, fill_value=0)
tr.data = ret
return tr
def __get_stations(inventory):
channels = inventory.get_contents()['channels']
stations = sorted({ch[:-1] + '?': ch[-1] for ch in channels})
return stations
def _iter_station_meta(inventory, components):
"""
Return iterator yielding metadata per station and day.
:param inventory: |Inventory| object with station and channel information
:param components: components to yield
"""
stations = __get_stations(inventory)
for seedid in stations:
for comp in components:
net, sta, loc, cha = seedid.split('.')
cha = cha[:2] + comp
meta = {'network': net, 'station': sta, 'location': loc,
'channel': cha}
yield meta
[docs]def get_data(smeta, data, data_format, day, overlap=0, edge=0,
trim_and_merge=False):
"""Return data of one day
:param smeta: dictionary with station metadata
:param data: string with expression of data day files or
function that returns the data (aka get_waveforms)
:param data_format: format of data
:param day: day as |UTC| object
:param overlap: overlap to next day in seconds
:param edge: additional time span requested from day before and after
in seconds
:param trim_and_merge: weather data is trimmed to day boundaries and merged
"""
next_day = day + 24 * 3600
if not isinstance(data, str):
try:
stream = data(starttime=day - edge,
endtime=next_day + overlap + edge, **smeta)
except Exception as ex:
log.debug('no data for %s %s: %s', day, smeta, str(ex))
return
else:
fname = data.format(t=day, **smeta)
try:
stream = obspy.read(fname, data_format)
except Exception:
return
t1 = stream[0].stats.starttime
t2 = stream[-1].stats.endtime
if t1 - day < 60:
fname = data.format(t=day - 1, **smeta)
try:
stream += obspy.read(fname, data_format, starttime=day - edge)
except Exception:
pass
if next_day - t2 < 60:
endtime = next_day + overlap + edge
fname = data.format(t=next_day, **smeta)
try:
stream += obspy.read(fname, data_format, endtime=endtime)
except Exception:
pass
if trim_and_merge:
stream.merge(method=1, interpolation_samples=10)
stream.trim(day, next_day + overlap)
return stream
def _shift(trace, shift):
"""Shift trace by given time and correct starttime => interpolation"""
msg = ('interpolate trace %s with starttime %s to shift by %.6fs '
'(Fourier method)')
log.debug(msg, trace.id, trace.stats.starttime, shift)
nfft = next_fast_len(len(trace))
spec = rfft(trace.data, nfft)
freq = rfftfreq(nfft, trace.stats.delta)
spec *= np.exp(-2j * np.pi * freq * shift)
trace.data = irfft(spec, nfft)[:len(trace)]
trace.stats.starttime -= shift
return trace
def _downsample_and_shift(trace, target_sr=None, tolerance_shift=None,
**interpolate_options):
"""Downsample and align samples at "good" times by shifting"""
sr = trace.stats.sampling_rate
if target_sr is None:
target_sr = sr
dt = 1 / target_sr
shift = (1e-6 * trace.stats.starttime.microsecond) % dt
if shift > 0.5 * dt:
shift = shift - dt
if tolerance_shift is None:
tolerance_shift = np.finfo(float).eps
must_shift = abs(shift) > tolerance_shift
if not must_shift: # anyway correct starttime
trace.stats.starttime -= shift
if sr % target_sr == 0:
if sr != target_sr:
trace.decimate(int(sr // target_sr))
if must_shift:
_shift(trace, shift)
else:
# anti-aliasing filter
if sr / target_sr > 16:
msg = ('Automatic filter design is unstable for decimation'
' factors above 16. '
'Manual decimation is necessary.')
raise ArithmeticError(msg)
trace.filter('lowpass_cheby_2', freq=0.5 * target_sr,
maxorder=12)
if must_shift:
starttime = trace.stats.starttime - shift
if starttime < trace.stats.starttime:
starttime += dt
msg = ('interpolate trace %s with starttime %s to downsample and '
'shift by %.6fs (Stream.interpolate() method)')
log.debug(msg, trace.id, trace.stats.starttime, shift)
else:
starttime = None
trace.interpolate(target_sr, starttime=starttime,
**interpolate_options)
return trace
def _prep1(target_sr, tolerance_shift, interpolate_options,
remove_response, inventory, remove_response_options, demean, filter,
tr):
"""Helper function for parallel preprocessing"""
tr.data = tr.data.astype('float64')
_downsample_and_shift(tr, target_sr=target_sr,
tolerance_shift=tolerance_shift,
interpolate_options=interpolate_options)
if remove_response:
tr.remove_response(inventory, **remove_response_options)
if demean:
tr.detrend('demean')
if filter is not None:
_filter(tr, filter)
return tr
def _prep2(normalization, time_norm_options, spectral_whitening_options,
decimate,
tr):
"""Helper function for parallel preprocessing"""
tr.data = _fill_array(tr.data, fill_value=0)
for norm in normalization:
if norm == 'spectral_whitening':
spectral_whitening(tr, **spectral_whitening_options)
else:
time_norm(tr, norm, **time_norm_options)
if decimate:
mask = np.ma.getmask(tr.data)
tr.decimate(decimate, no_filter=True)
if mask is not np.ma.nomask:
tr.data = np.ma.MaskedArray(tr.data, mask[::decimate],
fill_value=0)
return tr
[docs]def preprocess(stream, day=None, inventory=None,
overlap=0,
remove_response=False,
remove_response_options=None,
demean=True,
filter=None,
normalization=(),
time_norm_options=None,
spectral_whitening_options=None,
downsample=None, tolerance_shift=None, interpolate_options=None,
decimate=None,
njobs=0):
"""
Preprocess stream of 1 day
:param stream: |Stream| object
:param day: |UTC| object of day (for trimming)
:param inventory: |Inventory| object (for response removal)
:param bool remove_response: remove response
:param filter: min and max frequency of bandpass filter
:param normalizaton: ordered list of normalizations to apply,
``'sprectal_whitening'`` for `spectral_whitening` and/or
one or several of the time normalizations listed in `time_norm`
:param downsample: downsample before preprocessing,
target sampling rate
:param tolerance_shift: Samples are aligned at "good" times for the target
sampling rate. Specify tolerance in seconds. (default: no tolerance)
:param decimate: decimate further by given factor after preprocessing
(see Trace.decimate)
:param njobs: number of parallel workers
:param \*_options: dictionary of options passed to the corresponding
functions
"""
if time_norm_options is None:
time_norm_options = {}
if spectral_whitening_options is None:
spectral_whitening_options = {}
spectral_whitening_options.setdefault('filter', filter)
if remove_response_options is None:
remove_response_options = {}
if interpolate_options is None:
interpolate_options = {}
if isinstance(normalization, str):
normalization = [normalization]
stream.merge(1, interpolation_samples=10)
stream.traces = stream.split().traces
# discard traces with less than 10 samples
stream.traces = [tr for tr in stream if len(tr) >= 10]
if downsample is None:
downsample = min(tr.stats.sampling_rate for tr in stream)
# call _prep1 on all traces, merge stream and call _prep2 on all traces
do_work = partial(_prep1, downsample, tolerance_shift, interpolate_options,
remove_response, inventory, remove_response_options,
demean, filter)
stream.traces = start_parallel_jobs_inner_loop(stream, do_work, njobs)
len1 = len(stream)
stream.merge()
if len(stream) < len1:
log.debug('detected gaps in data')
do_work = partial(_prep2, normalization, time_norm_options,
spectral_whitening_options, decimate)
stream.traces = start_parallel_jobs_inner_loop(stream, do_work, njobs)
if day is not None:
next_day = day + 24 * 3600
stream.trim(day, next_day + overlap)
stream.sort()
assert len({tr.stats.sampling_rate for tr in stream}) == 1
return stream
[docs]def correlate_traces(tr1, tr2, maxshift=3600, demean=True):
"""
Return trace of cross-correlation of two input traces
:param tr1,tr2: two |Trace| objects
:param maxsift: maximal shift in correlation in seconds
"""
n1, s1, l1, c1 = tr1.id.split('.')
n2, s2, l2, c2 = tr2.id.split('.')
sr = tr1.stats.sampling_rate
xdata = obscorr(tr1.data, tr2.data, int(round(maxshift * sr)),
demean=demean)
header = {'network': s1, 'station': c1, 'location': s2, 'channel': c2,
'network1': n1, 'station1': s1, 'location1': l1, 'channel1': c1,
'network2': n2, 'station2': s2, 'location2': l2, 'channel2': c2,
'starttime': tr1.stats.starttime,
'sampling_rate': sr,
}
return obspy.Trace(data=xdata, header=header)
def _make_same_length(tr1, tr2):
"""Guarantee that tr1 and tr2 have the same length.
Even if tr1 and tr2 have the same sampling rate and are trimmed with
the same times, they could differ in length up to one sample.
This is handled here.
"""
dlen = len(tr2) - len(tr1)
dt = tr1.stats.delta
if dlen == -1:
tr1, tr2 = tr2, tr1
if abs(dlen) == 1:
# tr2 is too long
if tr1.stats.starttime - tr2.stats.starttime > dt / 2:
tr2.data = tr2.data[1:]
tr2.stats.starttime += dt
else:
tr2.data = tr2.data[:-1]
elif abs(dlen) > 1:
msg = 'This should not happen ;), traces have different length'
raise ValueError(msg)
def _slide_and_correlate_traces(day, next_day, length, overlap, discard,
max_lag, outkey, demean_window,
task):
"""Helper function for parallel correlating"""
tr1, tr2, dist, azi, baz = task
sr = tr1.stats.sampling_rate
sr2 = tr2.stats.sampling_rate
if sr != sr2:
msg = 'Traces have different sampling rate (%s != %s)' % (sr, sr2)
raise ValueError(msg)
xstream = obspy.Stream()
for t1 in IterTime(day, next_day - length + overlap, dt=length - overlap):
sub = obspy.Stream([tr1, tr2]).slice(t1, t1 + length)
if len(sub) < 2:
continue
st = [tr.stats.starttime for tr in sub]
et = [tr.stats.endtime for tr in sub]
if max(st) > min(et): # this should not happen
continue
sub.trim(max(st), min(et))
_make_same_length(sub[0], sub[1])
avail = min((tr.data.count() if hasattr(tr.data, 'count')
else len(tr)) / sr / length for tr in sub)
if discard is not None and avail < discard:
msg = ('discard trace combination %s-%s for time %s '
'(availability %.1f%% < %.1f%% desired)')
log.debug(msg, sub[0].id, sub[1].id, str(max(st))[:19],
100 * avail, 100 * discard)
continue
for tr in sub:
_fill_array(tr.data, fill_value=0)
tr.data = np.ma.getdata(tr.data)
xtr = correlate_traces(sub[0], sub[1], max_lag, demean=demean_window)
xtr.stats.starttime = t1
xtr.stats.key = outkey
xtr.stats.dist = dist
xtr.stats.azi = azi
xtr.stats.baz = baz
xtr.stats.avail = avail
xstream += xtr
return xstream
def _midtime(stats):
return stats.starttime + 0.5 * (stats.endtime - stats.starttime)
[docs]def correlate(io, day, outkey,
edge=60,
length=3600, overlap=1800,
demean_window=True,
discard=None,
only_auto_correlation=False,
station_combinations=None,
component_combinations=('ZZ',),
max_lag=100,
keep_correlations=False,
stack='1d',
njobs=0,
**preprocessing_kwargs):
"""
Correlate data of one day
:param io: io config dictionary
:param day: |UTC| object with day
:param outkey: the output key for the HDF5 index
:param edge: additional time span requested from day before and after
in seconds
:param length: length of correlation in seconds (string possible)
:param overlap: length of overlap in seconds (string possible)
:param demean_window: demean each window individually before correlating
:param discard: discard correlations with less data coverage
(float from interval [0, 1])
:param only_auto_correlations: Only correlate stations with itself
(different components possible)
:param station_combinations: specify station combinations
(e.g. ``'CX.PATCX-CX.PB01``, network code can be
omitted, e.g. ``'PATCX-PB01'``, default: all)
:param component_combinations: component combinations to calculate,
tuple of strings with length two, e.g. ``('ZZ', 'ZN', 'RR')``,
if ``'R'`` or ``'T'`` is specified, components will be rotated after
preprocessing, default: only ZZ components
:param max_lag: max time lag in correlations in seconds
:param keep_correlatons: write correlations into HDF5 file (dafault: False)
:param stack: stack correlations and write stacks into HDF5 file
(default: ``'1d'``, must be smaller than one day or one day)
.. note::
If you want to stack larger time spans
use the separate stack command on correlations or stacked
correlations.
:param njobs: number of jobs used. Some tasks will run parallel
(preprocessing and correlation).
:param \*\*preprocessing_kwargs: all other kwargs are passed to
`preprocess`
"""
inventory = io['inventory']
length = _time2sec(length)
overlap = _time2sec(overlap)
if not keep_correlations and stack is None:
msg = ('keep_correlation is False and stack is None -> correlations '
' would not be saved')
raise ValueError(msg)
components = set(''.join(component_combinations))
if 'R' in components or 'T' in components:
load_components = components - {'R', 'T'} | {'N', 'E'}
else:
load_components = components
if station_combinations is not None:
load_stations = set(sta for comb in station_combinations
for sta in comb.split('-'))
else:
load_stations = None
# load data
stream = obspy.Stream()
for smeta in _iter_station_meta(inventory, load_components):
if (load_stations is not None and smeta['station'] not in load_stations
and '.'.join((smeta['network'], smeta['station']))
not in load_stations):
continue
stream2 = get_data(smeta, io['data'], io['data_format'], day,
overlap=overlap, edge=edge)
if stream2:
stream += stream2
if len(stream) == 0:
log.warning('empty stream for day %s', str(day)[:10])
return
preprocess(stream, day, inventory, overlap=overlap, njobs=njobs,
**preprocessing_kwargs)
# collect trace pairs for correlation
next_day = day + 24 * 3600
stations = sorted({tr.id[:-1] for tr in stream})
tasks = []
for station1, station2 in itertools.combinations_with_replacement(
stations, 2):
if only_auto_correlation and station1 != station2:
continue
if station_combinations and not any(set(station_comb.split('-')) == (
{station1.rsplit('.', 2)[0], station2.rsplit('.', 2)[0]}
if '.' in (station_comb) else
{station1.split('.')[1], station2.split('.')[1]})
for station_comb in station_combinations):
continue
stream1 = Stream([tr for tr in stream if tr.id[:-1] == station1])
stream2 = Stream([tr for tr in stream if tr.id[:-1] == station2])
datetime1 = _midtime(stream1[0].stats)
datetime2 = _midtime(stream2[0].stats)
msg = 'Cannot get coordinates for channel %s datetime %s'
try:
c1 = inventory.get_coordinates(stream1[0].id, datetime=datetime1)
except Exception as ex:
raise RuntimeError(msg % (stream1[0].id, datetime1)) from ex
try:
c2 = inventory.get_coordinates(stream2[0].id, datetime=datetime2)
except Exception as ex:
raise RuntimeError(msg % (stream2[0].id, datetime2)) from ex
args = (c1['latitude'], c1['longitude'],
c2['latitude'], c2['longitude'])
dist, azi, baz = gps2dist_azimuth(*args)
if ('R' in components or 'T' in components) and station1 != station2:
stream1 = stream1.copy()
stream1b = stream1.copy().rotate('NE->RT', azi)
stream1.extend(stream1b.select(component='R'))
stream1.extend(stream1b.select(component='T'))
stream2 = stream2.copy()
stream2b = stream2.copy().rotate('NE->RT', azi)
stream2.extend(stream2b.select(component='R'))
stream2.extend(stream2b.select(component='T'))
it_ = (itertools.product(stream1, stream2) if station1 != station2 else
itertools.combinations_with_replacement(stream1, 2))
for tr1, tr2 in it_:
comps = tr1.stats.channel[-1] + tr2.stats.channel[-1]
if component_combinations and (
comps not in component_combinations and
comps[::-1] not in component_combinations):
continue
tasks.append((tr1, tr2, dist, azi, baz))
# start correlation
do_work = partial(_slide_and_correlate_traces, day, next_day, length,
overlap, discard, max_lag, outkey,
demean_window)
streams = start_parallel_jobs_inner_loop(tasks, do_work, njobs)
xstream = Stream()
xstream.traces = [tr for s_ in streams for tr in s_]
if len(xstream) > 0:
res = {}
if keep_correlations:
res['corr'] = xstream
if stack:
res['stack'] = yam.stack.stack(xstream, stack)
return res