Source code for nest.tests.test_threads

# -*- coding: utf-8 -*-
#
# test_threads.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST.  If not, see <http://www.gnu.org/licenses/>.

"""
UnitTests for multithreaded pynest
"""

import unittest
import nest


[docs]@nest.check_stack class ThreadTestCase(unittest.TestCase): """Tests for multi-threading"""
[docs] def nest_multithreaded(self): """Return True, if we have a thread-enabled NEST, False otherwise""" nest.sr("statusdict/threading :: (no) eq not") return nest.spp()
[docs] def test_Threads(self): """Multiple threads""" if not self.nest_multithreaded(): self.skipTest("NEST was compiled without multi-threading") nest.ResetKernel() self.assertEqual(nest.GetKernelStatus()['local_num_threads'], 1) nest.SetKernelStatus({'local_num_threads': 8}) n = nest.Create('iaf_psc_alpha', 8) st = list(nest.GetStatus(n, 'vp')) st.sort() self.assertEqual(st, [0, 1, 2, 3, 4, 5, 6, 7])
[docs] def test_ThreadsGetConnections(self): """GetConnections with threads""" if not self.nest_multithreaded(): self.skipTest("NEST was compiled without multi-threading") nest.ResetKernel() nest.SetKernelStatus({'local_num_threads': 8}) pre = nest.Create("iaf_psc_alpha") post = nest.Create("iaf_psc_alpha", 6) nest.Connect(pre, post) conn = nest.GetConnections(pre) # Because of threading, targets may be in a different order than # in post, so we sort the vector. targets = list(nest.GetStatus(conn, "target")) targets.sort() self.assertEqual(targets, list(post))
[docs] def test_ThreadsGetEvents(self): """ Gathering events across threads """ if not self.nest_multithreaded(): self.skipTest("NEST was compiled without multi-threading") threads = (1, 2, 4, 8) n_events_sd = [] n_events_vm = [] N = 128 Simtime = 1000. for t in threads: nest.ResetKernel() nest.SetKernelStatus({'local_num_threads': t}) # force a lot of spike events n = nest.Create('iaf_psc_alpha', N, {'I_e': 2000.}) sd = nest.Create('spike_detector') vm = nest.Create('voltmeter') nest.Connect(n, sd) nest.Connect(vm, n) nest.Simulate(Simtime) n_events_sd.append(nest.GetStatus(sd, 'n_events')[0]) n_events_vm.append(nest.GetStatus(vm, 'n_events')[0]) ref_vm = N * (Simtime - 1) ref_sd = n_events_sd[0] # could be done more elegantly with any(), ravel(), # but we dont want to be dependent on numpy et al [self.assertEqual(x, ref_vm) for x in n_events_vm] [self.assertEqual(x, ref_sd) for x in n_events_sd]
[docs]def suite(): suite = unittest.makeSuite(ThreadTestCase, 'test') return suite
[docs]def run(): runner = unittest.TextTestRunner(verbosity=2) runner.run(suite())
if __name__ == "__main__": run()