Commit 42695c8a authored by Peter V. Saveliev's avatar Peter V. Saveliev

ndb.source: fix restart

parent d460c08c
......@@ -76,12 +76,9 @@ import struct
import threading
from pyroute2 import IPRoute
from pyroute2 import RemoteIPRoute
from pyroute2.ndb.events import (SchemaReadLock,
SchemaReadUnlock,
ShutdownException,
from pyroute2.ndb.events import (ShutdownException,
State)
from pyroute2.ndb.messages import (cmsg,
cmsg_event,
from pyroute2.ndb.messages import (cmsg_event,
cmsg_failed,
cmsg_sstart)
from pyroute2.netlink.nlsocket import NetlinkMixin
......@@ -133,7 +130,7 @@ class Source(dict):
self.shutdown = threading.Event()
self.started = threading.Event()
self.lock = threading.RLock()
self.shutdown_lock = threading.Lock()
self.shutdown_lock = threading.RLock()
self.started.clear()
self.log = ndb.log.channel('sources.%s' % self.target)
self.state = State(log=self.log)
......@@ -357,16 +354,19 @@ class Source(dict):
def restart(self, reason='unknown'):
with self.lock:
if not self.shutdown.is_set():
with self.shutdown_lock:
self.log.debug('restarting the source, reason <%s>' % (reason))
self.evq.put((cmsg(self.target, SchemaReadLock()), ))
self.started.clear()
self.ndb.schema.allow_read(False)
try:
self.close()
if self.th:
self.th.join()
self.shutdown.clear()
self.start()
finally:
self.evq.put((cmsg(self.target, SchemaReadUnlock()), ))
self.ndb.schema.allow_read(True)
self.started.wait()
def __enter__(self):
return self
......
......@@ -81,6 +81,40 @@ class TestMisc(object):
for source in ndb.sources:
assert ndb.sources[source].nl.closed
def test_source_localhost_restart(self):
require_user('root')
ifname = uifname()
with NDB() as ndb:
assert len(list(ndb.interfaces.dump()))
ndb.sources['localhost'].restart()
assert len(list(ndb.interfaces.dump()))
(ndb
.interfaces
.create(ifname=ifname, kind='dummy', state='up')
.commit())
assert ndb.interfaces[ifname]['state'] == 'up'
ndb.interfaces[ifname].remove().commit()
def test_source_netns_restart(self):
require_user('root')
ifname = uifname()
nsname = str(uuid.uuid4())
with NDB() as ndb:
ndb.sources.add(netns=nsname)
assert len(list(ndb.interfaces.dump().filter(target=nsname)))
ndb.sources[nsname].restart()
assert len(list(ndb.interfaces.dump().filter(target=nsname)))
(ndb
.interfaces
.create(target=nsname, ifname=ifname, kind='dummy', state='up')
.commit())
assert ndb.interfaces[{'target': nsname,
'ifname': ifname}]['state'] == 'up'
ndb.interfaces[{'target': nsname,
'ifname': ifname}].remove().commit()
class TestPreSet(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment