Utility examples

Matplotlib plots and HTML generators

Note

This example makes use of the third party library matplotlib. More information on the matplotlib module can be found at matplotlib.org.

#
# This script can be used for any purpose without limitation subject to the
# conditions at https://www.ccdc.cam.ac.uk/Community/Pages/Licences/v2.aspx
#
# This permission notice and the following statement of attribution must be
# included in all copies or substantial portions of this script.
#
# 2015-06-17: created by the Cambridge Crystallographic Data Centre
#

"""
Utilities used in the other example scripts.
"""

############################################################################

from contextlib import contextmanager
import math
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plot
import numpy
import sys

############################################################################


class Plot(object):
    '''thin wrapper around a matplotlib plot'''

    def __init__(self, title='', xlabel='', ylabel='', file_name='scatter.png'):
        '''initialise the plot'''
        self.fig = plot.figure()
        self.axes = self.fig.add_subplot(1, 1, 1)
        self.title = title
        self.xlabel = xlabel
        self.ylabel = ylabel
        self.file_name = file_name

    def __del__(self):
        '''Ensure memory is freed.'''
        plot.close(self.fig)

    def _axes_property(name, label):
        '''define a property'''
        return property(
            lambda self: getattr(self.axes, 'get_' + name)(),
            lambda self, val: getattr(self.axes, 'set_' + name)(val),
            label
        )

    title = _axes_property('title', 'The title of the plot')
    xlabel = _axes_property('xlabel', 'The label for the X axis')
    ylabel = _axes_property('ylabel', 'The label for the Y axis')

    def write(self):
        self.fig.savefig(self.file_name)


class Scatterplot(Plot):
    '''A simple scatter plot'''

    def add_plot(self, xs, ys, color='green'):
        '''Add some scattered points'''
        self.axes.scatter(xs, ys, color=color)

    def annotate(self, x, y, text, color='red'):
        '''Mark a point'''
        self.add_plot([x], [y], color=color)
        xmin, xmax = self.axes.get_xlim()
        ymin, ymax = self.axes.get_ylim()
        xmid = (xmin + xmax) / 2.
        ymid = (ymin + ymax) / 2.
        xoff = (xmax - xmin) / 10.
        yoff = 0
        ha = 'left'
        va = 'center'
        if x > xmid:
            xoff *= -1
            ha = 'right'
        self.axes.annotate(
            text, (x, y), (x + xoff, y + yoff),
            ha=ha, va=va,
            arrowprops=dict(
                arrowstyle='simple',
                facecolor=color
            )
        )


class Histogram(Plot):
    '''A simple histogram'''

    def add_plot(self, data, color='green', bins=40):
        '''Histogram the data and add to the plot'''
        ys, xs = numpy.histogram(data, bins=bins)
        xmin = min(data)
        xmax = max(data)
        width = (xmax - xmin) / bins
        self.axes.bar(xs[:-1], ys, width, color=color)

    def add_histogram(self, data, low, high, color='green'):
        '''When the data has been histogrammed already'''
        width = (high - low) / float(len(data))
        xs = [low + i * width for i in range(len(data))]
        self.axes.set_xlim(low, high)
        self.axes.bar(xs, data, width, color=color)

    def annotate(self, x, text, color='red', yoff=0.8):
        self.axes.add_line(plot.Line2D((x, x), self.axes.get_ylim(), color=color))
        xmin, xmax = self.axes.get_xlim()
        xmid = (xmin + xmax) / 2.
        xoff = (xmax - xmin) / 10.
        ha = 'left'
        if x > xmid:
            xoff *= -1
            ha = 'right'
        yoff *= sum(self.axes.get_ylim())
        self.axes.annotate(
            text, (x, yoff), (x + xoff, yoff), color=color,
            ha=ha, va='center',
            arrowprops=dict(
                arrowstyle='simple',
                facecolor=color
            )
        )


class PolarScatterplot(Plot):
    '''Simple polar scatter'''

    def __init__(self, **kw):
        '''Initialise'''
        Plot.__init__(self, **kw)
        self.fig.delaxes(self.axes)
        self.axes = self.fig.add_subplot(1, 1, 1, polar=True)

    def add_plot(self, theta, r, color='green'):
        '''Add a plot with theta in degrees'''
        self.axes.scatter((math.radians(x) for x in theta), r, c=color)


class Lineplot(Plot):
    '''Simple line plot.'''

    def __init__(self, **kw):
        Plot.__init__(self, **kw)
        t, x, y = self.title, self.xlabel, self.ylabel
        self.fig.delaxes(self.axes)
        self.axes = self.fig.add_subplot(1, 1, 1)
        self.title, self.xlabel, self.ylabel = t, x, y

    def add_plot(self, xs, ys, **kw):
        '''Add a plot.

        Keywords may be any arguments accepted by :class:`matplotlib.Line'''
        self.axes.plot(xs, ys, **kw)

############################################################################


############################################################################

class Curry(object):
    """Bind a callable to arguments and keywords."""

    def __init__(self, f, *args, **kw):
        """Store callable, args and keywords."""
        self.f = f
        self.args = args
        self.kw = kw

    def __call__(self, *args, **kw):
        """
        Extend the bound arguments and keywords, then call and return result.
        """
        d = dict()
        d.update(self.kw)
        d.update(kw)
        return self.f(*(self.args + args), **d)


@contextmanager
def element(tag, stream=sys.stdout, **attribs):
    """Write HTML elements with attributes to the given stream."""
    stream.write(
        '<%s %s>' % (tag, ' '.join('%s="%s"' % (k, v)
                                   for k, v in attribs.items())
                     ))
    yield stream
    stream.write('</%s>' % tag)


html_table = Curry(element, 'table')
html_row = Curry(element, 'tr')
html_header = Curry(element, 'th')
html_datum = Curry(element, 'td')
html_theader = Curry(element, 'thead')
html_tbody = Curry(element, 'tbody')


def write_html_table(headers, data, stream=sys.stdout, **attribs):
    """Write HTML table with attributes to the given stream."""
    with html_table(stream=stream, **attribs) as s:
        with html_theader(stream=s):
            with html_row(stream=s):
                for h in headers:
                    with html_header(stream=s):
                        print(h, file=s)
        with html_tbody(stream=s):
            for row in data:
                with html_row(stream=s):
                    for d in row:
                        with html_datum(stream=s):
                            print(str(d), file=s)


def sanitise_file_name(string):
    """ Remove any non-alphanumeric characters from string.

    This is useful for making sure a molecule identifier can be used in Windows file names.
    """
    return ''.join(x if x.isalnum() else '_' for x in string)


def output_file(file_name):
    """Return a file object set up to properly write Unicode CSV on both Python 2 and 3. """
    # newline='' is needed to avoid csv.writer writing empty lines on Windows
    return open(file_name, 'w', encoding='utf-8', newline='')

def row_to_utf8(row):
    """Explicitly encode any unicode in the row as UTF-8 bytes on Python 2 for CSV writing."""
    return [cell for cell in row]

Timing python code

This example shows some ways of recording timing information for python code.

#!/usr/bin/env python
#
# This script can be used for any purpose without limitation subject to the
# conditions at https://www.ccdc.cam.ac.uk/Community/Pages/Licences/v2.aspx
#
# This permission notice and the following statement of attribution must be
# included in all copies or substantial portions of this script.
#
# 2017-08-24: created by the Cambridge Crystallographic Data Centre
#
'''
    timer.py    -   an example of using a timer.
    This will count all the phenyl rings in the CSD.
    For brevity, the first 10000 structures will be searched.  If you wish to know the full number, remove the lines annotated below.

    This produces output like:

  1000 (  0%)... 0:00:14 (expected 3:36:48) 1744 found so far
  2000 (  0%)... 0:00:29 (expected 3:40:51) 3462 found so far
  3000 (  0%)... 0:00:43 (expected 3:34:28) 4762 found so far
  4000 (  0%)... 0:00:57 (expected 3:32:42) 6231 found so far
  5000 (  0%)... 0:01:12 (expected 3:34:14) 7827 found so far
  6000 (  0%)... 0:01:33 (expected 3:49:21) 9937 found so far
  7000 (  0%)... 0:01:50 (expected 3:52:43) 11653 found so far
  8000 (  0%)... 0:02:04 (expected 3:50:00) 13323 found so far
  9000 (  1%)... 0:02:20 (expected 3:49:27) 15037 found so far
 10000 (  1%)... 0:02:40 (expected 3:55:53) 16918 found so far
There are 16918 phenyl rings in the CSD, averaging 1.69 per structure.
            All:   1: 0:02:40
         Search: 7653: 0:00:59
   Assign bonds: 7654: 0:00:44
Create molecule: 10000: 0:00:41
  Add hydrogens: 7654: 0:00:13

'''
###########################################################################

import time
from ccdc import search, io
from ccdc.utilities import Timer

# Create a timer
timer = Timer()

# This decorator will cause all calls to find_phenyl_rings to be timed and counted
@timer.decorate('Search')
def find_phenyl_rings(mol):
    phenyl = search.SMARTSSubstructure('c1ccccc1')
    searcher = search.SubstructureSearch()
    searcher.add_substructure(phenyl)
    hits = searcher.search(mol)
    return len(hits)

csd = io.EntryReader('CSD')
total = len(csd)
start = time.time()
ct = 0

# This and other 'with timer()' lines will cause the managed block to be timed and counted
with timer('All'):
    for i, e in enumerate(csd):
        if i and i % 1000 == 0:
            # This will cause a message about the elapsed time and predicted time to completion to be reported
            Timer.progress(start, i, total, '%d found so far' % ct)
        # Remove the following two lines to perform the full scan (takes around four hours)
        if i == 10000:
            break
        with timer('Create molecule'):
            mol = e.molecule
            if not mol.all_atoms_have_sites:
                continue
        with timer('Assign bonds'):
            try:
                mol.assign_bond_types()
                mol.standardise_aromatic_bonds()
            except RuntimeError:
                continue
        with timer('Add hydrogens'):
            try:
                mol.add_hydrogens()
            except RuntimeError:
                continue
        ct += find_phenyl_rings(mol)

print('There are %d phenyl rings in the CSD, averaging %.2f per structure.' % (ct, ct/float(i)))
# This will print a summary of the timings recorded
timer.report()

###########################################################################