Source code for trident.plotting

"""
Spectrum plotting functions.

"""

#-----------------------------------------------------------------------------
# Copyright (c) 2016, Trident Development Team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
#-----------------------------------------------------------------------------

from yt.funcs import \
    mylog
import numpy as np
import matplotlib.figure
from matplotlib.backends.backend_agg import \
    FigureCanvasAgg

_xlabels = {'angstrom': 'Wavelength [$\\rm\\AA$]',
            'km/s': 'Velocity Offset [km/s]'}

[docs]def plot_spectrum(wavelength, flux, filename="spectrum.png", lambda_limits=None, flux_limits=None, title=None, label=None, figsize=None, step=False, stagger=0.2, features=None, axis_labels=None): """ Plot a spectrum or a collection of spectra and save to disk. This function wraps some Matplotlib plotting functionality for plotting spectra generated with the :class:`~trident.SpectrumGenerator`. In its simplest form, it accepts a wavelength array consisting of wavelength values and a corresponding flux array consisting of relative flux values, and it plots them and saves to disk. In addition, it can plot several spectra on the same axes simultaneously by passing a list of arrays to the ``wavelength``, ``flux`` arguments (and optionally to the ``label`` and ``step`` keywords). Returns the Matplotlib Figure object for further processing. **Parameters** :wavelength: array of floats or list of arrays of floats Wavelength values in angstroms. Either as an array of floats in the case of plotting a single spectrum, or as a list of arrays of floats in the case of plotting several spectra on the same axes. :flux: array of floats or list of arrays of floats Relative flux values (from 0 to 1) corresponding to wavelength array. Either as an array of floats in the case of plotting a single spectrum, or as a list of arrays of floats in the case of plotting several spectra on the same axes. :filename: string, optional Output filename of the plotted spectrum. Will be a png file. Default: 'spectrum.png' :lambda_limits: tuple or list of floats, optional The minimum and maximum of the wavelength range (x-axis) for the plot in angstroms. If specified as None, will use whole lambda range of spectrum. Example: (1200, 1400) for 1200-1400 Angstroms Default: None :flux_limits: tuple or list of floats, optional The minimum and maximum of the flux range (y-axis) for the plot. If specified as None, limits are automatically from [0, 1.1*max(flux)]. Example: (0, 1) for normal flux range before postprocessing. Default: None :step: boolean or list of booleans, optional Plot the spectrum as a series of step functions. Appropriate for plotting processed and noisy data. Use a list of booleans when plotting multiple spectra, where each boolean corresponds to the entry in the ``wavelength`` and ``flux`` lists. :title: string, optional Optional title for plot Default: None :label: string or list of strings, optional Label for each spectrum to be plotted. Useful if plotting multiple spectra simultaneously. Will automatically trigger a legend to be generated. Default: None :stagger: float, optional If plotting multiple spectra on the same axes, do we offset them in the y direction? If set to None, no. If set to a float, stagger them by the flux value specified by this parameter. :features: dict, optional Include vertical lines with labels to represent certain spectral features. Each entry in the dictionary consists of a key string to be overplot and the value float as to where in wavelength space it will be plot as a vertical line with the corresponding label. Example: features={'Ly a' : 1216, 'Ly b' : 1026} Default: None :axis_labels: tuple of strings, optional Optionally set the axis labels directly. If set to None, defaults to ('Wavelength [$\\rm\\AA$]', 'Relative Flux'). Default: None **Returns** Matplotlib Figure object for further processing **Example** Plot a flat spectrum >>> import numpy as np >>> import trident >>> wavelength = np.arange(1200, 1400) >>> flux = np.ones(len(wavelength)) >>> trident.plot_spectrum(wavelength, flux) Generate a one-zone ray, create a Lyman alpha spectrum from it, and add gaussian noise to it. Plot both the raw spectrum and the noisy spectrum on top of each other. >>> import trident >>> ray = trident.make_onezone_ray(column_densities={'H_p0_number_density':1e21}) >>> sg_final = trident.SpectrumGenerator(lambda_min=1200, lambda_max=1300, dlambda=0.5) >>> sg_final.make_spectrum(ray, lines=['Ly a']) >>> sg_final.save_spectrum('spec_raw.h5') >>> sg_final.add_gaussian_noise(10) >>> sg_raw = trident.load_spectrum('spec_raw.h5') >>> trident.plot_spectrum([sg_raw.lambda_field, sg_final.lambda_field], ... [sg_raw.flux_field, sg_final.flux_field], stagger=0, step=[False, True], ... label=['Raw', 'Noisy'], filename='raw_and_noise.png') """ # number of rows and columns n_rows = 1 n_columns = 1 # blank space between edge of figure and active plot area top_buffer = 0.07 bottom_buffer = 0.15 left_buffer = 0.06 right_buffer = 0.03 # blank space between plots hor_buffer = 0.05 vert_buffer = 0.05 # calculate the height and width of each panel panel_width = ((1.0 - left_buffer - right_buffer - ((n_columns-1)*hor_buffer)) / n_columns) panel_height = ((1.0 - top_buffer - bottom_buffer - ((n_rows-1)*vert_buffer)) / n_rows) # create a figure (figsize is in inches) if figsize is None: figsize = (12, 4) figure = matplotlib.figure.Figure(figsize=figsize, frameon=True) # get the row and column number my_row = 0 my_column = 0 # calculate the position of the bottom, left corner of this plot left_side = left_buffer + (my_column * panel_width) + \ my_column * hor_buffer top_side = 1.0 - (top_buffer + (my_row * panel_height) + \ my_row * vert_buffer) bottom_side = top_side - panel_height # create an axes object on which we will make the plot my_axes = figure.add_axes((left_side, bottom_side, panel_width, panel_height)) # Are we overplotting several spectra? or just one? if isinstance(flux, list): fluxs = flux else: fluxs = [flux] if isinstance(wavelength, list): wavelengths = wavelength else: wavelengths = [wavelength]*len(fluxs) if isinstance(step, list): steps = step else: steps = [step]*len(fluxs) if isinstance(label, list): labels = label else: labels = [label]*len(fluxs) # A running maximum of flux for use in ylim scaling in final plot max_flux = 0. if isinstance(wavelength, list): key = wavelength[0] else: key = wavelength xlabel = _xlabels.get(str(key.units)) for i, (wavelength, flux) in enumerate(zip(wavelengths, fluxs)): # Do we stagger the fluxes? if stagger is not None: flux -= stagger * i # Do we include labels and a legend? if steps[i]: my_axes.step(wavelength, flux, label=labels[i]) else: my_axes.plot(wavelength, flux, label=labels[i]) new_max_flux = np.max(flux) if new_max_flux > max_flux: max_flux = new_max_flux # Return the fluxes to their normal values # if they've been staggered if stagger is not None: flux += stagger * i # Do we include a title? if title is not None: my_axes.set_title(title) if lambda_limits is None: lambda_limits = (wavelength.min(), wavelength.max()) my_axes.set_xlim(lambda_limits[0], lambda_limits[1]) if flux_limits is None: flux_limits = (0, 1.1*max_flux) my_axes.set_ylim(flux_limits[0], flux_limits[1]) if axis_labels is None: axis_labels = (xlabel, 'Relative Flux') my_axes.xaxis.set_label_text(axis_labels[0]) my_axes.yaxis.set_label_text(axis_labels[1]) # Don't let the x-axis switch to offset values for tick labels my_axes.get_xaxis().get_major_formatter().set_useOffset(False) if label is not None: my_axes.legend() # Overplot the relevant features on the plot if features is not None: for feature in features: label = feature wavelength = features[feature] # Draw line my_axes.plot([wavelength, wavelength], flux_limits, '--', color='k') # Write text text_location = flux_limits[1] - 0.05*(flux_limits[1] - flux_limits[0]) my_axes.text(wavelength, text_location, label, horizontalalignment='right', verticalalignment='top', rotation='vertical') mylog.info("Writing spectrum plot to png file: %s" % filename) canvas = FigureCanvasAgg(figure) canvas.print_figure(filename) return figure