Commit a1e8efc0 authored by Peter V. Saveliev's avatar Peter V. Saveliev

iproute: improve `flush_routes()`

Delete routes as reading messages; do not cache them.

Bug-Url: https://github.com/svinota/pyroute2/issues/316
parent 09cd0fc5
......@@ -175,6 +175,7 @@ from socket import AF_UNSPEC
from socket import AF_BRIDGE
from types import FunctionType
from types import MethodType
from pyroute2.netlink import NLMSG_DONE
from pyroute2.netlink import NLMSG_ERROR
from pyroute2.netlink import NLM_F_ATOMIC
from pyroute2.netlink import NLM_F_ROOT
......@@ -500,6 +501,7 @@ class IPRouteMixin(object):
msg_flags = NLM_F_DUMP | NLM_F_REQUEST
nkw = {}
nkw['callback'] = kwarg.pop('callback', None)
# get a particular route?
if isinstance(kwarg.get('dst'), basestring):
......@@ -598,13 +600,23 @@ class IPRouteMixin(object):
routine. Actually, this routine implements a pipe from
`get_routes()` to `nlm_request()`.
'''
flags = NLM_F_ACK | NLM_F_CREATE | NLM_F_EXCL | NLM_F_REQUEST
ret = []
def callback(msg):
if msg['header']['type'] == NLMSG_DONE:
# this message will pass to the get()
return False
# all other messages are filtered
table = msg.get_attr('RTA_TABLE') or msg.get('table', None)
if table == kwarg.get('table', DEFAULT_TABLE):
# delete matching routes
self.put(msg, msg_type=RTM_DELROUTE, msg_flags=NLM_F_REQUEST)
# ignore others
return True
kwarg['table'] = kwarg.get('table', DEFAULT_TABLE)
for route in self.get_routes(*argv, **kwarg):
ret.append(self.nlm_request(route,
msg_type=RTM_DELROUTE,
msg_flags=flags))
kwarg['callback'] = callback
self.get_routes(*argv, **kwarg)
return ret
def flush_addr(self, *argv, **kwarg):
......@@ -1608,6 +1620,7 @@ class IPRouteMixin(object):
match = kwarg
else:
match = kwarg.pop('match', None)
callback = kwarg.pop('callback', None)
commands = {'add': (RTM_NEWROUTE, flags_make),
'set': (RTM_NEWROUTE, flags_replace),
......@@ -1660,7 +1673,10 @@ class IPRouteMixin(object):
attr[1].find(':') >= 0 else AF_INET
break
ret = self.nlm_request(msg, msg_type=command, msg_flags=flags)
ret = self.nlm_request(msg,
msg_type=command,
msg_flags=flags,
callback=callback)
if match:
return self._match(match, ret)
else:
......
......@@ -139,7 +139,7 @@ class Marshal(object):
self.msg_map = self.msg_map or {}
self.defragmentation = {}
def parse(self, data):
def parse(self, data, seq=None, callback=None):
'''
Parse string data.
......@@ -178,6 +178,10 @@ class Marshal(object):
enc = enc_class(data, offset=offset+20)
enc.decode()
msg['header']['errmsg'] = enc
if callback and seq == msg['header']['sequence_number']:
if callback(msg):
offset += msg.length
continue
except NetlinkHeaderDecodeError as e:
# in the case of header decoding error,
# create an empty message
......@@ -565,7 +569,10 @@ class NetlinkMixin(object):
def sendto_gate(self, msg, addr):
raise NotImplementedError()
def get(self, bufsize=DEFAULT_RCVBUF, msg_seq=0, terminate=None):
def get(self, bufsize=DEFAULT_RCVBUF,
msg_seq=0,
terminate=None,
callback=None):
'''
Get parsed messages list. If `msg_seq` is given, return
only messages with that `msg['header']['sequence_number']`,
......@@ -700,7 +707,7 @@ class NetlinkMixin(object):
# locks, except the read lock must be released
data = self.recv_ft(bufsize)
# Parse data
msgs = self.marshal.parse(data)
msgs = self.marshal.parse(data, msg_seq, callback)
# Reset ctime -- timeout should be measured
# for every turn separately
ctime = time.time()
......@@ -763,6 +770,7 @@ class NetlinkMixin(object):
def nlm_request(self, msg, msg_type,
msg_flags=NLM_F_REQUEST | NLM_F_DUMP,
terminate=None,
callback=None,
exception_catch=Exception,
exception_handler=None):
......@@ -771,7 +779,9 @@ class NetlinkMixin(object):
with self.lock[msg_seq]:
try:
self.put(msg, msg_type, msg_flags, msg_seq=msg_seq)
ret = self.get(msg_seq=msg_seq, terminate=terminate)
ret = self.get(msg_seq=msg_seq,
terminate=terminate,
callback=callback)
return ret
except Exception:
raise
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment