mirror of
https://github.com/Mikaela/Limnoria.git
synced 2024-11-26 20:59:27 +01:00
STS: When persisting STS keys, use the actual port instead of the one from the policy
'Servers MAY send this key to securely connected clients, but it will be ignored.' -- https://ircv3.net/specs/extensions/sts\#the-port-key
This commit is contained in:
parent
74073b2736
commit
ee9f0dc1bf
@ -97,7 +97,8 @@ class ServersMixin(object):
|
|||||||
|
|
||||||
def _applyStsPolicy(self, server):
|
def _applyStsPolicy(self, server):
|
||||||
network = ircdb.networks.getNetwork(self.networkName)
|
network = ircdb.networks.getNetwork(self.networkName)
|
||||||
policy = network.stsPolicies.get(server.hostname)
|
(policy_port, policy) = network.stsPolicies.get(
|
||||||
|
server.hostname, (None, None))
|
||||||
lastDisconnect = network.lastDisconnectTimes.get(server.hostname)
|
lastDisconnect = network.lastDisconnectTimes.get(server.hostname)
|
||||||
|
|
||||||
if policy is None or lastDisconnect is None:
|
if policy is None or lastDisconnect is None:
|
||||||
@ -107,22 +108,22 @@ class ServersMixin(object):
|
|||||||
|
|
||||||
# The policy was stored, which means it was received on a secure
|
# The policy was stored, which means it was received on a secure
|
||||||
# connection.
|
# connection.
|
||||||
policy = ircutils.parseStsPolicy(log, policy, parseDuration=True)
|
policy = ircutils.parseStsPolicy(log, policy, secure_connection=True)
|
||||||
|
|
||||||
if lastDisconnect + policy['duration'] < time.time():
|
if lastDisconnect + policy['duration'] < time.time():
|
||||||
log.info('STS policy expired, removing.')
|
log.info('STS policy expired, removing.')
|
||||||
network.expireStsPolicy(server.hostname)
|
network.expireStsPolicy(server.hostname)
|
||||||
return server
|
return server
|
||||||
|
|
||||||
if server.port == policy['port']:
|
if server.port == policy_port:
|
||||||
log.info('Using STS policy, port %s', server.port)
|
log.info('Using STS policy, port %s', server.port)
|
||||||
else:
|
else:
|
||||||
log.info('Using STS policy: changing port from %s to %s.',
|
log.info('Using STS policy: changing port from %s to %s.',
|
||||||
server.port, policy['port'])
|
server.port, policy_port)
|
||||||
|
|
||||||
# Change the port, and force TLS verification, as required by the STS
|
# Change the port, and force TLS verification, as required by the STS
|
||||||
# specification.
|
# specification.
|
||||||
return Server(server.hostname, policy['port'], server.attempt,
|
return Server(server.hostname, policy_port, server.attempt,
|
||||||
force_tls_verification=True)
|
force_tls_verification=True)
|
||||||
|
|
||||||
def die(self):
|
def die(self):
|
||||||
|
21
src/ircdb.py
21
src/ircdb.py
@ -509,9 +509,10 @@ class IrcNetwork(object):
|
|||||||
(self.__class__.__name__, self.stsPolicies,
|
(self.__class__.__name__, self.stsPolicies,
|
||||||
self.lastDisconnectTimes)
|
self.lastDisconnectTimes)
|
||||||
|
|
||||||
def addStsPolicy(self, server, stsPolicy):
|
def addStsPolicy(self, server, port, stsPolicy):
|
||||||
assert isinstance(stsPolicy, str)
|
assert isinstance(port, int), repr(port)
|
||||||
self.stsPolicies[server] = stsPolicy
|
assert isinstance(stsPolicy, str), repr(stsPolicy)
|
||||||
|
self.stsPolicies[server] = (port, stsPolicy)
|
||||||
|
|
||||||
def expireStsPolicy(self, server):
|
def expireStsPolicy(self, server):
|
||||||
if server in self.stsPolicies:
|
if server in self.stsPolicies:
|
||||||
@ -526,8 +527,10 @@ class IrcNetwork(object):
|
|||||||
fd.write(s)
|
fd.write(s)
|
||||||
fd.write(os.linesep)
|
fd.write(os.linesep)
|
||||||
|
|
||||||
for (server, stsPolicy) in sorted(self.stsPolicies.items()):
|
for (server, (port, stsPolicy)) in sorted(self.stsPolicies.items()):
|
||||||
write('stsPolicy %s %s' % (server, stsPolicy))
|
assert isinstance(port, int), repr(port)
|
||||||
|
assert isinstance(stsPolicy, str), repr(stsPolicy)
|
||||||
|
write('stsPolicy %s %s %s' % (server, port, stsPolicy))
|
||||||
|
|
||||||
for (server, disconnectTime) in \
|
for (server, disconnectTime) in \
|
||||||
sorted(self.lastDisconnectTimes.items()):
|
sorted(self.lastDisconnectTimes.items()):
|
||||||
@ -667,8 +670,12 @@ class IrcNetworkCreator(Creator):
|
|||||||
IrcNetworkCreator.name = rest
|
IrcNetworkCreator.name = rest
|
||||||
|
|
||||||
def stspolicy(self, rest, lineno):
|
def stspolicy(self, rest, lineno):
|
||||||
(server, stsPolicy) = rest.split()
|
L = rest.split()
|
||||||
self.net.addStsPolicy(server, stsPolicy)
|
if len(L) == 2:
|
||||||
|
# Old policy missing a port. Discard it
|
||||||
|
return
|
||||||
|
(server, policyPort, stsPolicy) = L
|
||||||
|
self.net.addStsPolicy(server, int(policyPort), stsPolicy)
|
||||||
|
|
||||||
def lastdisconnecttime(self, rest, lineno):
|
def lastdisconnecttime(self, rest, lineno):
|
||||||
(server, when) = rest.split()
|
(server, when) = rest.split()
|
||||||
|
@ -2050,7 +2050,7 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
or (self.driver.ssl and self.driver.anyCertValidationEnabled())
|
or (self.driver.ssl and self.driver.anyCertValidationEnabled())
|
||||||
|
|
||||||
parsed_policy = ircutils.parseStsPolicy(
|
parsed_policy = ircutils.parseStsPolicy(
|
||||||
log, policy, parseDuration=secure_connection)
|
log, policy, secure_connection=secure_connection)
|
||||||
if parsed_policy is None:
|
if parsed_policy is None:
|
||||||
# There was an error (and it was logged). Ignore it and proceed
|
# There was an error (and it was logged). Ignore it and proceed
|
||||||
# with the connection.
|
# with the connection.
|
||||||
@ -2065,9 +2065,14 @@ class Irc(IrcCommandDispatcher, log.Firewalled):
|
|||||||
# For future-proofing (because we don't want to write an invalid
|
# For future-proofing (because we don't want to write an invalid
|
||||||
# value), we write the raw policy received from the server instead
|
# value), we write the raw policy received from the server instead
|
||||||
# of the parsed one.
|
# of the parsed one.
|
||||||
log.debug('Storing STS policy: %s', policy)
|
log.debug('Storing STS policy for %s (TLS port %s): %s',
|
||||||
|
self.driver.currentServer.hostname,
|
||||||
|
self.driver.currentServer.port,
|
||||||
|
policy)
|
||||||
ircdb.networks.getNetwork(self.network).addStsPolicy(
|
ircdb.networks.getNetwork(self.network).addStsPolicy(
|
||||||
self.driver.currentServer.hostname, policy)
|
self.driver.currentServer.hostname,
|
||||||
|
self.driver.currentServer.port,
|
||||||
|
policy)
|
||||||
else:
|
else:
|
||||||
hostname = self.driver.currentServer.hostname
|
hostname = self.driver.currentServer.hostname
|
||||||
attempt = self.driver.currentServer.attempt
|
attempt = self.driver.currentServer.attempt
|
||||||
|
@ -1073,11 +1073,15 @@ def parseCapabilityKeyValue(s):
|
|||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def parseStsPolicy(logger, policy, parseDuration):
|
def parseStsPolicy(logger, policy, secure_connection):
|
||||||
parsed_policy = parseCapabilityKeyValue(policy)
|
parsed_policy = parseCapabilityKeyValue(policy)
|
||||||
|
|
||||||
for key in ('port', 'duration'):
|
for key in ('port', 'duration'):
|
||||||
if key == 'duration' and not parseDuration:
|
if key == 'duration' and not secure_connection:
|
||||||
|
if key in parsed_policy:
|
||||||
|
del parsed_policy[key]
|
||||||
|
continue
|
||||||
|
elif key == 'port' and secure_connection:
|
||||||
if key in parsed_policy:
|
if key in parsed_policy:
|
||||||
del parsed_policy[key]
|
del parsed_policy[key]
|
||||||
continue
|
continue
|
||||||
|
@ -39,7 +39,7 @@ class DriversTestCase(SupyTestCase):
|
|||||||
def testValidStsPolicy(self):
|
def testValidStsPolicy(self):
|
||||||
irc = irclib.Irc('test')
|
irc = irclib.Irc('test')
|
||||||
net = ircdb.networks.getNetwork('test')
|
net = ircdb.networks.getNetwork('test')
|
||||||
net.addStsPolicy('example.com', 'duration=10,port=6697')
|
net.addStsPolicy('example.com', 6697, 'duration=10,port=12345')
|
||||||
net.addDisconnection('example.com')
|
net.addDisconnection('example.com')
|
||||||
|
|
||||||
with conf.supybot.networks.test.servers.context(
|
with conf.supybot.networks.test.servers.context(
|
||||||
@ -64,7 +64,7 @@ class DriversTestCase(SupyTestCase):
|
|||||||
def testExpiredStsPolicy(self):
|
def testExpiredStsPolicy(self):
|
||||||
irc = irclib.Irc('test')
|
irc = irclib.Irc('test')
|
||||||
net = ircdb.networks.getNetwork('test')
|
net = ircdb.networks.getNetwork('test')
|
||||||
net.addStsPolicy('example.com', 'duration=10,port=6697')
|
net.addStsPolicy('example.com', 6697, 'duration=10')
|
||||||
net.addDisconnection('example.com')
|
net.addDisconnection('example.com')
|
||||||
|
|
||||||
timeFastForward(16)
|
timeFastForward(16)
|
||||||
@ -81,7 +81,7 @@ class DriversTestCase(SupyTestCase):
|
|||||||
def testRescheduledStsPolicy(self):
|
def testRescheduledStsPolicy(self):
|
||||||
irc = irclib.Irc('test')
|
irc = irclib.Irc('test')
|
||||||
net = ircdb.networks.getNetwork('test')
|
net = ircdb.networks.getNetwork('test')
|
||||||
net.addStsPolicy('example.com', 'duration=10,port=6697')
|
net.addStsPolicy('example.com', 6697, 'duration=10')
|
||||||
net.addDisconnection('example.com')
|
net.addDisconnection('example.com')
|
||||||
|
|
||||||
with conf.supybot.networks.test.servers.context(
|
with conf.supybot.networks.test.servers.context(
|
||||||
|
@ -358,11 +358,11 @@ class IrcNetworkTestCase(IrcdbTestCase):
|
|||||||
|
|
||||||
def testStsPolicy(self):
|
def testStsPolicy(self):
|
||||||
n = ircdb.IrcNetwork()
|
n = ircdb.IrcNetwork()
|
||||||
n.addStsPolicy('foo', 'bar')
|
n.addStsPolicy('foo', 123, 'bar')
|
||||||
n.addStsPolicy('baz', 'qux')
|
n.addStsPolicy('baz', 456, 'qux')
|
||||||
self.assertEqual(n.stsPolicies, {
|
self.assertEqual(n.stsPolicies, {
|
||||||
'foo': 'bar',
|
'foo': (123, 'bar'),
|
||||||
'baz': 'qux',
|
'baz': (456, 'qux'),
|
||||||
})
|
})
|
||||||
|
|
||||||
def testAddDisconnection(self):
|
def testAddDisconnection(self):
|
||||||
@ -374,8 +374,8 @@ class IrcNetworkTestCase(IrcdbTestCase):
|
|||||||
|
|
||||||
def testPreserve(self):
|
def testPreserve(self):
|
||||||
n = ircdb.IrcNetwork()
|
n = ircdb.IrcNetwork()
|
||||||
n.addStsPolicy('foo', 'sts1')
|
n.addStsPolicy('foo', 123, 'sts1')
|
||||||
n.addStsPolicy('bar', 'sts2')
|
n.addStsPolicy('bar', 456,'sts2')
|
||||||
n.addDisconnection('foo')
|
n.addDisconnection('foo')
|
||||||
n.addDisconnection('baz')
|
n.addDisconnection('baz')
|
||||||
disconnect_time_foo = n.lastDisconnectTimes['foo']
|
disconnect_time_foo = n.lastDisconnectTimes['foo']
|
||||||
@ -384,8 +384,8 @@ class IrcNetworkTestCase(IrcdbTestCase):
|
|||||||
n.preserve(fd, indent=' ')
|
n.preserve(fd, indent=' ')
|
||||||
fd.seek(0)
|
fd.seek(0)
|
||||||
self.assertCountEqual(fd.read().split('\n'), [
|
self.assertCountEqual(fd.read().split('\n'), [
|
||||||
' stsPolicy foo sts1',
|
' stsPolicy foo 123 sts1',
|
||||||
' stsPolicy bar sts2',
|
' stsPolicy bar 456 sts2',
|
||||||
' lastDisconnectTime foo %d' % disconnect_time_foo,
|
' lastDisconnectTime foo %d' % disconnect_time_foo,
|
||||||
' lastDisconnectTime baz %d' % disconnect_time_baz,
|
' lastDisconnectTime baz %d' % disconnect_time_baz,
|
||||||
'',
|
'',
|
||||||
@ -467,8 +467,8 @@ class NetworksDictionaryTestCase(IrcdbTestCase):
|
|||||||
|
|
||||||
def testPreserveOne(self):
|
def testPreserveOne(self):
|
||||||
n = ircdb.IrcNetwork()
|
n = ircdb.IrcNetwork()
|
||||||
n.addStsPolicy('foo', 'sts1')
|
n.addStsPolicy('foo', 123, 'sts1')
|
||||||
n.addStsPolicy('bar', 'sts2')
|
n.addStsPolicy('bar', 456, 'sts2')
|
||||||
n.addDisconnection('foo')
|
n.addDisconnection('foo')
|
||||||
n.addDisconnection('baz')
|
n.addDisconnection('baz')
|
||||||
disconnect_time_foo = n.lastDisconnectTimes['foo']
|
disconnect_time_foo = n.lastDisconnectTimes['foo']
|
||||||
@ -486,8 +486,8 @@ class NetworksDictionaryTestCase(IrcdbTestCase):
|
|||||||
lines = fd.getvalue().split('\n')
|
lines = fd.getvalue().split('\n')
|
||||||
self.assertEqual(lines.pop(0), 'network foonet')
|
self.assertEqual(lines.pop(0), 'network foonet')
|
||||||
self.assertCountEqual(lines, [
|
self.assertCountEqual(lines, [
|
||||||
' stsPolicy foo sts1',
|
' stsPolicy foo 123 sts1',
|
||||||
' stsPolicy bar sts2',
|
' stsPolicy bar 456 sts2',
|
||||||
' lastDisconnectTime foo %d' % disconnect_time_foo,
|
' lastDisconnectTime foo %d' % disconnect_time_foo,
|
||||||
' lastDisconnectTime baz %d' % disconnect_time_baz,
|
' lastDisconnectTime baz %d' % disconnect_time_baz,
|
||||||
'',
|
'',
|
||||||
@ -496,15 +496,15 @@ class NetworksDictionaryTestCase(IrcdbTestCase):
|
|||||||
|
|
||||||
def testPreserveThree(self):
|
def testPreserveThree(self):
|
||||||
n = ircdb.IrcNetwork()
|
n = ircdb.IrcNetwork()
|
||||||
n.addStsPolicy('foo', 'sts1')
|
n.addStsPolicy('foo', 123, 'sts1')
|
||||||
self.networks.setNetwork('foonet', n)
|
self.networks.setNetwork('foonet', n)
|
||||||
|
|
||||||
n = ircdb.IrcNetwork()
|
n = ircdb.IrcNetwork()
|
||||||
n.addStsPolicy('bar', 'sts2')
|
n.addStsPolicy('bar', 456, 'sts2')
|
||||||
self.networks.setNetwork('barnet', n)
|
self.networks.setNetwork('barnet', n)
|
||||||
|
|
||||||
n = ircdb.IrcNetwork()
|
n = ircdb.IrcNetwork()
|
||||||
n.addStsPolicy('baz', 'sts3')
|
n.addStsPolicy('baz', 789, 'sts3')
|
||||||
self.networks.setNetwork('baznet', n)
|
self.networks.setNetwork('baznet', n)
|
||||||
|
|
||||||
fd = io.StringIO()
|
fd = io.StringIO()
|
||||||
@ -518,13 +518,13 @@ class NetworksDictionaryTestCase(IrcdbTestCase):
|
|||||||
fd.seek(0)
|
fd.seek(0)
|
||||||
self.assertEqual(fd.getvalue(),
|
self.assertEqual(fd.getvalue(),
|
||||||
'network barnet\n'
|
'network barnet\n'
|
||||||
' stsPolicy bar sts2\n'
|
' stsPolicy bar 456 sts2\n'
|
||||||
'\n'
|
'\n'
|
||||||
'network baznet\n'
|
'network baznet\n'
|
||||||
' stsPolicy baz sts3\n'
|
' stsPolicy baz 789 sts3\n'
|
||||||
'\n'
|
'\n'
|
||||||
'network foonet\n'
|
'network foonet\n'
|
||||||
' stsPolicy foo sts1\n'
|
' stsPolicy foo 123 sts1\n'
|
||||||
'\n'
|
'\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -759,16 +759,27 @@ class StsTestCase(SupyTestCase):
|
|||||||
self.irc.driver.ssl = True
|
self.irc.driver.ssl = True
|
||||||
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
|
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
|
||||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||||
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
args=('*', 'LS', 'sts=duration=42,port=12345')))
|
||||||
|
|
||||||
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {
|
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {
|
||||||
'irc.test': 'duration=42,port=6697'})
|
'irc.test': (6697, 'duration=42,port=12345')})
|
||||||
|
self.irc.driver.reconnect.assert_not_called()
|
||||||
|
|
||||||
|
def testStsInSecureConnectionNoPort(self):
|
||||||
|
self.irc.driver.anyCertValidationEnabled.return_value = True
|
||||||
|
self.irc.driver.ssl = True
|
||||||
|
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
|
||||||
|
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||||
|
args=('*', 'LS', 'sts=duration=42')))
|
||||||
|
|
||||||
|
self.assertEqual(ircdb.networks.getNetwork('test').stsPolicies, {
|
||||||
|
'irc.test': (6697, 'duration=42')})
|
||||||
self.irc.driver.reconnect.assert_not_called()
|
self.irc.driver.reconnect.assert_not_called()
|
||||||
|
|
||||||
def testStsInInsecureTlsConnection(self):
|
def testStsInInsecureTlsConnection(self):
|
||||||
self.irc.driver.anyCertValidationEnabled.return_value = False
|
self.irc.driver.anyCertValidationEnabled.return_value = False
|
||||||
self.irc.driver.ssl = True
|
self.irc.driver.ssl = True
|
||||||
self.irc.driver.currentServer = drivers.Server('irc.test', 6697, None, False)
|
self.irc.driver.currentServer = drivers.Server('irc.test', 6667, None, False)
|
||||||
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
self.irc.feedMsg(ircmsgs.IrcMsg(command='CAP',
|
||||||
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
args=('*', 'LS', 'sts=duration=42,port=6697')))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user