Commit 42695c8a authored by Peter V. Saveliev's avatar Peter V. Saveliev

ndb.source: fix restart

parent d460c08c
...@@ -76,12 +76,9 @@ import struct ...@@ -76,12 +76,9 @@ import struct
import threading import threading
from pyroute2 import IPRoute from pyroute2 import IPRoute
from pyroute2 import RemoteIPRoute from pyroute2 import RemoteIPRoute
from pyroute2.ndb.events import (SchemaReadLock, from pyroute2.ndb.events import (ShutdownException,
SchemaReadUnlock,
ShutdownException,
State) State)
from pyroute2.ndb.messages import (cmsg, from pyroute2.ndb.messages import (cmsg_event,
cmsg_event,
cmsg_failed, cmsg_failed,
cmsg_sstart) cmsg_sstart)
from pyroute2.netlink.nlsocket import NetlinkMixin from pyroute2.netlink.nlsocket import NetlinkMixin
...@@ -133,7 +130,7 @@ class Source(dict): ...@@ -133,7 +130,7 @@ class Source(dict):
self.shutdown = threading.Event() self.shutdown = threading.Event()
self.started = threading.Event() self.started = threading.Event()
self.lock = threading.RLock() self.lock = threading.RLock()
self.shutdown_lock = threading.Lock() self.shutdown_lock = threading.RLock()
self.started.clear() self.started.clear()
self.log = ndb.log.channel('sources.%s' % self.target) self.log = ndb.log.channel('sources.%s' % self.target)
self.state = State(log=self.log) self.state = State(log=self.log)
...@@ -357,16 +354,19 @@ class Source(dict): ...@@ -357,16 +354,19 @@ class Source(dict):
def restart(self, reason='unknown'): def restart(self, reason='unknown'):
with self.lock: with self.lock:
if not self.shutdown.is_set(): with self.shutdown_lock:
self.log.debug('restarting the source, reason <%s>' % (reason)) self.log.debug('restarting the source, reason <%s>' % (reason))
self.evq.put((cmsg(self.target, SchemaReadLock()), )) self.started.clear()
self.ndb.schema.allow_read(False)
try: try:
self.close() self.close()
if self.th: if self.th:
self.th.join() self.th.join()
self.shutdown.clear()
self.start() self.start()
finally: finally:
self.evq.put((cmsg(self.target, SchemaReadUnlock()), )) self.ndb.schema.allow_read(True)
self.started.wait()
def __enter__(self): def __enter__(self):
return self return self
......
...@@ -81,6 +81,40 @@ class TestMisc(object): ...@@ -81,6 +81,40 @@ class TestMisc(object):
for source in ndb.sources: for source in ndb.sources:
assert ndb.sources[source].nl.closed assert ndb.sources[source].nl.closed
def test_source_localhost_restart(self):
require_user('root')
ifname = uifname()
with NDB() as ndb:
assert len(list(ndb.interfaces.dump()))
ndb.sources['localhost'].restart()
assert len(list(ndb.interfaces.dump()))
(ndb
.interfaces
.create(ifname=ifname, kind='dummy', state='up')
.commit())
assert ndb.interfaces[ifname]['state'] == 'up'
ndb.interfaces[ifname].remove().commit()
def test_source_netns_restart(self):
require_user('root')
ifname = uifname()
nsname = str(uuid.uuid4())
with NDB() as ndb:
ndb.sources.add(netns=nsname)
assert len(list(ndb.interfaces.dump().filter(target=nsname)))
ndb.sources[nsname].restart()
assert len(list(ndb.interfaces.dump().filter(target=nsname)))
(ndb
.interfaces
.create(target=nsname, ifname=ifname, kind='dummy', state='up')
.commit())
assert ndb.interfaces[{'target': nsname,
'ifname': ifname}]['state'] == 'up'
ndb.interfaces[{'target': nsname,
'ifname': ifname}].remove().commit()
class TestPreSet(object): class TestPreSet(object):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment