Source code for yam.util

# Copyright 2017-2023 Tom Eulenfeld, MIT license
"""Utility functions"""

from importlib import import_module
import logging
import numbers
import os
from pkg_resources import resource_filename
import shutil
import sys
import tempfile

import numpy as np
from obspy import UTCDateTime as UTC
import scipy.signal
import tqdm


log = logging.getLogger('yam.util')


[docs]class YamError(Exception): pass
[docs]class ParseError(YamError): pass
[docs]class ConfigError(YamError): pass
def _analyze_key(key): if key.startswith('/'): key = key[1:] if '/' in key: key = key.split('/', 1)[0] return ''.join([k[0] for k in key.split('_')]) def _get_fname(io, key): fname = (io['stretch'] if 't' in _analyze_key(key) else io['stack'] if 's' in _analyze_key(key) else io['corr']) return fname def _corr_id(trace): st = trace.stats try: id_ = (st.network1, st.station1, st.location1, st.channel1, st.network2, st.station2, st.location2, st.channel2) except AttributeError: log.error('meta data error in read file, trace: %s, stats: %s', trace, st) return return '%s.%s.%s.%s-%s.%s.%s.%s' % id_ def _filter(stream, filter): if filter[0] is None: stream.filter("lowpass", freq=filter[1]) elif filter[1] is None: stream.filter("highpass", freq=filter[0]) else: stream.filter("bandpass", freqmin=filter[0], freqmax=filter[1]) def _load_func(modulename, funcname): """Load and return function from Python module""" sys.path.append(os.path.curdir) module = import_module(modulename) sys.path.pop(-1) func = getattr(module, funcname) return func def _seedid2meta(seedid): net, sta, loc, cha = seedid.split('.') smeta = {'network': net, 'station': sta, 'location': loc, 'channel': cha} return smeta def _time2sec(time): """Convert string (e.g. 1d or 0.5h) to seconds""" if not isinstance(time, numbers.Number): time, unit = float(time[:-1]), time[-1] assert unit in 'dh' time *= 24 * 3600 if unit == 'd' else 3600 return time def _trim(tr, time_interval): """Trim traces of stream relative to mid sample""" starttime = tr.stats.starttime mid = starttime + (tr.stats.endtime - starttime) / 2 if time_interval is not None: start, end = time_interval if start is not None: start = mid + start if end is not None: end = mid + end tr.trim(start, end) times = tr.times(reftime=mid) tr.stats.starttime = starttime return times def _trim_time_period(stream, time_period): """Restrict traces of stream to given time period""" if time_period is None: return starttime, endtime = time_period traces = [tr for tr in stream if (starttime is None or tr.stats.starttime >= UTC(starttime)) and (endtime is None or tr.stats.starttime < UTC(endtime))] stream.traces = traces
[docs]def create_config(conf='conf.json', tutorial=False, less_data=False): """Create JSON config file and download tutorial data if requested""" shutil.copyfile(resource_filename('yam', 'conf_example.json'), conf) temp_dir = os.path.join(tempfile.gettempdir(), 'yam_example_data') template = os.path.join(temp_dir, 'example_data') station_template = os.path.join(temp_dir, 'example_inventory') try: num_files = (len([name for name in os.listdir(template)]), len([name for name in os.listdir(station_template)])) except FileNotFoundError: num_files = (0, 0) if tutorial and (num_files[0] < (9 if less_data else 54) or num_files[1] < 3): print('Download example data from Geofon') from obspy import UTCDateTime as UTC from obspy.clients.fdsn.mass_downloader import ( GlobalDomain, Restrictions, MassDownloader) domain = GlobalDomain() restrictions = Restrictions( starttime=UTC('2010-02-04' if less_data else '2010-02-01'), endtime=UTC('2010-02-06' if less_data else '2010-02-15'), network='CX', station='PATCX', location=None, channel_priorities=["BH[ZN]"], chunklength_in_sec=86400, reject_channels_with_gaps=False, minimum_length=0.5) mdl = MassDownloader(providers=['GFZ']) kw = dict(threads_per_client=1, download_chunk_size_in_mb=200) mdl.download(domain, restrictions, template, station_template, **kw) restrictions.station = 'PB06' if not less_data: restrictions.endtime = UTC('2010-02-12') mdl.download(domain, restrictions, template, station_template, **kw) restrictions.station = 'PB01' restrictions.endtime = UTC('2010-02-04 08:00:00') restrictions.channel_priorities = ["BHZ"] mdl.download(domain, restrictions, template, station_template, **kw) if not less_data: restrictions.starttime = UTC('2010-02-08 00:00:00') restrictions.endtime = UTC('2010-02-09 23:55:00') restrictions.channel_priorities = ["BHZ"] mdl.download(domain, restrictions, template, station_template, **kw) if tutorial: dest_dir = os.path.dirname(conf) dest_dir_data = os.path.join(dest_dir, 'example_data') dest_dir_inv = os.path.join(dest_dir, 'example_inventory') if not os.path.exists(dest_dir_data): if less_data: ignore = shutil.ignore_patterns('*2010020[123]T000000Z__*', '*2010020[6-9]T000000Z__*', '*2010021?T000000Z__*') else: ignore = None shutil.copytree(template, dest_dir_data, ignore=ignore) if not os.path.exists(dest_dir_inv): shutil.copytree(station_template, dest_dir_inv)
# def _stream2matrix(stream): # """ # Return array with data in time window (start, end) around relative. # # 'time' can stand for UTCDateTime, list of UTCDateTimes, header entry out of # ('ponset', 'sonset', 'startime', 'endtime') or 'middle' # :param stream: Stream object with data # :param trim: 2 time or float (seconds) relative to param=relative # :param relative: time, is needed if start or end in seconds (float) # :return: np.array of shape (N_stream, N_data) # """ # if len(stream) == 0: # raise ValueError('Stream has length 0') # samp = [tr.stats.sampling_rate for tr in stream] # npts = [len(tr) for tr in stream] # if min(samp) != max(samp): # for tr in stream: # tr.decimate(int(tr.stats.sampling_rate) // min(samp)) # log.warning('Downsampling stream because of differing sampling rate.') # if min(npts) != max(npts): # log.warning('Traces in stream have different NPTS. ' # 'Difference: %d samples' % (max(npts) - min(npts))) # data = np.zeros((len(stream), max(npts))) # for i, trace in enumerate(stream): # data[i, :len(trace.data)] = trace.data # return data
[docs]def smooth(x, window_len=None, window='flat', method='zeros'): """Smooth the data using a window with requested size. This method is based on the convolution of a scaled window with the signal. :param x: the input signal (numpy array) :param window_len: the dimension of the smoothing window; should be an odd integer :param window: the type of window from 'flat', 'hanning', 'hamming', 'bartlett', 'blackman' flat window will produce a moving average smoothing. :param method: handling of border effects\n 'zeros': zero padding on both ends (len(smooth(x)) = len(x))\n 'reflect': pad reflected signal on both ends (same)\n 'clip': pad signal on both ends with the last valid value (same)\n None: no handling of border effects (len(smooth(x)) = len(x) - len(window_len) + 1) """ if window_len is None: return x if x.ndim != 1: raise ValueError("smooth only accepts 1 dimension arrays.") if x.size < window_len: raise ValueError("Input vector needs to be bigger than window size.") if window_len < 3: return x if window not in ['flat', 'hanning', 'hamming', 'bartlett', 'blackman']: raise ValueError("Window is one of 'flat', 'hanning', 'hamming'," "'bartlett', 'blackman'") if method == 'zeros': s = np.r_[np.zeros((window_len - 1) // 2), x, np.zeros(window_len // 2)] elif method == 'reflect': s = np.r_[x[(window_len - 1) // 2:0:-1], x, x[-1:-(window_len + 1) // 2:-1]] elif method == 'clip': s = np.r_[x[0] * np.ones((window_len - 1) // 2), x, x[-1] * np.ones(window_len // 2)] else: s = x if window == 'flat': w = np.ones(window_len, 'd') else: w = getattr(np, window)(window_len) return scipy.signal.fftconvolve(w / w.sum(), s, mode='valid')
[docs]class IterTime(): """Iterator yielding UTCDateTime objects between start- and endtime""" def __init__(self, startdate, enddate, dt=24 * 3600): self.startdate = startdate self.enddate = enddate self.dt = dt def __len__(self): return int((self.enddate - self.startdate) / self.dt) def __iter__(self): t = self.startdate while t <= self.enddate: yield t t += self.dt
# https://stackoverflow.com/a/38739634 # not working yet for parallel processing
[docs]class TqdmLoggingHandler(logging.Handler): def __init__(self, level=logging.NOTSET): super(self.__class__, self).__init__(level)
[docs] def emit(self, record): try: msg = self.format(record) tqdm.tqdm.write(msg) self.flush() except Exception: self.handleError(record)
LOGLEVELS = {0: 'CRITICAL', 1: 'WARNING', 2: 'INFO', 3: 'DEBUG'} LOGGING_DEFAULT_CONFIG = { 'version': 1, 'disable_existing_loggers': False, 'capture_warnings': True, 'formatters': { 'file': { 'format': ('%(asctime)s %(module)-10s%(process)-6d%(levelname)-8s' '%(message)s') }, 'console': { 'format': '%(levelname)-8s%(message)s' }, }, 'handlers': { 'console': { 'class': 'logging.StreamHandler', 'formatter': 'console', 'level': None, }, 'console_tqdm': { 'class': 'yam.util.TqdmLoggingHandler', 'formatter': 'console', 'level': None, }, 'file': { 'class': 'logging.FileHandler', 'formatter': 'file', 'level': None, 'filename': None, }, }, 'loggers': { 'yam': { 'handlers': ['console_tqdm', 'file'], 'level': 'DEBUG', 'propagate': False, }, 'py.warnings': { 'handlers': ['console_tqdm', 'file'], 'level': 'DEBUG', 'propagate': False, } } }