Limnoria/src/utils.py

741 lines
23 KiB
Python
Raw Normal View History

#!/usr/bin/env python
###
# Copyright (c) 2002, Jeremiah Fincher
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions, and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions, and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the author of this software nor the name of
# contributors to this software may be used to endorse or promote products
# derived from this software without specific prior written consent.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
###
"""
Simple utility functions.
"""
2003-11-25 09:38:19 +01:00
__revision__ = "$Id$"
2004-07-24 07:18:26 +02:00
import supybot.fix as fix
import os
import re
2004-07-31 00:35:51 +02:00
import sys
2003-10-28 15:20:00 +01:00
import md5
2004-04-20 11:51:20 +02:00
import new
2003-10-28 15:20:00 +01:00
import sha
import sets
2004-07-31 01:39:59 +02:00
import time
import types
2004-07-31 01:39:59 +02:00
import random
2004-07-31 11:44:03 +02:00
import shutil
2004-01-16 18:33:51 +01:00
import socket
import string
import sgmllib
2003-11-07 20:40:03 +01:00
import compiler
import textwrap
import UserDict
import itertools
2004-07-31 00:35:51 +02:00
import traceback
import htmlentitydefs
2003-12-02 19:58:57 +01:00
from itertools import imap, ifilter
2004-07-24 07:18:26 +02:00
from supybot.structures import TwoWayDictionary
2004-04-20 11:44:58 +02:00
curry = new.instancemethod
2003-09-01 20:39:27 +02:00
def normalizeWhitespace(s):
"""Normalizes the whitespace in a string; \s+ becomes one space."""
return ' '.join(s.split())
class HtmlToText(sgmllib.SGMLParser):
"""Taken from some eff-bot code on c.l.p."""
2004-07-01 19:56:02 +02:00
entitydefs = htmlentitydefs.entitydefs.copy()
entitydefs['nbsp'] = ' '
def __init__(self, tagReplace=' '):
self.data = []
self.tagReplace = tagReplace
sgmllib.SGMLParser.__init__(self)
def unknown_starttag(self, tag, attr):
self.data.append(self.tagReplace)
def unknown_endtag(self, tag):
self.data.append(self.tagReplace)
def handle_data(self, data):
self.data.append(data)
def getText(self):
text = ''.join(self.data).strip()
2003-09-01 20:39:27 +02:00
return normalizeWhitespace(text)
def htmlToText(s, tagReplace=' '):
2003-08-10 12:45:44 +02:00
"""Turns HTML into text. tagReplace is a string to replace HTML tags with.
"""
x = HtmlToText(tagReplace)
x.feed(s)
return x.getText()
2003-03-31 07:14:21 +02:00
def eachSubstring(s):
2003-08-10 12:45:44 +02:00
"""Returns every substring starting at the first index until the last."""
for i in xrange(1, len(s)+1):
2003-03-31 07:14:21 +02:00
yield s[:i]
def abbrev(strings, d=None):
2003-08-10 12:45:44 +02:00
"""Returns a dictionary mapping unambiguous abbreviations to full forms."""
if d is None:
d = {}
2003-03-31 07:14:21 +02:00
for s in strings:
for abbreviation in eachSubstring(s):
if abbreviation not in d:
d[abbreviation] = s
else:
if abbreviation not in strings:
d[abbreviation] = None
2003-03-31 07:14:21 +02:00
removals = []
for key in d:
if d[key] is None:
removals.append(key)
for key in removals:
del d[key]
return d
def timeElapsed(elapsed, leadingZeroes=False, years=True, weeks=True,
days=True, hours=True, minutes=True, seconds=True):
2003-08-10 12:45:44 +02:00
"""Given <elapsed> seconds, returns a string with an English description of
how much time as passed. leadingZeroes determines whether 0 days, 0 hours,
etc. will be printed; the others determine what larger time periods should
be used.
"""
elapsed = int(elapsed)
assert years or weeks or days or \
hours or minutes or seconds, 'One flag must be True'
ret = []
if years:
yrs, elapsed = elapsed // 31536000, elapsed % 31536000
if leadingZeroes or yrs:
if yrs:
leadingZeroes = True
ret.append(nItems('year', yrs))
if weeks:
wks, elapsed = elapsed // 604800, elapsed % 604800
if leadingZeroes or wks:
if wks:
leadingZeroes = True
ret.append(nItems('week', wks))
if days:
ds, elapsed = elapsed // 86400, elapsed % 86400
if leadingZeroes or ds:
if ds:
leadingZeroes = True
ret.append(nItems('day', ds))
if hours:
hrs, elapsed = elapsed // 3600, elapsed % 3600
if leadingZeroes or hrs:
if hrs:
leadingZeroes = True
ret.append(nItems('hour', hrs))
if minutes or seconds:
mins, secs = elapsed // 60, elapsed % 60
if leadingZeroes or mins:
ret.append(nItems('minute', mins))
if seconds:
ret.append(nItems('second', secs))
if len(ret) == 0:
raise ValueError, 'Time difference not great enough to be noted.'
2004-07-31 14:14:36 +02:00
return commaAndify(ret)
2003-04-04 17:49:24 +02:00
def distance(s, t):
2003-08-10 12:45:44 +02:00
"""Returns the levenshtein edit distance between two strings."""
2003-04-04 17:49:24 +02:00
n = len(s)
m = len(t)
if n == 0:
return m
elif m == 0:
return n
d = []
for i in range(n+1):
d.append([])
for j in range(m+1):
d[i].append(0)
d[0][j] = j
d[i][0] = i
2003-04-04 17:49:24 +02:00
for i in range(1, n+1):
cs = s[i-1]
for j in range(1, m+1):
ct = t[j-1]
cost = int(cs != ct)
2003-04-04 17:49:24 +02:00
d[i][j] = min(d[i-1][j]+1, d[i][j-1]+1, d[i-1][j-1]+cost)
return d[n][m]
_soundextrans = string.maketrans(string.ascii_uppercase,
'01230120022455012623010202')
_notUpper = string.ascii.translate(string.ascii, string.ascii_uppercase)
def soundex(s, length=4):
2003-08-10 12:45:44 +02:00
"""Returns the soundex hash of a given string."""
s = s.upper() # Make everything uppercase.
s = s.translate(string.ascii, _notUpper) # Delete non-letters.
if not s:
raise ValueError, 'Invalid string for soundex: %s'
firstChar = s[0] # Save the first character.
s = s.translate(_soundextrans) # Convert to soundex numbers.
s = s.lstrip(s[0]) # Remove all repeated first characters.
L = [firstChar]
for c in s:
if c != L[-1]:
L.append(c)
2003-08-10 12:45:44 +02:00
L = [c for c in L if c != '0'] + (['0']*(length-1))
s = ''.join(L)
return length and s[:length] or s.rstrip('0')
2003-04-12 14:50:20 +02:00
def dqrepr(s):
"""Returns a repr() of s guaranteed to be in double quotes."""
2003-07-31 08:20:58 +02:00
# The wankers-that-be decided not to use double-quotes anymore in 2.3.
# return '"' + repr("'\x00" + s)[6:]
return '"%s"' % s.encode('string_escape').replace('"', '\\"')
2003-04-12 14:50:20 +02:00
nonEscapedSlashes = re.compile(r'(?<!\\)/')
def perlReToPythonRe(s):
2003-08-10 12:45:44 +02:00
"""Converts a string representation of a Perl regular expression (i.e.,
m/^foo$/i or /foo|bar/) to a Python regular expression.
"""
2003-11-25 10:13:28 +01:00
try:
(kind, regexp, flags) = nonEscapedSlashes.split(s)
except ValueError: # Unpack list of wrong size.
raise ValueError, 'Must be of the form m/.../ or /.../'
regexp = regexp.replace('\\/', '/')
if kind not in ('', 'm'):
raise ValueError, 'Invalid kind: must be in ("", "m")'
flag = 0
try:
for c in flags.upper():
flag |= getattr(re, c)
except AttributeError:
raise ValueError, 'Invalid flag: %s' % c
try:
return re.compile(regexp, flag)
except re.error, e:
raise ValueError, str(e)
def perlReToReplacer(s):
2003-08-10 12:45:44 +02:00
"""Converts a string representation of a Perl regular expression (i.e.,
s/foo/bar/g or s/foo/bar/i) to a Python function doing the equivalent
replacement.
"""
2003-11-25 10:13:28 +01:00
try:
(kind, regexp, replace, flags) = nonEscapedSlashes.split(s)
except ValueError: # Unpack list of wrong size.
raise ValueError, 'Must be of the form s/.../.../'
regexp = regexp.replace('\x08', r'\b')
replace = replace.replace('\\/', '/')
for i in xrange(10):
replace = replace.replace(chr(i), r'\%s' % i)
if kind != 's':
raise ValueError, 'Invalid kind: must be "s"'
g = False
if 'g' in flags:
g = True
flags = filter('g'.__ne__, flags)
r = perlReToPythonRe('/'.join(('', regexp, flags)))
if g:
return curry(r.sub, replace)
else:
return lambda s: r.sub(replace, s, 1)
def findBinaryInPath(s):
2003-08-19 21:02:59 +02:00
"""Return full path of a binary if it's in PATH, otherwise return None."""
cmdLine = None
for dir in os.getenv('PATH').split(':'):
filename = os.path.join(dir, s)
if os.path.exists(filename):
cmdLine = filename
break
return cmdLine
2004-04-28 08:26:02 +02:00
def commaAndify(seq, comma=',', And='and'):
"""Given a a sequence, returns an english clause for that sequence.
I.e., given [1, 2, 3], returns '1, 2, and 3'
"""
L = list(seq)
if len(L) == 0:
return ''
elif len(L) == 1:
return ''.join(L) # We need this because it raises TypeError.
elif len(L) == 2:
L.insert(1, And)
return ' '.join(L)
else:
L[-1] = '%s %s' % (And, L[-1])
2004-04-28 08:26:02 +02:00
sep = '%s ' % comma
return sep.join(L)
2003-08-23 09:57:04 +02:00
_unCommaTheRe = re.compile(r'(.*),\s*(the)$', re.I)
def unCommaThe(s):
"""Takes a string of the form 'foo, the' and turns it into 'the foo'."""
2003-08-23 09:57:04 +02:00
m = _unCommaTheRe.match(s)
if m is not None:
return '%s %s' % (m.group(2), m.group(1))
else:
return s
def wrapLines(s):
"""Word wraps several paragraphs in a string s."""
L = []
for line in s.splitlines():
L.append(textwrap.fill(line))
return '\n'.join(L)
def ellipsisify(s, n):
"""Returns a shortened version of s. Produces up to the first n chars at
the nearest word boundary.
"""
if len(s) <= n:
return s
else:
return (textwrap.wrap(s, n-3)[0] + '...')
2003-12-02 19:58:57 +01:00
plurals = TwoWayDictionary({})
def matchCase(s1, s2):
"""Matches the case of s1 in s2"""
2003-12-02 19:58:57 +01:00
if s1.isupper():
return s2.upper()
else:
L = list(s2)
for (i, char) in enumerate(s1[:len(s2)]):
if char.isupper():
L[i] = L[i].upper()
2003-12-02 19:58:57 +01:00
return ''.join(L)
2004-01-28 22:42:46 +01:00
consonants = 'bcdfghjklmnpqrstvwxz'
_pluralizeRegex = re.compile('[%s]y$' % consonants)
def pluralize(s, i=2):
"""Returns the plural of s based on its number i. Put any exceptions to
the general English rule of appending 's' in the plurals dictionary.
"""
2003-09-01 07:42:35 +02:00
if i == 1:
return s
else:
lowered = s.lower()
2004-01-28 22:42:46 +01:00
# Exception dictionary
if lowered in plurals:
2004-07-21 21:36:35 +02:00
return matchCase(s, plurals[lowered])
2004-01-28 22:42:46 +01:00
# Words ending with 'ch', 'sh' or 'ss' such as 'punch(es)', 'fish(es)
# and miss(es)
elif any(lowered.endswith, ['ch', 'sh', 'ss']):
2004-07-21 21:36:35 +02:00
return matchCase(s, s+'es')
2004-01-28 22:42:46 +01:00
# Words ending with a consonant followed by a 'y' such as
# 'try (tries)' or 'spy (spies)'
elif re.search(_pluralizeRegex, lowered):
return matchCase(s, s[:-1] + 'ies')
# In all other cases, we simply add an 's' to the base word
else:
return matchCase(s, s+'s')
2004-01-28 22:42:46 +01:00
_depluralizeRegex = re.compile('[%s]ies' % consonants)
def depluralize(s):
2003-10-16 22:06:17 +02:00
"""Returns the singular of s."""
lowered = s.lower()
if lowered in plurals:
return matchCase(s, plurals[lowered])
2004-01-28 22:42:46 +01:00
elif any(lowered.endswith, ['ches', 'shes', 'sses']):
2003-12-02 19:58:57 +01:00
return s[:-2]
2004-01-28 22:42:46 +01:00
elif re.search(_depluralizeRegex, lowered):
return s[:-3] + 'y'
else:
if lowered.endswith('s'):
return s[:-1] # Chop off 's'.
2003-09-01 07:42:35 +02:00
else:
return s # Don't know what to do.
2003-09-01 07:42:35 +02:00
def nItems(item, n, between=None):
"""Works like this:
>>> nItems('clock', 1)
'1 clock'
>>> nItems('clock', 10)
'10 clocks'
>>> nItems('clock', 10, between='grandfather')
'10 grandfather clocks'
"""
2003-09-03 11:40:26 +02:00
if between is None:
return '%s %s' % (n, pluralize(item, n))
2003-09-03 11:40:26 +02:00
else:
return '%s %s %s' % (n, between, pluralize(item, n))
2003-09-03 11:40:26 +02:00
2003-09-01 07:42:35 +02:00
def be(i):
"""Returns the form of the verb 'to be' based on the number i."""
2003-09-01 07:42:35 +02:00
if i == 1:
return 'is'
else:
return 'are'
2004-08-11 19:10:20 +02:00
def has(i):
"""Returns the form of the verb 'to have' based on the number i."""
if i == 1:
return 'has'
else:
return 'have'
def sortBy(f, L):
"""Uses the decorate-sort-undecorate pattern to sort L by function f."""
for (i, elt) in enumerate(L):
L[i] = (f(elt), i, elt)
L.sort()
for (i, elt) in enumerate(L):
L[i] = L[i][2]
2003-09-01 20:39:27 +02:00
2004-04-20 12:04:09 +02:00
def sorted(iterable, cmp=None, key=None, reversed=False):
2003-12-09 15:46:12 +01:00
L = list(iterable)
2004-04-20 12:04:09 +02:00
if key is not None:
2004-08-11 00:09:23 +02:00
assert cmp is not None, 'Can\'t use both cmp and key.'
sortBy(key, L)
2004-04-20 12:04:09 +02:00
else:
L.sort(cmp)
if reversed:
L.reverse()
2003-12-09 15:46:12 +01:00
return L
2004-04-20 12:04:09 +02:00
__builtins__['sorted'] = sorted
def mktemp(suffix=''):
"""Gives a decent random string, suitable for a filename."""
r = random.Random()
m = md5.md5(suffix)
r.seed(time.time())
s = str(r.getstate())
for x in xrange(0, random.randrange(400), random.randrange(1, 5)):
m.update(str(x))
m.update(s)
m.update(str(time.time()))
s = m.hexdigest()
return sha.sha(s + str(time.time())).hexdigest() + suffix
def itersplit(isSeparator, iterable, maxsplit=-1, yieldEmpty=False):
"""Splits an iterator based on a predicate isSeparator."""
acc = []
for element in iterable:
if maxsplit == 0 or not isSeparator(element):
acc.append(element)
else:
maxsplit -= 1
if acc or yieldEmpty:
yield acc
acc = []
if acc or yieldEmpty:
yield acc
def flatten(seq, strings=False):
"""Flattens a list of lists into a single list. See the test for examples.
"""
for elt in seq:
if not strings and type(elt) == str or type(elt) == unicode:
yield elt
else:
try:
for x in flatten(elt):
yield x
except TypeError:
yield elt
2003-10-28 15:20:00 +01:00
def saltHash(password, salt=None, hash='sha'):
if salt is None:
salt = mktemp()[:8]
2003-11-07 20:40:03 +01:00
if hash == 'sha':
2003-10-28 15:20:00 +01:00
hasher = sha.sha
2003-11-07 20:40:03 +01:00
elif hash == 'md5':
hasher = md5.md5
return '|'.join([salt, hasher(salt + password).hexdigest()])
2003-10-28 15:20:00 +01:00
2003-11-07 20:40:03 +01:00
def safeEval(s, namespace={'True': True, 'False': False, 'None': None}):
"""Evaluates s, safely. Useful for turning strings into tuples/lists/etc.
without unsafely using eval()."""
node = compiler.parse(s)
nodes = compiler.parse(s).node.nodes
if not nodes:
if node.__class__ is compiler.ast.Module:
return node.doc
else:
2004-08-11 08:18:29 +02:00
raise ValueError, 'Unsafe string: %r' % s
2003-11-07 20:40:03 +01:00
node = nodes[0]
if node.__class__ is not compiler.ast.Discard:
2004-08-11 08:18:29 +02:00
raise ValueError, 'Invalid expression: %r' % s
2003-11-07 20:40:03 +01:00
node = node.getChildNodes()[0]
def checkNode(node):
if node.__class__ is compiler.ast.Const:
return True
if node.__class__ in (compiler.ast.List,
compiler.ast.Tuple,
compiler.ast.Dict):
return all(checkNode, node.getChildNodes())
if node.__class__ is compiler.ast.Name:
if node.name in namespace:
return True
else:
return False
else:
return False
if checkNode(node):
return eval(s, namespace, namespace)
else:
2004-08-11 08:18:29 +02:00
raise ValueError, 'Unsafe string: %r' % s
def exnToString(e):
"""Turns a simple exception instance into a string (better than str(e))"""
return '%s: %s' % (e.__class__.__name__, e)
2003-11-07 20:40:03 +01:00
class IterableMap(object):
"""Define .iteritems() in a class and subclass this to get the other iters.
"""
def iteritems(self):
raise NotImplementedError
def iterkeys(self):
for (key, _) in self.iteritems():
yield key
__iter__ = iterkeys
def itervalues(self):
for (_, value) in self.iteritems():
yield value
def items(self):
return list(self.iteritems())
def keys(self):
return list(self.iterkeys())
def values(self):
return list(self.itervalues())
def __len__(self):
ret = 0
for _ in self.iteritems():
ret += 1
return ret
def __nonzero__(self):
for _ in self.iteritems():
return True
return False
def nonCommentLines(fd):
for line in fd:
if not line.startswith('#'):
yield line
def nonEmptyLines(fd):
## for line in fd:
## if line.strip():
## yield line
return ifilter(str.strip, fd)
def nonCommentNonEmptyLines(fd):
return nonEmptyLines(nonCommentLines(fd))
def changeFunctionName(f, name, doc=None):
if doc is None:
doc = f.__doc__
newf = types.FunctionType(f.func_code, f.func_globals, name,
f.func_defaults, f.func_closure)
newf.__doc__ = doc
return newf
def getSocket(host):
"""Returns a socket of the correct AF_INET type (v4 or v6) in order to
communicate with host.
"""
2004-07-31 02:47:08 +02:00
host = socket.gethostbyname(host)
if isIP(host):
return socket.socket(socket.AF_INET, socket.SOCK_STREAM)
elif isIPV6(host):
return socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
else:
raise socket.error, 'Something wonky happened.'
2004-01-16 18:33:51 +01:00
def isIP(s):
"""Returns whether or not a given string is an IPV4 address.
>>> isIP('255.255.255.255')
1
>>> isIP('abc.abc.abc.abc')
0
"""
try:
return bool(socket.inet_aton(s))
except socket.error:
return False
def isIPV6(s):
"""Returns whether or not a given string is an IPV6 address."""
try:
return bool(socket.inet_pton(socket.AF_INET6, s))
except socket.error:
try:
socket.inet_pton(socket.AF_INET6, '::')
except socket.error:
# We gotta fake it.
if s.count('::') <= 1:
L = s.split(':')
if len(L) <= 8:
for x in L:
if x:
try:
int(x, 16)
except ValueError:
return False
return True
return False
class InsensitivePreservingDict(UserDict.DictMixin, object):
key = staticmethod(str.lower)
def __init__(self, dict=None, key=None):
if key is not None:
self.key = key
self.data = {}
if dict is not None:
self.update(dict)
2004-04-09 18:29:16 +02:00
def fromkeys(cls, keys, s=None, dict=None, key=None):
d = cls(dict=dict, key=key)
for key in keys:
d[key] = s
return d
fromkeys = classmethod(fromkeys)
def __getitem__(self, k):
return self.data[self.key(k)][1]
def __setitem__(self, k, v):
self.data[self.key(k)] = (k, v)
def __delitem__(self, k):
del self.data[self.key(k)]
def iteritems(self):
2004-02-05 08:32:20 +01:00
return self.data.itervalues()
def keys(self):
L = []
for (k, _) in self.iteritems():
L.append(k)
return L
def __reduce__(self):
return (self.__class__, (dict(self.data.values()),))
class NormalizingSet(sets.Set):
def __init__(self, iterable=()):
iterable = itertools.imap(self.normalize, iterable)
super(NormalizingSet, self).__init__(iterable)
def normalize(self, x):
return x
def add(self, x):
return super(NormalizingSet, self).add(self.normalize(x))
def remove(self, x):
return super(NormalizingSet, self).remove(self.normalize(x))
def discard(self, x):
return super(NormalizingSet, self).discard(self.normalize(x))
def __contains__(self, x):
return super(NormalizingSet, self).__contains__(self.normalize(x))
has_key = __contains__
2004-04-30 20:24:35 +02:00
def mungeEmailForWeb(s):
s = s.replace('@', ' AT ')
s = s.replace('.', ' DOT ')
return s
2004-07-31 00:35:51 +02:00
def stackTrace():
traceback.print_stack(sys._getframe())
2004-04-30 20:24:35 +02:00
2004-07-31 11:44:03 +02:00
class AtomicFile(file):
"""Used for files that need to be atomically written -- i.e., if there's a
failure, the original file remains, unmodified. mode must be 'w' or 'wb'"""
def __init__(self, filename, mode='w',
allowEmptyOverwrite=False, tmpDir=None):
if mode not in ('w', 'wb'):
raise ValueError, 'Invalid mode: %r' % mode
self.rolledback = False
self.allowEmptyOverwrite = allowEmptyOverwrite
self.filename = filename
if tmpDir is None:
# If not given a tmpDir, we'll just put a random token on the end
# of our filename and put it in the same directory.
self.tempFilename = '%s.%s' % (self.filename, mktemp())
else:
# If given a tmpDir, we'll get the basename (just the filename, no
# directory), put our random token on the end, and put it in tmpDir
tempFilename = '%s.%s' % (os.path.basename(self.filename), mktemp())
self.tempFilename = os.path.join(tmpDir, tempFilename)
super(AtomicFile, self).__init__(self.tempFilename, mode)
2004-07-31 11:44:03 +02:00
def rollback(self):
2004-07-31 12:58:31 +02:00
if not self.closed:
super(AtomicFile, self).close()
2004-07-31 12:58:31 +02:00
if os.path.exists(self.tempFilename):
os.remove(self.tempFilename)
self.rolledback = True
def close(self):
if not self.rolledback:
super(AtomicFile, self).close()
# We don't mind writing an empty file if the file we're overwriting
# doesn't exist.
size = os.path.getsize(self.tempFilename)
originalExists = os.path.exists(self.filename)
if size or self.allowEmptyOverwrite or not originalExists:
if os.path.exists(self.tempFilename):
2004-07-31 21:12:38 +02:00
# We use shutil.move here instead of os.rename because
# the latter doesn't work on Windows when self.filename
# (the target) already exists. shutil.move handles those
# intricacies for us.
shutil.move(self.tempFilename, self.filename)
def __del__(self):
# We rollback because if we're deleted without being explicitly closed,
# that's bad. We really should log this here, but as of yet we've got
# no logging facility in utils. I've got some ideas for this, though.
self.rollback()
2004-07-31 11:44:03 +02:00
def transactionalFile(*args, **kwargs):
# This exists so it can be replaced by a function that provides the tmpDir.
# We do that replacement in conf.py.
return AtomicFile(*args, **kwargs)
if __name__ == '__main__':
2004-08-11 00:09:23 +02:00
import doctest
doctest.testmod(sys.modules['__main__'])
# vim:set shiftwidth=4 tabstop=8 expandtab textwidth=78: