1
0
mirror of https://github.com/arsenetar/dupeguru.git synced 2026-01-22 06:37:17 +00:00

Format files with black

- Format all files with black
- Update tox.ini flake8 arguments to be compatible
- Add black to requirements-extra.txt
- Reduce ignored flake8 rules and fix a few violations
This commit is contained in:
2019-12-31 20:16:27 -06:00
parent 359d6498f7
commit 7ba8aa3514
141 changed files with 5241 additions and 3648 deletions

View File

@@ -26,7 +26,8 @@ import modulefinder
from setuptools import setup, Extension
from .plat import ISWINDOWS
from .util import modified_after, find_in_path, ensure_folder, delete_files_with_pattern
from .util import ensure_folder, delete_files_with_pattern
def print_and_do(cmd):
"""Prints ``cmd`` and executes it in the shell.
@@ -35,6 +36,7 @@ def print_and_do(cmd):
p = Popen(cmd, shell=True)
return p.wait()
def _perform(src, dst, action, actionname):
if not op.lexists(src):
print("Copying %s failed: it doesn't exist." % src)
@@ -44,26 +46,32 @@ def _perform(src, dst, action, actionname):
shutil.rmtree(dst)
else:
os.remove(dst)
print('%s %s --> %s' % (actionname, src, dst))
print("%s %s --> %s" % (actionname, src, dst))
action(src, dst)
def copy_file_or_folder(src, dst):
if op.isdir(src):
shutil.copytree(src, dst, symlinks=True)
else:
shutil.copy(src, dst)
def move(src, dst):
_perform(src, dst, os.rename, 'Moving')
_perform(src, dst, os.rename, "Moving")
def copy(src, dst):
_perform(src, dst, copy_file_or_folder, 'Copying')
_perform(src, dst, copy_file_or_folder, "Copying")
def symlink(src, dst):
_perform(src, dst, os.symlink, 'Symlinking')
_perform(src, dst, os.symlink, "Symlinking")
def hardlink(src, dst):
_perform(src, dst, os.link, 'Hardlinking')
_perform(src, dst, os.link, "Hardlinking")
def _perform_on_all(pattern, dst, action):
# pattern is a glob pattern, example "folder/foo*". The file is moved directly in dst, no folder
@@ -73,12 +81,15 @@ def _perform_on_all(pattern, dst, action):
destpath = op.join(dst, op.basename(fn))
action(fn, destpath)
def move_all(pattern, dst):
_perform_on_all(pattern, dst, move)
def copy_all(pattern, dst):
_perform_on_all(pattern, dst, copy)
def ensure_empty_folder(path):
"""Make sure that the path exists and that it's an empty folder.
"""
@@ -86,43 +97,54 @@ def ensure_empty_folder(path):
shutil.rmtree(path)
os.mkdir(path)
def filereplace(filename, outfilename=None, **kwargs):
"""Reads `filename`, replaces all {variables} in kwargs, and writes the result to `outfilename`.
"""
if outfilename is None:
outfilename = filename
fp = open(filename, 'rt', encoding='utf-8')
fp = open(filename, "rt", encoding="utf-8")
contents = fp.read()
fp.close()
# We can't use str.format() because in some files, there might be {} characters that mess with it.
for key, item in kwargs.items():
contents = contents.replace('{{{}}}'.format(key), item)
fp = open(outfilename, 'wt', encoding='utf-8')
contents = contents.replace("{{{}}}".format(key), item)
fp = open(outfilename, "wt", encoding="utf-8")
fp.write(contents)
fp.close()
def get_module_version(modulename):
mod = importlib.import_module(modulename)
return mod.__version__
def setup_package_argparser(parser):
parser.add_argument(
'--sign', dest='sign_identity',
help="Sign app under specified identity before packaging (OS X only)"
"--sign",
dest="sign_identity",
help="Sign app under specified identity before packaging (OS X only)",
)
parser.add_argument(
'--nosign', action='store_true', dest='nosign',
help="Don't sign the packaged app (OS X only)"
"--nosign",
action="store_true",
dest="nosign",
help="Don't sign the packaged app (OS X only)",
)
parser.add_argument(
'--src-pkg', action='store_true', dest='src_pkg',
help="Build a tar.gz of the current source."
"--src-pkg",
action="store_true",
dest="src_pkg",
help="Build a tar.gz of the current source.",
)
parser.add_argument(
'--arch-pkg', action='store_true', dest='arch_pkg',
help="Force Arch Linux packaging type, regardless of distro name."
"--arch-pkg",
action="store_true",
dest="arch_pkg",
help="Force Arch Linux packaging type, regardless of distro name.",
)
# `args` come from an ArgumentParser updated with setup_package_argparser()
def package_cocoa_app_in_dmg(app_path, destfolder, args):
# Rather than signing our app in XCode during the build phase, we sign it during the package
@@ -130,7 +152,9 @@ def package_cocoa_app_in_dmg(app_path, destfolder, args):
# a valid signature.
if args.sign_identity:
sign_identity = "Developer ID Application: {}".format(args.sign_identity)
result = print_and_do('codesign --force --deep --sign "{}" "{}"'.format(sign_identity, app_path))
result = print_and_do(
'codesign --force --deep --sign "{}" "{}"'.format(sign_identity, app_path)
)
if result != 0:
print("ERROR: Signing failed. Aborting packaging.")
return
@@ -139,23 +163,31 @@ def package_cocoa_app_in_dmg(app_path, destfolder, args):
return
build_dmg(app_path, destfolder)
def build_dmg(app_path, destfolder):
"""Builds a DMG volume with application at ``app_path`` and puts it in ``dest_path``.
The name of the resulting DMG volume is determined by the app's name and version.
"""
print(repr(op.join(app_path, 'Contents', 'Info.plist')))
plist = plistlib.readPlist(op.join(app_path, 'Contents', 'Info.plist'))
print(repr(op.join(app_path, "Contents", "Info.plist")))
plist = plistlib.readPlist(op.join(app_path, "Contents", "Info.plist"))
workpath = tempfile.mkdtemp()
dmgpath = op.join(workpath, plist['CFBundleName'])
dmgpath = op.join(workpath, plist["CFBundleName"])
os.mkdir(dmgpath)
print_and_do('cp -R "%s" "%s"' % (app_path, dmgpath))
print_and_do('ln -s /Applications "%s"' % op.join(dmgpath, 'Applications'))
dmgname = '%s_osx_%s.dmg' % (plist['CFBundleName'].lower().replace(' ', '_'), plist['CFBundleVersion'].replace('.', '_'))
print('Building %s' % dmgname)
print_and_do('ln -s /Applications "%s"' % op.join(dmgpath, "Applications"))
dmgname = "%s_osx_%s.dmg" % (
plist["CFBundleName"].lower().replace(" ", "_"),
plist["CFBundleVersion"].replace(".", "_"),
)
print("Building %s" % dmgname)
# UDBZ = bzip compression. UDZO (zip compression) was used before, but it compresses much less.
print_and_do('hdiutil create "%s" -format UDBZ -nocrossdev -srcdir "%s"' % (op.join(destfolder, dmgname), dmgpath))
print('Build Complete')
print_and_do(
'hdiutil create "%s" -format UDBZ -nocrossdev -srcdir "%s"'
% (op.join(destfolder, dmgname), dmgpath)
)
print("Build Complete")
def copy_sysconfig_files_for_embed(destpath):
# This normally shouldn't be needed for Python 3.3+.
@@ -163,24 +195,28 @@ def copy_sysconfig_files_for_embed(destpath):
configh = sysconfig.get_config_h_filename()
shutil.copy(makefile, destpath)
shutil.copy(configh, destpath)
with open(op.join(destpath, 'site.py'), 'w') as fp:
fp.write("""
with open(op.join(destpath, "site.py"), "w") as fp:
fp.write(
"""
import os.path as op
from distutils import sysconfig
sysconfig.get_makefile_filename = lambda: op.join(op.dirname(__file__), 'Makefile')
sysconfig.get_config_h_filename = lambda: op.join(op.dirname(__file__), 'pyconfig.h')
""")
"""
)
def add_to_pythonpath(path):
"""Adds ``path`` to both ``PYTHONPATH`` env and ``sys.path``.
"""
abspath = op.abspath(path)
pythonpath = os.environ.get('PYTHONPATH', '')
pathsep = ';' if ISWINDOWS else ':'
pythonpath = os.environ.get("PYTHONPATH", "")
pathsep = ";" if ISWINDOWS else ":"
pythonpath = pathsep.join([abspath, pythonpath]) if pythonpath else abspath
os.environ['PYTHONPATH'] = pythonpath
os.environ["PYTHONPATH"] = pythonpath
sys.path.insert(1, abspath)
# This is a method to hack around those freakingly tricky data inclusion/exlusion rules
# in setuptools. We copy the packages *without data* in a build folder and then build the plugin
# from there.
@@ -195,14 +231,16 @@ def copy_packages(packages_names, dest, create_links=False, extra_ignores=None):
create_links = False
if not extra_ignores:
extra_ignores = []
ignore = shutil.ignore_patterns('.hg*', 'tests', 'testdata', 'modules', 'docs', 'locale', *extra_ignores)
ignore = shutil.ignore_patterns(
".hg*", "tests", "testdata", "modules", "docs", "locale", *extra_ignores
)
for package_name in packages_names:
if op.exists(package_name):
source_path = package_name
else:
mod = __import__(package_name)
source_path = mod.__file__
if mod.__file__.endswith('__init__.py'):
if mod.__file__.endswith("__init__.py"):
source_path = op.dirname(source_path)
dest_name = op.basename(source_path)
dest_path = op.join(dest, dest_name)
@@ -220,58 +258,81 @@ def copy_packages(packages_names, dest, create_links=False, extra_ignores=None):
else:
shutil.copy(source_path, dest_path)
def copy_qt_plugins(folder_names, dest): # This is only for Windows
def copy_qt_plugins(folder_names, dest): # This is only for Windows
from PyQt5.QtCore import QLibraryInfo
qt_plugin_dir = QLibraryInfo.location(QLibraryInfo.PluginsPath)
def ignore(path, names):
if path == qt_plugin_dir:
return [n for n in names if n not in folder_names]
else:
return [n for n in names if not n.endswith('.dll')]
return [n for n in names if not n.endswith(".dll")]
shutil.copytree(qt_plugin_dir, dest, ignore=ignore)
def build_debian_changelog(changelogpath, destfile, pkgname, from_version=None,
distribution='precise', fix_version=None):
def build_debian_changelog(
changelogpath,
destfile,
pkgname,
from_version=None,
distribution="precise",
fix_version=None,
):
"""Builds a debian changelog out of a YAML changelog.
Use fix_version to patch the top changelog to that version (if, for example, there was a
packaging error and you need to quickly fix it)
"""
def desc2list(desc):
# We take each item, enumerated with the '*' character, and transform it into a list.
desc = desc.replace('\n', ' ')
desc = desc.replace(' ', ' ')
result = desc.split('*')
desc = desc.replace("\n", " ")
desc = desc.replace(" ", " ")
result = desc.split("*")
return [s.strip() for s in result if s.strip()]
ENTRY_MODEL = "{pkg} ({version}-1) {distribution}; urgency=low\n\n{changes}\n -- Virgil Dupras <hsoft@hardcoded.net> {date}\n\n"
ENTRY_MODEL = (
"{pkg} ({version}-1) {distribution}; urgency=low\n\n{changes}\n "
"-- Virgil Dupras <hsoft@hardcoded.net> {date}\n\n"
)
CHANGE_MODEL = " * {description}\n"
changelogs = read_changelog_file(changelogpath)
if from_version:
# We only want logs from a particular version
for index, log in enumerate(changelogs):
if log['version'] == from_version:
changelogs = changelogs[:index+1]
if log["version"] == from_version:
changelogs = changelogs[: index + 1]
break
if fix_version:
changelogs[0]['version'] = fix_version
changelogs[0]["version"] = fix_version
rendered_logs = []
for log in changelogs:
version = log['version']
logdate = log['date']
desc = log['description']
rendered_date = logdate.strftime('%a, %d %b %Y 00:00:00 +0000')
version = log["version"]
logdate = log["date"]
desc = log["description"]
rendered_date = logdate.strftime("%a, %d %b %Y 00:00:00 +0000")
rendered_descs = [CHANGE_MODEL.format(description=d) for d in desc2list(desc)]
changes = ''.join(rendered_descs)
rendered_log = ENTRY_MODEL.format(pkg=pkgname, version=version, changes=changes,
date=rendered_date, distribution=distribution)
changes = "".join(rendered_descs)
rendered_log = ENTRY_MODEL.format(
pkg=pkgname,
version=version,
changes=changes,
date=rendered_date,
distribution=distribution,
)
rendered_logs.append(rendered_log)
result = ''.join(rendered_logs)
fp = open(destfile, 'w')
result = "".join(rendered_logs)
fp = open(destfile, "w")
fp.write(result)
fp.close()
re_changelog_header = re.compile(r'=== ([\d.b]*) \(([\d\-]*)\)')
re_changelog_header = re.compile(r"=== ([\d.b]*) \(([\d\-]*)\)")
def read_changelog_file(filename):
def iter_by_three(it):
while True:
@@ -283,25 +344,31 @@ def read_changelog_file(filename):
return
yield version, date, description
with open(filename, 'rt', encoding='utf-8') as fp:
with open(filename, "rt", encoding="utf-8") as fp:
contents = fp.read()
splitted = re_changelog_header.split(contents)[1:] # the first item is empty
splitted = re_changelog_header.split(contents)[1:] # the first item is empty
# splitted = [version1, date1, desc1, version2, date2, ...]
result = []
for version, date_str, description in iter_by_three(iter(splitted)):
date = datetime.strptime(date_str, '%Y-%m-%d').date()
d = {'date': date, 'date_str': date_str, 'version': version, 'description': description.strip()}
date = datetime.strptime(date_str, "%Y-%m-%d").date()
d = {
"date": date,
"date_str": date_str,
"version": version,
"description": description.strip(),
}
result.append(d)
return result
class OSXAppStructure:
def __init__(self, dest):
self.dest = dest
self.contents = op.join(dest, 'Contents')
self.macos = op.join(self.contents, 'MacOS')
self.resources = op.join(self.contents, 'Resources')
self.frameworks = op.join(self.contents, 'Frameworks')
self.infoplist = op.join(self.contents, 'Info.plist')
self.contents = op.join(dest, "Contents")
self.macos = op.join(self.contents, "MacOS")
self.resources = op.join(self.contents, "Resources")
self.frameworks = op.join(self.contents, "Frameworks")
self.infoplist = op.join(self.contents, "Info.plist")
def create(self, infoplist):
ensure_empty_folder(self.dest)
@@ -309,11 +376,11 @@ class OSXAppStructure:
os.mkdir(self.resources)
os.mkdir(self.frameworks)
copy(infoplist, self.infoplist)
open(op.join(self.contents, 'PkgInfo'), 'wt').write("APPLxxxx")
open(op.join(self.contents, "PkgInfo"), "wt").write("APPLxxxx")
def copy_executable(self, executable):
info = plistlib.readPlist(self.infoplist)
self.executablename = info['CFBundleExecutable']
self.executablename = info["CFBundleExecutable"]
self.executablepath = op.join(self.macos, self.executablename)
copy(executable, self.executablepath)
@@ -329,8 +396,14 @@ class OSXAppStructure:
copy(path, framework_dest)
def create_osx_app_structure(dest, executable, infoplist, resources=None, frameworks=None,
symlink_resources=False):
def create_osx_app_structure(
dest,
executable,
infoplist,
resources=None,
frameworks=None,
symlink_resources=False,
):
# `dest`: A path to the destination .app folder
# `executable`: the path of the executable file that goes in "MacOS"
# `infoplist`: The path to your Info.plist file.
@@ -343,13 +416,14 @@ def create_osx_app_structure(dest, executable, infoplist, resources=None, framew
app.copy_resources(*resources, use_symlinks=symlink_resources)
app.copy_frameworks(*frameworks)
class OSXFrameworkStructure:
def __init__(self, dest):
self.dest = dest
self.contents = op.join(dest, 'Versions', 'A')
self.resources = op.join(self.contents, 'Resources')
self.headers = op.join(self.contents, 'Headers')
self.infoplist = op.join(self.resources, 'Info.plist')
self.contents = op.join(dest, "Versions", "A")
self.resources = op.join(self.contents, "Resources")
self.headers = op.join(self.contents, "Headers")
self.infoplist = op.join(self.resources, "Info.plist")
self._update_executable_path()
def _update_executable_path(self):
@@ -357,7 +431,7 @@ class OSXFrameworkStructure:
self.executablename = self.executablepath = None
return
info = plistlib.readPlist(self.infoplist)
self.executablename = info['CFBundleExecutable']
self.executablename = info["CFBundleExecutable"]
self.executablepath = op.join(self.contents, self.executablename)
def create(self, infoplist):
@@ -371,10 +445,10 @@ class OSXFrameworkStructure:
def create_symlinks(self):
# Only call this after create() and copy_executable()
rel = lambda path: op.relpath(path, self.dest)
os.symlink('A', op.join(self.dest, 'Versions', 'Current'))
os.symlink("A", op.join(self.dest, "Versions", "Current"))
os.symlink(rel(self.executablepath), op.join(self.dest, self.executablename))
os.symlink(rel(self.headers), op.join(self.dest, 'Headers'))
os.symlink(rel(self.resources), op.join(self.dest, 'Resources'))
os.symlink(rel(self.headers), op.join(self.dest, "Headers"))
os.symlink(rel(self.resources), op.join(self.dest, "Resources"))
def copy_executable(self, executable):
copy(executable, self.executablepath)
@@ -393,23 +467,28 @@ class OSXFrameworkStructure:
def copy_embeddable_python_dylib(dst):
runtime = op.join(sysconfig.get_config_var('PYTHONFRAMEWORKPREFIX'), sysconfig.get_config_var('LDLIBRARY'))
filedest = op.join(dst, 'Python')
runtime = op.join(
sysconfig.get_config_var("PYTHONFRAMEWORKPREFIX"),
sysconfig.get_config_var("LDLIBRARY"),
)
filedest = op.join(dst, "Python")
shutil.copy(runtime, filedest)
os.chmod(filedest, 0o774) # We need write permission to use install_name_tool
cmd = 'install_name_tool -id @rpath/Python %s' % filedest
os.chmod(filedest, 0o774) # We need write permission to use install_name_tool
cmd = "install_name_tool -id @rpath/Python %s" % filedest
print_and_do(cmd)
def collect_stdlib_dependencies(script, dest_folder, extra_deps=None):
sysprefix = sys.prefix # could be a virtualenv
real_lib_prefix = sysconfig.get_config_var('LIBDEST')
sysprefix = sys.prefix # could be a virtualenv
real_lib_prefix = sysconfig.get_config_var("LIBDEST")
def is_stdlib_path(path):
# A module path is only a stdlib path if it's in either sys.prefix or
# sysconfig.get_config_var('prefix') (the 2 are different if we are in a virtualenv) and if
# there's no "site-package in the path.
if not path:
return False
if 'site-package' in path:
if "site-package" in path:
return False
if not (path.startswith(sysprefix) or path.startswith(real_lib_prefix)):
return False
@@ -425,13 +504,17 @@ def collect_stdlib_dependencies(script, dest_folder, extra_deps=None):
relpath = op.relpath(p, real_lib_prefix)
elif p.startswith(sysprefix):
relpath = op.relpath(p, sysprefix)
assert relpath.startswith('lib/python3.') # we want to get rid of that lib/python3.x part
relpath = relpath[len('lib/python3.X/'):]
assert relpath.startswith(
"lib/python3."
) # we want to get rid of that lib/python3.x part
relpath = relpath[len("lib/python3.X/") :]
else:
raise AssertionError()
if relpath.startswith('lib-dynload'): # We copy .so files in lib-dynload directly in our dest
relpath = relpath[len('lib-dynload/'):]
if relpath.startswith('encodings') or relpath.startswith('distutils'):
if relpath.startswith(
"lib-dynload"
): # We copy .so files in lib-dynload directly in our dest
relpath = relpath[len("lib-dynload/") :]
if relpath.startswith("encodings") or relpath.startswith("distutils"):
# We force their inclusion later.
continue
dest_path = op.join(dest_folder, relpath)
@@ -440,34 +523,47 @@ def collect_stdlib_dependencies(script, dest_folder, extra_deps=None):
# stringprep is used by encodings.
# We use real_lib_prefix with distutils because virtualenv messes with it and we need to refer
# to the original distutils folder.
FORCED_INCLUSION = ['encodings', 'stringprep', op.join(real_lib_prefix, 'distutils')]
FORCED_INCLUSION = [
"encodings",
"stringprep",
op.join(real_lib_prefix, "distutils"),
]
if extra_deps:
FORCED_INCLUSION += extra_deps
copy_packages(FORCED_INCLUSION, dest_folder)
# There's a couple of rather big exe files in the distutils folder that we absolutely don't
# need. Remove them.
delete_files_with_pattern(op.join(dest_folder, 'distutils'), '*.exe')
delete_files_with_pattern(op.join(dest_folder, "distutils"), "*.exe")
# And, finally, create an empty "site.py" that Python needs around on startup.
open(op.join(dest_folder, 'site.py'), 'w').close()
open(op.join(dest_folder, "site.py"), "w").close()
def fix_qt_resource_file(path):
# pyrcc5 under Windows, if the locale is non-english, can produce a source file with a date
# containing accented characters. If it does, the encoding is wrong and it prevents the file
# from being correctly frozen by cx_freeze. To work around that, we open the file, strip all
# comments, and save.
with open(path, 'rb') as fp:
with open(path, "rb") as fp:
contents = fp.read()
lines = contents.split(b'\n')
lines = [l for l in lines if not l.startswith(b'#')]
with open(path, 'wb') as fp:
fp.write(b'\n'.join(lines))
lines = contents.split(b"\n")
lines = [l for l in lines if not l.startswith(b"#")]
with open(path, "wb") as fp:
fp.write(b"\n".join(lines))
def build_cocoa_ext(extname, dest, source_files, extra_frameworks=(), extra_includes=()):
def build_cocoa_ext(
extname, dest, source_files, extra_frameworks=(), extra_includes=()
):
extra_link_args = ["-framework", "CoreFoundation", "-framework", "Foundation"]
for extra in extra_frameworks:
extra_link_args += ['-framework', extra]
ext = Extension(extname, source_files, extra_link_args=extra_link_args, include_dirs=extra_includes)
setup(script_args=['build_ext', '--inplace'], ext_modules=[ext])
extra_link_args += ["-framework", extra]
ext = Extension(
extname,
source_files,
extra_link_args=extra_link_args,
include_dirs=extra_includes,
)
setup(script_args=["build_ext", "--inplace"], ext_modules=[ext])
# Our problem here is to get the fully qualified filename of the resulting .so but I couldn't
# find a documented way to do so. The only thing I could find is this below :(
fn = ext._file_name

View File

@@ -8,26 +8,24 @@ import argparse
from setuptools import setup, Extension
def get_parser():
parser = argparse.ArgumentParser(description="Build an arbitrary Python extension.")
parser.add_argument(
'source_files', nargs='+',
help="List of source files to compile"
)
parser.add_argument(
'name', nargs=1,
help="Name of the resulting extension"
"source_files", nargs="+", help="List of source files to compile"
)
parser.add_argument("name", nargs=1, help="Name of the resulting extension")
return parser
def main():
args = get_parser().parse_args()
print("Building {}...".format(args.name[0]))
ext = Extension(args.name[0], args.source_files)
setup(
script_args=['build_ext', '--inplace'],
ext_modules=[ext],
script_args=["build_ext", "--inplace"], ext_modules=[ext],
)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -2,8 +2,8 @@
# Created On: 2008-01-08
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
"""When you have to deal with names that have to be unique and can conflict together, you can use
@@ -16,14 +16,15 @@ import shutil
from .path import Path, pathify
#This matches [123], but not [12] (3 digits being the minimum).
#It also matches [1234] [12345] etc..
#And only at the start of the string
re_conflict = re.compile(r'^\[\d{3}\d*\] ')
# This matches [123], but not [12] (3 digits being the minimum).
# It also matches [1234] [12345] etc..
# And only at the start of the string
re_conflict = re.compile(r"^\[\d{3}\d*\] ")
def get_conflicted_name(other_names, name):
"""Returns name with a ``[000]`` number in front of it.
The number between brackets depends on how many conlicted filenames
there already are in other_names.
"""
@@ -32,23 +33,26 @@ def get_conflicted_name(other_names, name):
return name
i = 0
while True:
newname = '[%03d] %s' % (i, name)
newname = "[%03d] %s" % (i, name)
if newname not in other_names:
return newname
i += 1
def get_unconflicted_name(name):
"""Returns ``name`` without ``[]`` brackets.
Brackets which, of course, might have been added by func:`get_conflicted_name`.
"""
return re_conflict.sub('',name,1)
return re_conflict.sub("", name, 1)
def is_conflicted(name):
"""Returns whether ``name`` is prepended with a bracketed number.
"""
return re_conflict.match(name) is not None
@pathify
def _smart_move_or_copy(operation, source_path: Path, dest_path: Path):
"""Use move() or copy() to move and copy file with the conflict management.
@@ -61,19 +65,24 @@ def _smart_move_or_copy(operation, source_path: Path, dest_path: Path):
newname = get_conflicted_name(os.listdir(str(dest_dir_path)), filename)
dest_path = dest_dir_path[newname]
operation(str(source_path), str(dest_path))
def smart_move(source_path, dest_path):
"""Same as :func:`smart_copy`, but it moves files instead.
"""
_smart_move_or_copy(shutil.move, source_path, dest_path)
def smart_copy(source_path, dest_path):
"""Copies ``source_path`` to ``dest_path``, recursively and with conflict resolution.
"""
try:
_smart_move_or_copy(shutil.copy, source_path, dest_path)
except IOError as e:
if e.errno in {21, 13}: # it's a directory, code is 21 on OS X / Linux and 13 on Windows
if e.errno in {
21,
13,
}: # it's a directory, code is 21 on OS X / Linux and 13 on Windows
_smart_move_or_copy(shutil.copytree, source_path, dest_path)
else:
raise
raise

View File

@@ -1,14 +1,15 @@
# Created By: Virgil Dupras
# Created On: 2011-04-19
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
import sys
import traceback
# Taken from http://bzimmer.ziclix.com/2008/12/17/python-thread-dumps/
def stacktraces():
code = []
@@ -18,5 +19,5 @@ def stacktraces():
code.append('File: "%s", line %d, in %s' % (filename, lineno, name))
if line:
code.append(" %s" % (line.strip()))
return "\n".join(code)
return "\n".join(code)

View File

@@ -9,25 +9,30 @@
import os.path as op
import logging
class SpecialFolder:
AppData = 1
Cache = 2
def open_url(url):
"""Open ``url`` with the default browser.
"""
_open_url(url)
def open_path(path):
"""Open ``path`` with its associated application.
"""
_open_path(str(path))
def reveal_path(path):
"""Open the folder containing ``path`` with the default file browser.
"""
_reveal_path(str(path))
def special_folder_path(special_folder, appname=None):
"""Returns the path of ``special_folder``.
@@ -38,12 +43,14 @@ def special_folder_path(special_folder, appname=None):
"""
return _special_folder_path(special_folder, appname)
try:
# Normally, we would simply do "from cocoa import proxy", but due to a bug in pytest (currently
# at v2.4.2), our test suite is broken when we do that. This below is a workaround until that
# bug is fixed.
import cocoa
if not hasattr(cocoa, 'proxy'):
if not hasattr(cocoa, "proxy"):
raise ImportError()
proxy = cocoa.proxy
_open_url = proxy.openURL_
@@ -56,13 +63,15 @@ try:
else:
base = proxy.getAppdataPath()
if not appname:
appname = proxy.bundleInfo_('CFBundleName')
appname = proxy.bundleInfo_("CFBundleName")
return op.join(base, appname)
except ImportError:
try:
from PyQt5.QtCore import QUrl, QStandardPaths
from PyQt5.QtGui import QDesktopServices
def _open_url(url):
QDesktopServices.openUrl(QUrl(url))
@@ -79,10 +88,12 @@ except ImportError:
else:
qtfolder = QStandardPaths.DataLocation
return QStandardPaths.standardLocations(qtfolder)[0]
except ImportError:
# We're either running tests, and these functions don't matter much or we're in a really
# weird situation. Let's just have dummy fallbacks.
logging.warning("Can't setup desktop functions!")
def _open_path(path):
pass
@@ -90,4 +101,4 @@ except ImportError:
pass
def _special_folder_path(special_folder, appname=None):
return '/tmp'
return "/tmp"

View File

@@ -1,9 +1,9 @@
# Created By: Virgil Dupras
# Created On: 2011-08-05
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
from sys import maxsize as INF
@@ -11,73 +11,74 @@ from math import sqrt
VERY_SMALL = 0.0000001
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def __repr__(self):
return '<Point {:2.2f}, {:2.2f}>'.format(*self)
return "<Point {:2.2f}, {:2.2f}>".format(*self)
def __iter__(self):
yield self.x
yield self.y
def distance_to(self, other):
return Line(self, other).length()
class Line:
def __init__(self, p1, p2):
self.p1 = p1
self.p2 = p2
def __repr__(self):
return '<Line {}, {}>'.format(*self)
return "<Line {}, {}>".format(*self)
def __iter__(self):
yield self.p1
yield self.p2
def dx(self):
return self.p2.x - self.p1.x
def dy(self):
return self.p2.y - self.p1.y
def length(self):
return sqrt(self.dx() ** 2 + self.dy() ** 2)
def slope(self):
if self.dx() == 0:
return INF if self.dy() > 0 else -INF
else:
return self.dy() / self.dx()
def intersection_point(self, other):
# with help from http://paulbourke.net/geometry/lineline2d/
if abs(self.slope() - other.slope()) < VERY_SMALL:
# parallel. Even if coincident, we return nothing
return None
A, B = self
C, D = other
denom = (D.y-C.y) * (B.x-A.x) - (D.x-C.x) * (B.y-A.y)
denom = (D.y - C.y) * (B.x - A.x) - (D.x - C.x) * (B.y - A.y)
if denom == 0:
return None
numera = (D.x-C.x) * (A.y-C.y) - (D.y-C.y) * (A.x-C.x)
numerb = (B.x-A.x) * (A.y-C.y) - (B.y-A.y) * (A.x-C.x)
mua = numera / denom;
mub = numerb / denom;
numera = (D.x - C.x) * (A.y - C.y) - (D.y - C.y) * (A.x - C.x)
numerb = (B.x - A.x) * (A.y - C.y) - (B.y - A.y) * (A.x - C.x)
mua = numera / denom
mub = numerb / denom
if (0 <= mua <= 1) and (0 <= mub <= 1):
x = A.x + mua * (B.x - A.x)
y = A.y + mua * (B.y - A.y)
return Point(x, y)
else:
return None
class Rect:
def __init__(self, x, y, w, h):
@@ -85,43 +86,43 @@ class Rect:
self.y = y
self.w = w
self.h = h
def __iter__(self):
yield self.x
yield self.y
yield self.w
yield self.h
def __repr__(self):
return '<Rect {:2.2f}, {:2.2f}, {:2.2f}, {:2.2f}>'.format(*self)
return "<Rect {:2.2f}, {:2.2f}, {:2.2f}, {:2.2f}>".format(*self)
@classmethod
def from_center(cls, center, width, height):
x = center.x - width / 2
y = center.y - height / 2
return cls(x, y, width, height)
@classmethod
def from_corners(cls, pt1, pt2):
x1, y1 = pt1
x2, y2 = pt2
return cls(min(x1, x2), min(y1, y2), abs(x1-x2), abs(y1-y2))
return cls(min(x1, x2), min(y1, y2), abs(x1 - x2), abs(y1 - y2))
def center(self):
return Point(self.x + self.w/2, self.y + self.h/2)
return Point(self.x + self.w / 2, self.y + self.h / 2)
def contains_point(self, point):
x, y = point
(x1, y1), (x2, y2) = self.corners()
return (x1 <= x <= x2) and (y1 <= y <= y2)
def contains_rect(self, rect):
pt1, pt2 = rect.corners()
return self.contains_point(pt1) and self.contains_point(pt2)
def corners(self):
return Point(self.x, self.y), Point(self.x+self.w, self.y+self.h)
return Point(self.x, self.y), Point(self.x + self.w, self.y + self.h)
def intersects(self, other):
r1pt1, r1pt2 = self.corners()
r2pt1, r2pt2 = other.corners()
@@ -136,7 +137,7 @@ class Rect:
else:
yinter = r2pt2.y >= r1pt1.y
return yinter
def lines(self):
pt1, pt4 = self.corners()
pt2 = Point(pt4.x, pt1.y)
@@ -146,7 +147,7 @@ class Rect:
l3 = Line(pt3, pt4)
l4 = Line(pt1, pt3)
return l1, l2, l3, l4
def scaled_rect(self, dx, dy):
"""Returns a rect that has the same borders at self, but grown/shrunk by dx/dy on each side.
"""
@@ -156,7 +157,7 @@ class Rect:
w += dx * 2
h += dy * 2
return Rect(x, y, w, h)
def united(self, other):
"""Returns the bounding rectangle of this rectangle and `other`.
"""
@@ -166,53 +167,52 @@ class Rect:
corner1 = Point(min(ulcorner1.x, ulcorner2.x), min(ulcorner1.y, ulcorner2.y))
corner2 = Point(max(lrcorner1.x, lrcorner2.x), max(lrcorner1.y, lrcorner2.y))
return Rect.from_corners(corner1, corner2)
#--- Properties
# --- Properties
@property
def top(self):
return self.y
@top.setter
def top(self, value):
self.y = value
@property
def bottom(self):
return self.y + self.h
@bottom.setter
def bottom(self, value):
self.y = value - self.h
@property
def left(self):
return self.x
@left.setter
def left(self, value):
self.x = value
@property
def right(self):
return self.x + self.w
@right.setter
def right(self, value):
self.x = value - self.w
@property
def width(self):
return self.w
@width.setter
def width(self, value):
self.w = value
@property
def height(self):
return self.h
@height.setter
def height(self, value):
self.h = value

View File

@@ -4,13 +4,16 @@
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
def noop(*args, **kwargs):
pass
class NoopGUI:
def __getattr__(self, func_name):
return noop
class GUIObject:
"""Cross-toolkit "model" representation of a GUI layer object.
@@ -32,6 +35,7 @@ class GUIObject:
However, sometimes you want to be able to re-bind another view. In this case, set the
``multibind`` flag to ``True`` and the safeguard will be disabled.
"""
def __init__(self, multibind=False):
self._view = None
self._multibind = multibind
@@ -77,4 +81,3 @@ class GUIObject:
# Instead of None, we put a NoopGUI() there to avoid rogue view callback raising an
# exception.
self._view = NoopGUI()

View File

@@ -1,21 +1,23 @@
# Created By: Virgil Dupras
# Created On: 2010-07-25
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
import copy
from .base import GUIObject
class Column:
"""Holds column attributes such as its name, width, visibility, etc.
These attributes are then used to correctly configure the column on the "view" side.
"""
def __init__(self, name, display='', visible=True, optional=False):
def __init__(self, name, display="", visible=True, optional=False):
#: "programmatical" (not for display) name. Used as a reference in a couple of place, such
#: as :meth:`Columns.column_by_name`.
self.name = name
@@ -39,52 +41,57 @@ class Column:
self.default_visible = visible
#: Whether the column can have :attr:`visible` set to false.
self.optional = optional
class ColumnsView:
"""Expected interface for :class:`Columns`'s view.
*Not actually used in the code. For documentation purposes only.*
Our view, the columns controller of a table or outline, is expected to properly respond to
callbacks.
"""
def restore_columns(self):
"""Update all columns according to the model.
When this is called, our view has to update the columns title, order and visibility of all
columns.
"""
def set_column_visible(self, colname, visible):
"""Update visibility of column ``colname``.
Called when the user toggles the visibility of a column, we must update the column
``colname``'s visibility status to ``visible``.
"""
class PrefAccessInterface:
"""Expected interface for :class:`Columns`'s prefaccess.
*Not actually used in the code. For documentation purposes only.*
"""
def get_default(self, key, fallback_value):
"""Retrieve the value for ``key`` in the currently running app's preference store.
If the key doesn't exist, return ``fallback_value``.
"""
def set_default(self, key, value):
"""Set the value ``value`` for ``key`` in the currently running app's preference store.
"""
class Columns(GUIObject):
"""Cross-toolkit GUI-enabled column set for tables or outlines.
Manages a column set's order, visibility and width. We also manage the persistence of these
attributes so that we can restore them on the next run.
Subclasses :class:`.GUIObject`. Expected view: :class:`ColumnsView`.
:param table: The table the columns belong to. It's from there that we retrieve our column
configuration and it must have a ``COLUMNS`` attribute which is a list of
:class:`Column`. We also call :meth:`~.GUITable.save_edits` on it from time to
@@ -97,6 +104,7 @@ class Columns(GUIObject):
a prefix. Preferences are saved under more than one name, but they will all
have that same prefix.
"""
def __init__(self, table, prefaccess=None, savename=None):
GUIObject.__init__(self)
self.table = table
@@ -108,84 +116,88 @@ class Columns(GUIObject):
column.logical_index = i
column.ordered_index = i
self.coldata = {col.name: col for col in self.column_list}
#--- Private
# --- Private
def _get_colname_attr(self, colname, attrname, default):
try:
return getattr(self.coldata[colname], attrname)
except KeyError:
return default
def _set_colname_attr(self, colname, attrname, value):
try:
col = self.coldata[colname]
setattr(col, attrname, value)
except KeyError:
pass
def _optional_columns(self):
return [c for c in self.column_list if c.optional]
#--- Override
# --- Override
def _view_updated(self):
self.restore_columns()
#--- Public
# --- Public
def column_by_index(self, index):
"""Return the :class:`Column` having the :attr:`~Column.logical_index` ``index``.
"""
return self.column_list[index]
def column_by_name(self, name):
"""Return the :class:`Column` having the :attr:`~Column.name` ``name``.
"""
return self.coldata[name]
def columns_count(self):
"""Returns the number of columns in our set.
"""
return len(self.column_list)
def column_display(self, colname):
"""Returns display name for column named ``colname``, or ``''`` if there's none.
"""
return self._get_colname_attr(colname, 'display', '')
return self._get_colname_attr(colname, "display", "")
def column_is_visible(self, colname):
"""Returns visibility for column named ``colname``, or ``True`` if there's none.
"""
return self._get_colname_attr(colname, 'visible', True)
return self._get_colname_attr(colname, "visible", True)
def column_width(self, colname):
"""Returns width for column named ``colname``, or ``0`` if there's none.
"""
return self._get_colname_attr(colname, 'width', 0)
return self._get_colname_attr(colname, "width", 0)
def columns_to_right(self, colname):
"""Returns the list of all columns to the right of ``colname``.
"right" meaning "having a higher :attr:`Column.ordered_index`" in our left-to-right
civilization.
"""
column = self.coldata[colname]
index = column.ordered_index
return [col.name for col in self.column_list if (col.visible and col.ordered_index > index)]
return [
col.name
for col in self.column_list
if (col.visible and col.ordered_index > index)
]
def menu_items(self):
"""Returns a list of items convenient for quick visibility menu generation.
Returns a list of ``(display_name, is_marked)`` items for each optional column in the
current view (``is_marked`` means that it's visible).
You can use this to generate a menu to let the user toggle the visibility of an optional
column. That is why we only show optional column, because the visibility of mandatory
columns can't be toggled.
"""
return [(c.display, c.visible) for c in self._optional_columns()]
def move_column(self, colname, index):
"""Moves column ``colname`` to ``index``.
The column will be placed just in front of the column currently having that index, or to the
end of the list if there's none.
"""
@@ -193,7 +205,7 @@ class Columns(GUIObject):
colnames.remove(colname)
colnames.insert(index, colname)
self.set_column_order(colnames)
def reset_to_defaults(self):
"""Reset all columns' width and visibility to their default values.
"""
@@ -202,12 +214,12 @@ class Columns(GUIObject):
col.visible = col.default_visible
col.width = col.default_width
self.view.restore_columns()
def resize_column(self, colname, newwidth):
"""Set column ``colname``'s width to ``newwidth``.
"""
self._set_colname_attr(colname, 'width', newwidth)
self._set_colname_attr(colname, "width", newwidth)
def restore_columns(self):
"""Restore's column persistent attributes from the last :meth:`save_columns`.
"""
@@ -218,72 +230,73 @@ class Columns(GUIObject):
self.view.restore_columns()
return
for col in self.column_list:
pref_name = '{}.Columns.{}'.format(self.savename, col.name)
pref_name = "{}.Columns.{}".format(self.savename, col.name)
coldata = self.prefaccess.get_default(pref_name, fallback_value={})
if 'index' in coldata:
col.ordered_index = coldata['index']
if 'width' in coldata:
col.width = coldata['width']
if col.optional and 'visible' in coldata:
col.visible = coldata['visible']
if "index" in coldata:
col.ordered_index = coldata["index"]
if "width" in coldata:
col.width = coldata["width"]
if col.optional and "visible" in coldata:
col.visible = coldata["visible"]
self.view.restore_columns()
def save_columns(self):
"""Save column attributes in persistent storage for restoration in :meth:`restore_columns`.
"""
if not (self.prefaccess and self.savename and self.coldata):
return
for col in self.column_list:
pref_name = '{}.Columns.{}'.format(self.savename, col.name)
coldata = {'index': col.ordered_index, 'width': col.width}
pref_name = "{}.Columns.{}".format(self.savename, col.name)
coldata = {"index": col.ordered_index, "width": col.width}
if col.optional:
coldata['visible'] = col.visible
coldata["visible"] = col.visible
self.prefaccess.set_default(pref_name, coldata)
def set_column_order(self, colnames):
"""Change the columns order so it matches the order in ``colnames``.
:param colnames: A list of column names in the desired order.
"""
colnames = (name for name in colnames if name in self.coldata)
for i, colname in enumerate(colnames):
col = self.coldata[colname]
col.ordered_index = i
def set_column_visible(self, colname, visible):
"""Set the visibility of column ``colname``.
"""
self.table.save_edits() # the table on the GUI side will stop editing when the columns change
self._set_colname_attr(colname, 'visible', visible)
self.table.save_edits() # the table on the GUI side will stop editing when the columns change
self._set_colname_attr(colname, "visible", visible)
self.view.set_column_visible(colname, visible)
def set_default_width(self, colname, width):
"""Set the default width or column ``colname``.
"""
self._set_colname_attr(colname, 'default_width', width)
self._set_colname_attr(colname, "default_width", width)
def toggle_menu_item(self, index):
"""Toggles the visibility of an optional column.
You know, that optional column menu you've generated in :meth:`menu_items`? Well, ``index``
is the index of them menu item in *that* menu that the user has clicked on to toggle it.
Returns whether the column in question ends up being visible or not.
"""
col = self._optional_columns()[index]
self.set_column_visible(col.name, not col.visible)
return col.visible
#--- Properties
# --- Properties
@property
def ordered_columns(self):
"""List of :class:`Column` in visible order.
"""
return [col for col in sorted(self.column_list, key=lambda col: col.ordered_index)]
return [
col for col in sorted(self.column_list, key=lambda col: col.ordered_index)
]
@property
def colnames(self):
"""List of column names in visible order.
"""
return [col.name for col in self.ordered_columns]

View File

@@ -8,6 +8,7 @@ from ..jobprogress.performer import ThreadedJobPerformer
from .base import GUIObject
from .text_field import TextField
class ProgressWindowView:
"""Expected interface for :class:`ProgressWindow`'s view.
@@ -18,6 +19,7 @@ class ProgressWindowView:
It's also expected to call :meth:`ProgressWindow.cancel` when the cancel button is clicked.
"""
def show(self):
"""Show the dialog.
"""
@@ -36,6 +38,7 @@ class ProgressWindowView:
:param int progress: a value between ``0`` and ``100``.
"""
class ProgressWindow(GUIObject, ThreadedJobPerformer):
"""Cross-toolkit GUI-enabled progress window.
@@ -58,6 +61,7 @@ class ProgressWindow(GUIObject, ThreadedJobPerformer):
if you want to. If the function returns ``True``, ``finish_func()`` will be
called as if the job terminated normally.
"""
def __init__(self, finish_func, error_func=None):
# finish_func(jobid) is the function that is called when a job is completed.
GUIObject.__init__(self)
@@ -124,10 +128,9 @@ class ProgressWindow(GUIObject, ThreadedJobPerformer):
# target is a function with its first argument being a Job. It can then be followed by other
# arguments which are passed as `args`.
self.jobid = jobid
self.progressdesc_textfield.text = ''
self.progressdesc_textfield.text = ""
j = self.create_job()
args = tuple([j] + list(args))
self.run_threaded(target, args)
self.jobdesc_textfield.text = title
self.view.show()

View File

@@ -1,92 +1,96 @@
# Created By: Virgil Dupras
# Created On: 2011-09-06
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
from collections import Sequence, MutableSequence
from .base import GUIObject
class Selectable(Sequence):
"""Mix-in for a ``Sequence`` that manages its selection status.
When mixed in with a ``Sequence``, we enable it to manage its selection status. The selection
is held as a list of ``int`` indexes. Multiple selection is supported.
"""
def __init__(self):
self._selected_indexes = []
#--- Private
# --- Private
def _check_selection_range(self):
if not self:
self._selected_indexes = []
if not self._selected_indexes:
return
self._selected_indexes = [index for index in self._selected_indexes if index < len(self)]
self._selected_indexes = [
index for index in self._selected_indexes if index < len(self)
]
if not self._selected_indexes:
self._selected_indexes = [len(self) - 1]
#--- Virtual
# --- Virtual
def _update_selection(self):
"""(Virtual) Updates the model's selection appropriately.
Called after selection has been updated. Takes the table's selection and does appropriates
updates on the view and/or model. Common sense would dictate that when the selection doesn't
change, we don't update anything (and thus don't call ``_update_selection()`` at all), but
there are cases where it's false. For example, if our list updates its items but doesn't
change its selection, we probably want to update the model's selection.
By default, does nothing.
Important note: This is only called on :meth:`select`, not on changes to
:attr:`selected_indexes`.
"""
# A redesign of how this whole thing works is probably in order, but not now, there's too
# much breakage at once involved.
#--- Public
# --- Public
def select(self, indexes):
"""Update selection to ``indexes``.
:meth:`_update_selection` is called afterwards.
:param list indexes: List of ``int`` that is to become the new selection.
"""
if isinstance(indexes, int):
indexes = [indexes]
self.selected_indexes = indexes
self._update_selection()
#--- Properties
# --- Properties
@property
def selected_index(self):
"""Points to the first selected index.
*int*. *get/set*.
*int*. *get/set*.
Thin wrapper around :attr:`selected_indexes`. ``None`` if selection is empty. Using this
property only makes sense if your selectable sequence supports single selection only.
"""
return self._selected_indexes[0] if self._selected_indexes else None
@selected_index.setter
def selected_index(self, value):
self.selected_indexes = [value]
@property
def selected_indexes(self):
"""List of selected indexes.
*list of int*. *get/set*.
When setting the value, automatically removes out-of-bounds indexes. The list is kept
sorted.
"""
return self._selected_indexes
@selected_indexes.setter
def selected_indexes(self, value):
self._selected_indexes = value
@@ -96,53 +100,54 @@ class Selectable(Sequence):
class SelectableList(MutableSequence, Selectable):
"""A list that can manage selection of its items.
Subclasses :class:`Selectable`. Behaves like a ``list``.
"""
def __init__(self, items=None):
Selectable.__init__(self)
if items:
self._items = list(items)
else:
self._items = []
def __delitem__(self, key):
self._items.__delitem__(key)
self._check_selection_range()
self._on_change()
def __getitem__(self, key):
return self._items.__getitem__(key)
def __len__(self):
return len(self._items)
def __setitem__(self, key, value):
self._items.__setitem__(key, value)
self._on_change()
#--- Override
# --- Override
def append(self, item):
self._items.append(item)
self._on_change()
def insert(self, index, item):
self._items.insert(index, item)
self._on_change()
def remove(self, row):
self._items.remove(row)
self._check_selection_range()
self._on_change()
#--- Virtual
# --- Virtual
def _on_change(self):
"""(Virtual) Called whenever the contents of the list changes.
By default, does nothing.
"""
#--- Public
# --- Public
def search_by_prefix(self, prefix):
# XXX Why the heck is this method here?
prefix = prefix.lower()
@@ -150,59 +155,62 @@ class SelectableList(MutableSequence, Selectable):
if s.lower().startswith(prefix):
return index
return -1
class GUISelectableListView:
"""Expected interface for :class:`GUISelectableList`'s view.
*Not actually used in the code. For documentation purposes only.*
Our view, some kind of list view or combobox, is expected to sync with the list's contents by
appropriately behave to all callbacks in this interface.
"""
def refresh(self):
"""Refreshes the contents of the list widget.
Ensures that the contents of the list widget is synced with the model.
"""
def update_selection(self):
"""Update selection status.
Ensures that the list widget's selection is in sync with the model.
"""
class GUISelectableList(SelectableList, GUIObject):
"""Cross-toolkit GUI-enabled list view.
Represents a UI element presenting the user with a selectable list of items.
Subclasses :class:`SelectableList` and :class:`.GUIObject`. Expected view:
:class:`GUISelectableListView`.
:param iterable items: If specified, items to fill the list with initially.
"""
def __init__(self, items=None):
SelectableList.__init__(self, items)
GUIObject.__init__(self)
def _view_updated(self):
"""Refreshes the view contents with :meth:`GUISelectableListView.refresh`.
Overrides :meth:`~hscommon.gui.base.GUIObject._view_updated`.
"""
self.view.refresh()
def _update_selection(self):
"""Refreshes the view selection with :meth:`GUISelectableListView.update_selection`.
Overrides :meth:`Selectable._update_selection`.
"""
self.view.update_selection()
def _on_change(self):
"""Refreshes the view contents with :meth:`GUISelectableListView.refresh`.
Overrides :meth:`SelectableList._on_change`.
"""
self.view.refresh()

View File

@@ -11,6 +11,7 @@ from collections import MutableSequence, namedtuple
from .base import GUIObject
from .selectable_list import Selectable
# We used to directly subclass list, but it caused problems at some point with deepcopy
class Table(MutableSequence, Selectable):
"""Sortable and selectable sequence of :class:`Row`.
@@ -24,6 +25,7 @@ class Table(MutableSequence, Selectable):
Subclasses :class:`.Selectable`.
"""
def __init__(self):
Selectable.__init__(self)
self._rows = []
@@ -101,7 +103,7 @@ class Table(MutableSequence, Selectable):
if self._footer is not None:
self._rows.append(self._footer)
#--- Properties
# --- Properties
@property
def footer(self):
"""If set, a row that always stay at the bottom of the table.
@@ -216,6 +218,7 @@ class GUITableView:
Whenever the user changes the selection, we expect the view to call :meth:`Table.select`.
"""
def refresh(self):
"""Refreshes the contents of the table widget.
@@ -238,7 +241,9 @@ class GUITableView:
"""
SortDescriptor = namedtuple('SortDescriptor', 'column desc')
SortDescriptor = namedtuple("SortDescriptor", "column desc")
class GUITable(Table, GUIObject):
"""Cross-toolkit GUI-enabled table view.
@@ -254,6 +259,7 @@ class GUITable(Table, GUIObject):
Subclasses :class:`Table` and :class:`.GUIObject`. Expected view:
:class:`GUITableView`.
"""
def __init__(self):
GUIObject.__init__(self)
Table.__init__(self)
@@ -261,7 +267,7 @@ class GUITable(Table, GUIObject):
self.edited = None
self._sort_descriptor = None
#--- Virtual
# --- Virtual
def _do_add(self):
"""(Virtual) Creates a new row, adds it in the table.
@@ -309,7 +315,7 @@ class GUITable(Table, GUIObject):
else:
self.select([len(self) - 1])
#--- Public
# --- Public
def add(self):
"""Add a new row in edit mode.
@@ -444,6 +450,7 @@ class Row:
Of course, this is only default behavior. This can be overriden.
"""
def __init__(self, table):
super(Row, self).__init__()
self.table = table
@@ -454,7 +461,7 @@ class Row:
assert self.table.edited is None
self.table.edited = self
#--- Virtual
# --- Virtual
def can_edit(self):
"""(Virtual) Whether the whole row can be edited.
@@ -489,11 +496,11 @@ class Row:
there's none, raises ``AttributeError``.
"""
try:
return getattr(self, '_' + column_name)
return getattr(self, "_" + column_name)
except AttributeError:
return getattr(self, column_name)
#--- Public
# --- Public
def can_edit_cell(self, column_name):
"""Returns whether cell for column ``column_name`` can be edited.
@@ -511,18 +518,18 @@ class Row:
return False
# '_' is in case column is a python keyword
if not hasattr(self, column_name):
if hasattr(self, column_name + '_'):
column_name = column_name + '_'
if hasattr(self, column_name + "_"):
column_name = column_name + "_"
else:
return False
if hasattr(self, 'can_edit_' + column_name):
return getattr(self, 'can_edit_' + column_name)
if hasattr(self, "can_edit_" + column_name):
return getattr(self, "can_edit_" + column_name)
# If the row has a settable property, we can edit the cell
rowclass = self.__class__
prop = getattr(rowclass, column_name, None)
if prop is None:
return False
return bool(getattr(prop, 'fset', None))
return bool(getattr(prop, "fset", None))
def get_cell_value(self, attrname):
"""Get cell value for ``attrname``.
@@ -530,8 +537,8 @@ class Row:
By default, does a simple ``getattr()``, but it is used to allow subclasses to have
alternative value storage mechanisms.
"""
if attrname == 'from':
attrname = 'from_'
if attrname == "from":
attrname = "from_"
return getattr(self, attrname)
def set_cell_value(self, attrname, value):
@@ -540,7 +547,6 @@ class Row:
By default, does a simple ``setattr()``, but it is used to allow subclasses to have
alternative value storage mechanisms.
"""
if attrname == 'from':
attrname = 'from_'
if attrname == "from":
attrname = "from_"
setattr(self, attrname, value)

View File

@@ -1,102 +1,106 @@
# Created On: 2012/01/23
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
from .base import GUIObject
from ..util import nonone
class TextFieldView:
"""Expected interface for :class:`TextField`'s view.
*Not actually used in the code. For documentation purposes only.*
Our view is expected to sync with :attr:`TextField.text` "both ways", that is, update the
model's text when the user types something, but also update the text field when :meth:`refresh`
is called.
"""
def refresh(self):
"""Refreshes the contents of the input widget.
Ensures that the contents of the input widget is actually :attr:`TextField.text`.
"""
class TextField(GUIObject):
"""Cross-toolkit text field.
Represents a UI element allowing the user to input a text value. Its main attribute is
:attr:`text` which acts as the store of the said value.
When our model value isn't a string, we have a built-in parsing/formatting mechanism allowing
us to directly retrieve/set our non-string value through :attr:`value`.
Subclasses :class:`.GUIObject`. Expected view: :class:`TextFieldView`.
"""
def __init__(self):
GUIObject.__init__(self)
self._text = ''
self._text = ""
self._value = None
#--- Virtual
# --- Virtual
def _parse(self, text):
"""(Virtual) Parses ``text`` to put into :attr:`value`.
Returns the parsed version of ``text``. Called whenever :attr:`text` changes.
"""
return text
def _format(self, value):
"""(Virtual) Formats ``value`` to put into :attr:`text`.
Returns the formatted version of ``value``. Called whenever :attr:`value` changes.
"""
return value
def _update(self, newvalue):
"""(Virtual) Called whenever we have a new value.
Whenever our text/value store changes to a new value (different from the old one), this
method is called. By default, it does nothing but you can override it if you want.
"""
#--- Override
# --- Override
def _view_updated(self):
self.view.refresh()
#--- Public
# --- Public
def refresh(self):
"""Triggers a view :meth:`~TextFieldView.refresh`.
"""
self.view.refresh()
@property
def text(self):
"""The text that is currently displayed in the widget.
*str*. *get/set*.
This property can be set. When it is, :meth:`refresh` is called and the view is synced with
our value. Always in sync with :attr:`value`.
"""
return self._text
@text.setter
def text(self, newtext):
self.value = self._parse(nonone(newtext, ''))
self.value = self._parse(nonone(newtext, ""))
@property
def value(self):
"""The "parsed" representation of :attr:`text`.
*arbitrary type*. *get/set*.
By default, it's a mirror of :attr:`text`, but a subclass can override :meth:`_parse` and
:meth:`_format` to have anything else. Always in sync with :attr:`text`.
"""
return self._value
@value.setter
def value(self, newvalue):
if newvalue == self._value:
@@ -105,4 +109,3 @@ class TextField(GUIObject):
self._text = self._format(newvalue)
self._update(self._value)
self.refresh()

View File

@@ -1,16 +1,17 @@
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
from collections import MutableSequence
from .base import GUIObject
class Node(MutableSequence):
"""Pretty bland node implementation to be used in a :class:`Tree`.
It has a :attr:`parent`, behaves like a list, its content being its children. Link integrity
is somewhat enforced (adding a child to a node will set the child's :attr:`parent`, but that's
pretty much as far as we go, integrity-wise. Nodes don't tend to move around much in a GUI
@@ -19,57 +20,58 @@ class Node(MutableSequence):
Nodes are designed to be subclassed and given meaningful attributes (those you'll want to
display in your tree view), but they all have a :attr:`name`, which is given on initialization.
"""
def __init__(self, name):
self._name = name
self._parent = None
self._path = None
self._children = []
def __repr__(self):
return '<Node %r>' % self.name
#--- MutableSequence overrides
return "<Node %r>" % self.name
# --- MutableSequence overrides
def __delitem__(self, key):
self._children.__delitem__(key)
def __getitem__(self, key):
return self._children.__getitem__(key)
def __len__(self):
return len(self._children)
def __setitem__(self, key, value):
self._children.__setitem__(key, value)
def append(self, node):
self._children.append(node)
node._parent = self
node._path = None
def insert(self, index, node):
self._children.insert(index, node)
node._parent = self
node._path = None
#--- Public
# --- Public
def clear(self):
"""Clears the node of all its children.
"""
del self[:]
def find(self, predicate, include_self=True):
"""Return the first child to match ``predicate``.
See :meth:`findall`.
"""
try:
return next(self.findall(predicate, include_self=include_self))
except StopIteration:
return None
def findall(self, predicate, include_self=True):
"""Yield all children matching ``predicate``.
:param predicate: ``f(node) --> bool``
:param include_self: Whether we can return ``self`` or we return only children.
"""
@@ -78,10 +80,10 @@ class Node(MutableSequence):
for child in self:
for found in child.findall(predicate, include_self=True):
yield found
def get_node(self, index_path):
"""Returns the node at ``index_path``.
:param index_path: a list of int indexes leading to our node. See :attr:`path`.
"""
result = self
@@ -89,40 +91,40 @@ class Node(MutableSequence):
for index in index_path:
result = result[index]
return result
def get_path(self, target_node):
"""Returns the :attr:`path` of ``target_node``.
If ``target_node`` is ``None``, returns ``None``.
"""
if target_node is None:
return None
return target_node.path
@property
def children_count(self):
"""Same as ``len(self)``.
"""
return len(self)
@property
def name(self):
"""Name for the node, supplied on init.
"""
return self._name
@property
def parent(self):
"""Parent of the node.
If ``None``, we have a root node.
"""
return self._parent
@property
def path(self):
"""A list of node indexes leading from the root node to ``self``.
The path of a node is always related to its :attr:`root`. It's the sequences of index that
we have to take to get to our node, starting from the root. For example, if
``node.path == [1, 2, 3, 4]``, it means that ``node.root[1][2][3][4] is node``.
@@ -133,112 +135,113 @@ class Node(MutableSequence):
else:
self._path = self._parent.path + [self._parent.index(self)]
return self._path
@property
def root(self):
"""Root node of current node.
To get it, we recursively follow our :attr:`parent` chain until we have ``None``.
"""
if self._parent is None:
return self
else:
return self._parent.root
class Tree(Node, GUIObject):
"""Cross-toolkit GUI-enabled tree view.
This class is a bit too thin to be used as a tree view controller out of the box and HS apps
that subclasses it each add quite a bit of logic to it to make it workable. Making this more
usable out of the box is a work in progress.
This class is here (in addition to being a :class:`Node`) mostly to handle selection.
Subclasses :class:`Node` (it is the root node of all its children) and :class:`.GUIObject`.
"""
def __init__(self):
Node.__init__(self, '')
Node.__init__(self, "")
GUIObject.__init__(self)
#: Where we store selected nodes (as a list of :class:`Node`)
self._selected_nodes = []
#--- Virtual
# --- Virtual
def _select_nodes(self, nodes):
"""(Virtual) Customize node selection behavior.
By default, simply set :attr:`_selected_nodes`.
"""
self._selected_nodes = nodes
#--- Override
# --- Override
def _view_updated(self):
self.view.refresh()
def clear(self):
self._selected_nodes = []
Node.clear(self)
#--- Public
# --- Public
@property
def selected_node(self):
"""Currently selected node.
*:class:`Node`*. *get/set*.
First of :attr:`selected_nodes`. ``None`` if empty.
"""
return self._selected_nodes[0] if self._selected_nodes else None
@selected_node.setter
def selected_node(self, node):
if node is not None:
self._select_nodes([node])
else:
self._select_nodes([])
@property
def selected_nodes(self):
"""List of selected nodes in the tree.
*List of :class:`Node`*. *get/set*.
We use nodes instead of indexes to store selection because it's simpler when it's time to
manage selection of multiple node levels.
"""
return self._selected_nodes
@selected_nodes.setter
def selected_nodes(self, nodes):
self._select_nodes(nodes)
@property
def selected_path(self):
"""Currently selected path.
*:attr:`Node.path`*. *get/set*.
First of :attr:`selected_paths`. ``None`` if empty.
"""
return self.get_path(self.selected_node)
@selected_path.setter
def selected_path(self, index_path):
if index_path is not None:
self.selected_paths = [index_path]
else:
self._select_nodes([])
@property
def selected_paths(self):
"""List of selected paths in the tree.
*List of :attr:`Node.path`*. *get/set*
Computed from :attr:`selected_nodes`.
"""
return list(map(self.get_path, self._selected_nodes))
@selected_paths.setter
def selected_paths(self, index_paths):
nodes = []
@@ -248,4 +251,3 @@ class Tree(Node, GUIObject):
except IndexError:
pass
self._select_nodes(nodes)

View File

@@ -6,15 +6,19 @@
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
class JobCancelled(Exception):
"The user has cancelled the job"
class JobInProgressError(Exception):
"A job is already being performed, you can't perform more than one at the same time."
class JobCountError(Exception):
"The number of jobs started have exceeded the number of jobs allowed"
class Job:
"""Manages a job's progression and return it's progression through a callback.
@@ -30,14 +34,15 @@ class Job:
Another one is that nothing stops you from calling add_progress right after
SkipJob.
"""
#---Magic functions
# ---Magic functions
def __init__(self, job_proportions, callback):
"""Initialize the Job with 'jobcount' jobs. Start every job with
start_job(). Every time the job progress is updated, 'callback' is called
'callback' takes a 'progress' int param, and a optional 'desc'
parameter. Callback must return false if the job must be cancelled.
"""
if not hasattr(callback, '__call__'):
if not hasattr(callback, "__call__"):
raise TypeError("'callback' MUST be set when creating a Job")
if isinstance(job_proportions, int):
job_proportions = [1] * job_proportions
@@ -49,12 +54,12 @@ class Job:
self._progress = 0
self._currmax = 1
#---Private
def _subjob_callback(self, progress, desc=''):
# ---Private
def _subjob_callback(self, progress, desc=""):
"""This is the callback passed to children jobs.
"""
self.set_progress(progress, desc)
return True #if JobCancelled has to be raised, it will be at the highest level
return True # if JobCancelled has to be raised, it will be at the highest level
def _do_update(self, desc):
"""Calls the callback function with a % progress as a parameter.
@@ -67,18 +72,18 @@ class Job:
total_progress = self._jobcount * self._currmax
progress = ((passed_progress + current_progress) * 100) // total_progress
else:
progress = -1 # indeterminate
progress = -1 # indeterminate
# It's possible that callback doesn't support a desc arg
result = self._callback(progress, desc) if desc else self._callback(progress)
if not result:
raise JobCancelled()
#---Public
def add_progress(self, progress=1, desc=''):
# ---Public
def add_progress(self, progress=1, desc=""):
self.set_progress(self._progress + progress, desc)
def check_if_cancelled(self):
self._do_update('')
self._do_update("")
def iter_with_progress(self, iterable, desc_format=None, every=1, count=None):
"""Iterate through ``iterable`` while automatically adding progress.
@@ -89,7 +94,7 @@ class Job:
"""
if count is None:
count = len(iterable)
desc = ''
desc = ""
if desc_format:
desc = desc_format % (0, count)
self.start_job(count, desc)
@@ -103,7 +108,7 @@ class Job:
desc = desc_format % (count, count)
self.set_progress(100, desc)
def start_job(self, max_progress=100, desc=''):
def start_job(self, max_progress=100, desc=""):
"""Begin work on the next job. You must not call start_job more than
'jobcount' (in __init__) times.
'max' is the job units you are to perform.
@@ -118,7 +123,7 @@ class Job:
self._currmax = max(1, max_progress)
self._do_update(desc)
def start_subjob(self, job_proportions, desc=''):
def start_subjob(self, job_proportions, desc=""):
"""Starts a sub job. Use this when you want to split a job into
multiple smaller jobs. Pretty handy when starting a process where you
know how many subjobs you will have, but don't know the work unit count
@@ -128,7 +133,7 @@ class Job:
self.start_job(100, desc)
return Job(job_proportions, self._subjob_callback)
def set_progress(self, progress, desc=''):
def set_progress(self, progress, desc=""):
"""Sets the progress of the current job to 'progress', and call the
callback
"""

View File

@@ -1,9 +1,9 @@
# Created By: Virgil Dupras
# Created On: 2010-11-19
# Copyright 2011 Hardcoded Software (http://www.hardcoded.net)
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
from threading import Thread
@@ -11,29 +11,31 @@ import sys
from .job import Job, JobInProgressError, JobCancelled
class ThreadedJobPerformer:
"""Run threaded jobs and track progress.
To run a threaded job, first create a job with _create_job(), then call _run_threaded(), with
To run a threaded job, first create a job with _create_job(), then call _run_threaded(), with
your work function as a parameter.
Example:
j = self._create_job()
self._run_threaded(self.some_work_func, (arg1, arg2, j))
"""
_job_running = False
last_error = None
#--- Protected
# --- Protected
def create_job(self):
if self._job_running:
raise JobInProgressError()
self.last_progress = -1
self.last_desc = ''
self.last_desc = ""
self.job_cancelled = False
return Job(1, self._update_progress)
def _async_run(self, *args):
target = args[0]
args = tuple(args[1:])
@@ -49,24 +51,23 @@ class ThreadedJobPerformer:
finally:
self._job_running = False
self.last_progress = None
def reraise_if_error(self):
"""Reraises the error that happened in the thread if any.
Call this after the caller of run_threaded detected that self._job_running returned to False
"""
if self.last_error is not None:
raise self.last_error.with_traceback(self.last_traceback)
def _update_progress(self, newprogress, newdesc=''):
def _update_progress(self, newprogress, newdesc=""):
self.last_progress = newprogress
if newdesc:
self.last_desc = newdesc
return not self.job_cancelled
def run_threaded(self, target, args=()):
if self._job_running:
raise JobInProgressError()
args = (target, ) + args
args = (target,) + args
Thread(target=self._async_run, args=args).start()

View File

@@ -11,17 +11,18 @@ from PyQt5.QtWidgets import QProgressDialog
from . import performer
class Progress(QProgressDialog, performer.ThreadedJobPerformer):
finished = pyqtSignal(['QString'])
finished = pyqtSignal(["QString"])
def __init__(self, parent):
flags = Qt.CustomizeWindowHint | Qt.WindowTitleHint | Qt.WindowSystemMenuHint
QProgressDialog.__init__(self, '', "Cancel", 0, 100, parent, flags)
QProgressDialog.__init__(self, "", "Cancel", 0, 100, parent, flags)
self.setModal(True)
self.setAutoReset(False)
self.setAutoClose(False)
self._timer = QTimer()
self._jobid = ''
self._jobid = ""
self._timer.timeout.connect(self.updateProgress)
def updateProgress(self):
@@ -44,9 +45,8 @@ class Progress(QProgressDialog, performer.ThreadedJobPerformer):
def run(self, jobid, title, target, args=()):
self._jobid = jobid
self.reset()
self.setLabelText('')
self.setLabelText("")
self.run_threaded(target, args)
self.setWindowTitle(title)
self.show()
self._timer.start(500)

View File

@@ -7,26 +7,29 @@ import tempfile
import polib
from . import pygettext
from .util import modified_after, dedupe, ensure_folder, ensure_file
from .build import print_and_do, ensure_empty_folder, copy
from .util import modified_after, dedupe, ensure_folder
from .build import print_and_do, ensure_empty_folder
LC_MESSAGES = 'LC_MESSAGES'
LC_MESSAGES = "LC_MESSAGES"
# There isn't a 1-on-1 exact fit between .po language codes and cocoa ones
PO2COCOA = {
'pl_PL': 'pl',
'pt_BR': 'pt-BR',
'zh_CN': 'zh-Hans',
"pl_PL": "pl",
"pt_BR": "pt-BR",
"zh_CN": "zh-Hans",
}
COCOA2PO = {v: k for k, v in PO2COCOA.items()}
def get_langs(folder):
return [name for name in os.listdir(folder) if op.isdir(op.join(folder, name))]
def files_with_ext(folder, ext):
return [op.join(folder, fn) for fn in os.listdir(folder) if fn.endswith(ext)]
def generate_pot(folders, outpath, keywords, merge=False):
if merge and not op.exists(outpath):
merge = False
@@ -37,21 +40,23 @@ def generate_pot(folders, outpath, keywords, merge=False):
pyfiles = []
for folder in folders:
for root, dirs, filenames in os.walk(folder):
keep = [fn for fn in filenames if fn.endswith('.py')]
keep = [fn for fn in filenames if fn.endswith(".py")]
pyfiles += [op.join(root, fn) for fn in keep]
pygettext.main(pyfiles, outpath=genpath, keywords=keywords)
if merge:
merge_po_and_preserve(genpath, outpath)
os.remove(genpath)
def compile_all_po(base_folder):
langs = get_langs(base_folder)
for lang in langs:
pofolder = op.join(base_folder, lang, LC_MESSAGES)
pofiles = files_with_ext(pofolder, '.po')
pofiles = files_with_ext(pofolder, ".po")
for pofile in pofiles:
p = polib.pofile(pofile)
p.save_as_mofile(pofile[:-3] + '.mo')
p.save_as_mofile(pofile[:-3] + ".mo")
def merge_locale_dir(target, mergeinto):
langs = get_langs(target)
@@ -59,22 +64,24 @@ def merge_locale_dir(target, mergeinto):
if not op.exists(op.join(mergeinto, lang)):
continue
mofolder = op.join(target, lang, LC_MESSAGES)
mofiles = files_with_ext(mofolder, '.mo')
mofiles = files_with_ext(mofolder, ".mo")
for mofile in mofiles:
shutil.copy(mofile, op.join(mergeinto, lang, LC_MESSAGES))
def merge_pots_into_pos(folder):
# We're going to take all pot files in `folder` and for each lang, merge it with the po file
# with the same name.
potfiles = files_with_ext(folder, '.pot')
potfiles = files_with_ext(folder, ".pot")
for potfile in potfiles:
refpot = polib.pofile(potfile)
refname = op.splitext(op.basename(potfile))[0]
for lang in get_langs(folder):
po = polib.pofile(op.join(folder, lang, LC_MESSAGES, refname + '.po'))
po = polib.pofile(op.join(folder, lang, LC_MESSAGES, refname + ".po"))
po.merge(refpot)
po.save()
def merge_po_and_preserve(source, dest):
# Merges source entries into dest, but keep old entries intact
sourcepo = polib.pofile(source)
@@ -86,36 +93,41 @@ def merge_po_and_preserve(source, dest):
destpo.append(entry)
destpo.save()
def normalize_all_pos(base_folder):
"""Normalize the format of .po files in base_folder.
When getting POs from external sources, such as Transifex, we end up with spurious diffs because
of a difference in the way line wrapping is handled. It wouldn't be a big deal if it happened
once, but these spurious diffs keep overwriting each other, and it's annoying.
Our PO files will keep polib's format. Call this function to ensure that freshly pulled POs
are of the right format before committing them.
"""
langs = get_langs(base_folder)
for lang in langs:
pofolder = op.join(base_folder, lang, LC_MESSAGES)
pofiles = files_with_ext(pofolder, '.po')
pofiles = files_with_ext(pofolder, ".po")
for pofile in pofiles:
p = polib.pofile(pofile)
p.save()
#--- Cocoa
# --- Cocoa
def all_lproj_paths(folder):
return files_with_ext(folder, '.lproj')
return files_with_ext(folder, ".lproj")
def escape_cocoa_strings(s):
return s.replace('\\', '\\\\').replace('"', '\\"').replace('\n', '\\n')
return s.replace("\\", "\\\\").replace('"', '\\"').replace("\n", "\\n")
def unescape_cocoa_strings(s):
return s.replace('\\\\', '\\').replace('\\"', '"').replace('\\n', '\n')
return s.replace("\\\\", "\\").replace('\\"', '"').replace("\\n", "\n")
def strings2pot(target, dest):
with open(target, 'rt', encoding='utf-8') as fp:
with open(target, "rt", encoding="utf-8") as fp:
contents = fp.read()
# We're reading an en.lproj file. We only care about the righthand part of the translation.
re_trans = re.compile(r'".*" = "(.*)";')
@@ -131,17 +143,21 @@ def strings2pot(target, dest):
entry = polib.POEntry(msgid=s)
po.append(entry)
# we don't know or care about a line number so we put 0
entry.occurrences.append((target, '0'))
entry.occurrences.append((target, "0"))
entry.occurrences = dedupe(entry.occurrences)
po.save(dest)
def allstrings2pot(lprojpath, dest, excludes=None):
allstrings = files_with_ext(lprojpath, '.strings')
allstrings = files_with_ext(lprojpath, ".strings")
if excludes:
allstrings = [p for p in allstrings if op.splitext(op.basename(p))[0] not in excludes]
allstrings = [
p for p in allstrings if op.splitext(op.basename(p))[0] not in excludes
]
for strings_path in allstrings:
strings2pot(strings_path, dest)
def po2strings(pofile, en_strings, dest):
# Takes en_strings and replace all righthand parts of "foo" = "bar"; entries with translations
# in pofile, then puts the result in dest.
@@ -150,9 +166,10 @@ def po2strings(pofile, en_strings, dest):
return
ensure_folder(op.dirname(dest))
print("Creating {} from {}".format(dest, pofile))
with open(en_strings, 'rt', encoding='utf-8') as fp:
with open(en_strings, "rt", encoding="utf-8") as fp:
contents = fp.read()
re_trans = re.compile(r'(?<= = ").*(?=";\n)')
def repl(match):
s = match.group(0)
unescaped = unescape_cocoa_strings(s)
@@ -162,10 +179,12 @@ def po2strings(pofile, en_strings, dest):
return s
trans = entry.msgstr
return escape_cocoa_strings(trans) if trans else s
contents = re_trans.sub(repl, contents)
with open(dest, 'wt', encoding='utf-8') as fp:
with open(dest, "wt", encoding="utf-8") as fp:
fp.write(contents)
def generate_cocoa_strings_from_code(code_folder, dest_folder):
# Uses the "genstrings" command to generate strings file from all .m files in "code_folder".
# The strings file (their name depends on the localization table used in the source) will be
@@ -173,36 +192,49 @@ def generate_cocoa_strings_from_code(code_folder, dest_folder):
# genstrings produces utf-16 files with comments. After having generated the files, we convert
# them to utf-8 and remove the comments.
ensure_empty_folder(dest_folder)
print_and_do('genstrings -o "{}" `find "{}" -name *.m | xargs`'.format(dest_folder, code_folder))
print_and_do(
'genstrings -o "{}" `find "{}" -name *.m | xargs`'.format(
dest_folder, code_folder
)
)
for stringsfile in os.listdir(dest_folder):
stringspath = op.join(dest_folder, stringsfile)
with open(stringspath, 'rt', encoding='utf-16') as fp:
with open(stringspath, "rt", encoding="utf-16") as fp:
content = fp.read()
content = re.sub('/\*.*?\*/', '', content)
content = re.sub('\n{2,}', '\n', content)
content = re.sub(r"/\*.*?\*/", "", content)
content = re.sub(r"\n{2,}", "\n", content)
# I have no idea why, but genstrings seems to have problems with "%" character in strings
# and inserts (number)$ after it. Find these bogus inserts and remove them.
content = re.sub('%\d\$', '%', content)
with open(stringspath, 'wt', encoding='utf-8') as fp:
content = re.sub(r"%\d\$", "%", content)
with open(stringspath, "wt", encoding="utf-8") as fp:
fp.write(content)
def generate_cocoa_strings_from_xib(xib_folder):
xibs = [op.join(xib_folder, fn) for fn in os.listdir(xib_folder) if fn.endswith('.xib')]
xibs = [
op.join(xib_folder, fn) for fn in os.listdir(xib_folder) if fn.endswith(".xib")
]
for xib in xibs:
dest = xib.replace('.xib', '.strings')
print_and_do('ibtool {} --generate-strings-file {}'.format(xib, dest))
print_and_do('iconv -f utf-16 -t utf-8 {0} | tee {0}'.format(dest))
dest = xib.replace(".xib", ".strings")
print_and_do("ibtool {} --generate-strings-file {}".format(xib, dest))
print_and_do("iconv -f utf-16 -t utf-8 {0} | tee {0}".format(dest))
def localize_stringsfile(stringsfile, dest_root_folder):
stringsfile_name = op.basename(stringsfile)
for lang in get_langs('locale'):
pofile = op.join('locale', lang, 'LC_MESSAGES', 'ui.po')
for lang in get_langs("locale"):
pofile = op.join("locale", lang, "LC_MESSAGES", "ui.po")
cocoa_lang = PO2COCOA.get(lang, lang)
dest_lproj = op.join(dest_root_folder, cocoa_lang + '.lproj')
dest_lproj = op.join(dest_root_folder, cocoa_lang + ".lproj")
ensure_folder(dest_lproj)
po2strings(pofile, stringsfile, op.join(dest_lproj, stringsfile_name))
def localize_all_stringsfiles(src_folder, dest_root_folder):
stringsfiles = [op.join(src_folder, fn) for fn in os.listdir(src_folder) if fn.endswith('.strings')]
stringsfiles = [
op.join(src_folder, fn)
for fn in os.listdir(src_folder)
if fn.endswith(".strings")
]
for path in stringsfiles:
localize_stringsfile(path, dest_root_folder)

View File

@@ -1,7 +1,7 @@
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
"""Very simple inter-object notification system.
@@ -14,55 +14,58 @@ the method with the same name as the broadcasted message is called on the listen
from collections import defaultdict
class Broadcaster:
"""Broadcasts messages that are received by all listeners.
"""
def __init__(self):
self.listeners = set()
def add_listener(self, listener):
self.listeners.add(listener)
def notify(self, msg):
"""Notify all connected listeners of ``msg``.
That means that each listeners will have their method with the same name as ``msg`` called.
"""
for listener in self.listeners.copy(): # listeners can change during iteration
if listener in self.listeners: # disconnected during notification
for listener in self.listeners.copy(): # listeners can change during iteration
if listener in self.listeners: # disconnected during notification
listener.dispatch(msg)
def remove_listener(self, listener):
self.listeners.discard(listener)
class Listener:
"""A listener is initialized with the broadcaster it's going to listen to. Initially, it is not connected.
"""
def __init__(self, broadcaster):
self.broadcaster = broadcaster
self._bound_notifications = defaultdict(list)
def bind_messages(self, messages, func):
"""Binds multiple message to the same function.
Often, we perform the same thing on multiple messages. Instead of having the same function
repeated again and agin in our class, we can use this method to bind multiple messages to
the same function.
"""
for message in messages:
self._bound_notifications[message].append(func)
def connect(self):
"""Connects the listener to its broadcaster.
"""
self.broadcaster.add_listener(self)
def disconnect(self):
"""Disconnects the listener from its broadcaster.
"""
self.broadcaster.remove_listener(self)
def dispatch(self, msg):
if msg in self._bound_notifications:
for func in self._bound_notifications[msg]:
@@ -70,20 +73,19 @@ class Listener:
if hasattr(self, msg):
method = getattr(self, msg)
method()
class Repeater(Broadcaster, Listener):
REPEATED_NOTIFICATIONS = None
def __init__(self, broadcaster):
Broadcaster.__init__(self)
Listener.__init__(self, broadcaster)
def _repeat_message(self, msg):
if not self.REPEATED_NOTIFICATIONS or msg in self.REPEATED_NOTIFICATIONS:
self.notify(msg)
def dispatch(self, msg):
Listener.dispatch(self, msg)
self._repeat_message(msg)

View File

@@ -2,8 +2,8 @@
# Created On: 2006/02/21
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
import logging
@@ -15,19 +15,21 @@ from itertools import takewhile
from functools import wraps
from inspect import signature
class Path(tuple):
"""A handy class to work with paths.
We subclass ``tuple``, each element of the tuple represents an element of the path.
* ``Path('/foo/bar/baz')[1]`` --> ``'bar'``
* ``Path('/foo/bar/baz')[1:2]`` --> ``Path('bar/baz')``
* ``Path('/foo/bar')['baz']`` --> ``Path('/foo/bar/baz')``
* ``str(Path('/foo/bar/baz'))`` --> ``'/foo/bar/baz'``
"""
# Saves a little bit of memory usage
__slots__ = ()
def __new__(cls, value, separator=None):
def unicode_if_needed(s):
if isinstance(s, str):
@@ -38,7 +40,7 @@ class Path(tuple):
except UnicodeDecodeError:
logging.warning("Could not decode %r", s)
raise
if isinstance(value, Path):
return value
if not separator:
@@ -47,44 +49,53 @@ class Path(tuple):
value = unicode_if_needed(value)
if isinstance(value, str):
if value:
if (separator not in value) and ('/' in value):
separator = '/'
if (separator not in value) and ("/" in value):
separator = "/"
value = value.split(separator)
else:
value = ()
else:
if any(isinstance(x, bytes) for x in value):
value = [unicode_if_needed(x) for x in value]
#value is a tuple/list
# value is a tuple/list
if any(separator in x for x in value):
#We have a component with a separator in it. Let's rejoin it, and generate another path.
# We have a component with a separator in it. Let's rejoin it, and generate another path.
return Path(separator.join(value), separator)
if (len(value) > 1) and (not value[-1]):
value = value[:-1] #We never want a path to end with a '' (because Path() can be called with a trailing slash ending path)
value = value[
:-1
] # We never want a path to end with a '' (because Path() can be called with a trailing slash ending path)
return tuple.__new__(cls, value)
def __add__(self, other):
other = Path(other)
if other and (not other[0]):
other = other[1:]
return Path(tuple.__add__(self, other))
def __contains__(self, item):
if isinstance(item, Path):
return item[:len(self)] == self
return item[: len(self)] == self
else:
return tuple.__contains__(self, item)
def __eq__(self, other):
return tuple.__eq__(self, Path(other))
def __getitem__(self, key):
if isinstance(key, slice):
if isinstance(key.start, Path):
equal_elems = list(takewhile(lambda pair: pair[0] == pair[1], zip(self, key.start)))
equal_elems = list(
takewhile(lambda pair: pair[0] == pair[1], zip(self, key.start))
)
key = slice(len(equal_elems), key.stop, key.step)
if isinstance(key.stop, Path):
equal_elems = list(takewhile(lambda pair: pair[0] == pair[1], zip(reversed(self), reversed(key.stop))))
equal_elems = list(
takewhile(
lambda pair: pair[0] == pair[1],
zip(reversed(self), reversed(key.stop)),
)
)
stop = -len(equal_elems) if equal_elems else None
key = slice(key.start, stop, key.step)
return Path(tuple.__getitem__(self, key))
@@ -92,31 +103,31 @@ class Path(tuple):
return self + key
else:
return tuple.__getitem__(self, key)
def __hash__(self):
return tuple.__hash__(self)
def __ne__(self, other):
return not self.__eq__(other)
def __radd__(self, other):
return Path(other) + self
def __str__(self):
if len(self) == 1:
first = self[0]
if (len(first) == 2) and (first[1] == ':'): #Windows drive letter
return first + '\\'
elif not len(first): #root directory
return '/'
if (len(first) == 2) and (first[1] == ":"): # Windows drive letter
return first + "\\"
elif not len(first): # root directory
return "/"
return os.sep.join(self)
def has_drive_letter(self):
if not self:
return False
first = self[0]
return (len(first) == 2) and (first[1] == ':')
return (len(first) == 2) and (first[1] == ":")
def is_parent_of(self, other):
"""Whether ``other`` is a subpath of ``self``.
@@ -133,29 +144,29 @@ class Path(tuple):
return self[1:]
else:
return self
def tobytes(self):
return str(self).encode(sys.getfilesystemencoding())
def parent(self):
"""Returns the parent path.
``Path('/foo/bar/baz').parent()`` --> ``Path('/foo/bar')``
"""
return self[:-1]
@property
def name(self):
"""Last element of the path (filename), with extension.
``Path('/foo/bar/baz').name`` --> ``'baz'``
"""
return self[-1]
# OS method wrappers
def exists(self):
return op.exists(str(self))
def copy(self, dest_path):
return shutil.copy(str(self), str(dest_path))
@@ -200,36 +211,44 @@ class Path(tuple):
def stat(self):
return os.stat(str(self))
def pathify(f):
"""Ensure that every annotated :class:`Path` arguments are actually paths.
When a function is decorated with ``@pathify``, every argument with annotated as Path will be
converted to a Path if it wasn't already. Example::
@pathify
def foo(path: Path, otherarg):
return path.listdir()
Calling ``foo('/bar', 0)`` will convert ``'/bar'`` to ``Path('/bar')``.
"""
sig = signature(f)
pindexes = {i for i, p in enumerate(sig.parameters.values()) if p.annotation is Path}
pindexes = {
i for i, p in enumerate(sig.parameters.values()) if p.annotation is Path
}
pkeys = {k: v for k, v in sig.parameters.items() if v.annotation is Path}
def path_or_none(p):
return None if p is None else Path(p)
@wraps(f)
def wrapped(*args, **kwargs):
args = tuple((path_or_none(a) if i in pindexes else a) for i, a in enumerate(args))
args = tuple(
(path_or_none(a) if i in pindexes else a) for i, a in enumerate(args)
)
kwargs = {k: (path_or_none(v) if k in pkeys else v) for k, v in kwargs.items()}
return f(*args, **kwargs)
return wrapped
def log_io_error(func):
""" Catches OSError, IOError and WindowsError and log them
"""
@wraps(func)
def wrapper(path, *args, **kwargs):
try:
@@ -239,5 +258,5 @@ def log_io_error(func):
classname = e.__class__.__name__
funcname = func.__name__
logging.warn(msg.format(classname, funcname, str(path), str(e)))
return wrapper

View File

@@ -1,8 +1,8 @@
# Created On: 2011/09/22
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
# Yes, I know, there's the 'platform' unit for this kind of stuff, but the thing is that I got a
@@ -11,6 +11,6 @@
import sys
ISWINDOWS = sys.platform == 'win32'
ISOSX = sys.platform == 'darwin'
ISLINUX = sys.platform.startswith('linux')
ISWINDOWS = sys.platform == "win32"
ISOSX = sys.platform == "darwin"
ISLINUX = sys.platform.startswith("linux")

View File

@@ -18,20 +18,17 @@ import os
import imp
import sys
import glob
import time
import token
import tokenize
import operator
__version__ = '1.5'
__version__ = "1.5"
default_keywords = ['_']
DEFAULTKEYWORDS = ', '.join(default_keywords)
default_keywords = ["_"]
DEFAULTKEYWORDS = ", ".join(default_keywords)
EMPTYSTRING = ''
EMPTYSTRING = ""
# The normal pot-file header. msgmerge and Emacs's po-mode work better if it's
# there.
pot_header = """
@@ -41,17 +38,17 @@ msgstr ""
"Content-Transfer-Encoding: utf-8\\n"
"""
def usage(code, msg=''):
def usage(code, msg=""):
print(__doc__ % globals(), file=sys.stderr)
if msg:
print(msg, file=sys.stderr)
sys.exit(code)
escapes = []
def make_escapes(pass_iso8859):
global escapes
if pass_iso8859:
@@ -66,11 +63,11 @@ def make_escapes(pass_iso8859):
escapes.append(chr(i))
else:
escapes.append("\\%03o" % i)
escapes[ord('\\')] = '\\\\'
escapes[ord('\t')] = '\\t'
escapes[ord('\r')] = '\\r'
escapes[ord('\n')] = '\\n'
escapes[ord('\"')] = '\\"'
escapes[ord("\\")] = "\\\\"
escapes[ord("\t")] = "\\t"
escapes[ord("\r")] = "\\r"
escapes[ord("\n")] = "\\n"
escapes[ord('"')] = '\\"'
def escape(s):
@@ -83,26 +80,26 @@ def escape(s):
def safe_eval(s):
# unwrap quotes, safely
return eval(s, {'__builtins__':{}}, {})
return eval(s, {"__builtins__": {}}, {})
def normalize(s):
# This converts the various Python string types into a format that is
# appropriate for .po files, namely much closer to C style.
lines = s.split('\n')
lines = s.split("\n")
if len(lines) == 1:
s = '"' + escape(s) + '"'
else:
if not lines[-1]:
del lines[-1]
lines[-1] = lines[-1] + '\n'
lines[-1] = lines[-1] + "\n"
for i in range(len(lines)):
lines[i] = escape(lines[i])
lineterm = '\\n"\n"'
s = '""\n"' + lineterm.join(lines) + '"'
return s
def containsAny(str, set):
"""Check whether 'str' contains ANY of the chars in 'set'"""
return 1 in [c in str for c in set]
@@ -111,20 +108,24 @@ def containsAny(str, set):
def _visit_pyfiles(list, dirname, names):
"""Helper for getFilesForName()."""
# get extension for python source files
if '_py_ext' not in globals():
if "_py_ext" not in globals():
global _py_ext
_py_ext = [triple[0] for triple in imp.get_suffixes()
if triple[2] == imp.PY_SOURCE][0]
_py_ext = [
triple[0] for triple in imp.get_suffixes() if triple[2] == imp.PY_SOURCE
][0]
# don't recurse into CVS directories
if 'CVS' in names:
names.remove('CVS')
if "CVS" in names:
names.remove("CVS")
# add all *.py files to list
list.extend(
[os.path.join(dirname, file) for file in names
if os.path.splitext(file)[1] == _py_ext]
)
[
os.path.join(dirname, file)
for file in names
if os.path.splitext(file)[1] == _py_ext
]
)
def _get_modpkg_path(dotted_name, pathlist=None):
@@ -135,13 +136,14 @@ def _get_modpkg_path(dotted_name, pathlist=None):
extension module.
"""
# split off top-most name
parts = dotted_name.split('.', 1)
parts = dotted_name.split(".", 1)
if len(parts) > 1:
# we have a dotted path, import top-level package
try:
file, pathname, description = imp.find_module(parts[0], pathlist)
if file: file.close()
if file:
file.close()
except ImportError:
return None
@@ -154,8 +156,7 @@ def _get_modpkg_path(dotted_name, pathlist=None):
else:
# plain name
try:
file, pathname, description = imp.find_module(
dotted_name, pathlist)
file, pathname, description = imp.find_module(dotted_name, pathlist)
if file:
file.close()
if description[2] not in [imp.PY_SOURCE, imp.PKG_DIRECTORY]:
@@ -195,7 +196,7 @@ def getFilesForName(name):
return []
class TokenEater:
def __init__(self, options):
self.__options = options
@@ -208,9 +209,9 @@ class TokenEater:
def __call__(self, ttype, tstring, stup, etup, line):
# dispatch
## import token
## print >> sys.stderr, 'ttype:', token.tok_name[ttype], \
## 'tstring:', tstring
# import token
# print >> sys.stderr, 'ttype:', token.tok_name[ttype], \
# 'tstring:', tstring
self.__state(ttype, tstring, stup[0])
def __waiting(self, ttype, tstring, lineno):
@@ -226,7 +227,7 @@ class TokenEater:
self.__freshmodule = 0
return
# class docstring?
if ttype == tokenize.NAME and tstring in ('class', 'def'):
if ttype == tokenize.NAME and tstring in ("class", "def"):
self.__state = self.__suiteseen
return
if ttype == tokenize.NAME and tstring in opts.keywords:
@@ -234,7 +235,7 @@ class TokenEater:
def __suiteseen(self, ttype, tstring, lineno):
# ignore anything until we see the colon
if ttype == tokenize.OP and tstring == ':':
if ttype == tokenize.OP and tstring == ":":
self.__state = self.__suitedocstring
def __suitedocstring(self, ttype, tstring, lineno):
@@ -242,13 +243,12 @@ class TokenEater:
if ttype == tokenize.STRING:
self.__addentry(safe_eval(tstring), lineno, isdocstring=1)
self.__state = self.__waiting
elif ttype not in (tokenize.NEWLINE, tokenize.INDENT,
tokenize.COMMENT):
elif ttype not in (tokenize.NEWLINE, tokenize.INDENT, tokenize.COMMENT):
# there was no class docstring
self.__state = self.__waiting
def __keywordseen(self, ttype, tstring, lineno):
if ttype == tokenize.OP and tstring == '(':
if ttype == tokenize.OP and tstring == "(":
self.__data = []
self.__lineno = lineno
self.__state = self.__openseen
@@ -256,7 +256,7 @@ class TokenEater:
self.__state = self.__waiting
def __openseen(self, ttype, tstring, lineno):
if ttype == tokenize.OP and tstring == ')':
if ttype == tokenize.OP and tstring == ")":
# We've seen the last of the translatable strings. Record the
# line number of the first line of the strings and update the list
# of messages seen. Reset state for the next batch. If there
@@ -266,20 +266,25 @@ class TokenEater:
self.__state = self.__waiting
elif ttype == tokenize.STRING:
self.__data.append(safe_eval(tstring))
elif ttype not in [tokenize.COMMENT, token.INDENT, token.DEDENT,
token.NEWLINE, tokenize.NL]:
elif ttype not in [
tokenize.COMMENT,
token.INDENT,
token.DEDENT,
token.NEWLINE,
tokenize.NL,
]:
# warn if we see anything else than STRING or whitespace
print('*** %(file)s:%(lineno)s: Seen unexpected token "%(token)s"' % {
'token': tstring,
'file': self.__curfile,
'lineno': self.__lineno
}, file=sys.stderr)
print(
'*** %(file)s:%(lineno)s: Seen unexpected token "%(token)s"'
% {"token": tstring, "file": self.__curfile, "lineno": self.__lineno},
file=sys.stderr,
)
self.__state = self.__waiting
def __addentry(self, msg, lineno=None, isdocstring=0):
if lineno is None:
lineno = self.__lineno
if not msg in self.__options.toexclude:
if msg not in self.__options.toexclude:
entry = (self.__curfile, lineno)
self.__messages.setdefault(msg, {})[entry] = isdocstring
@@ -289,7 +294,6 @@ class TokenEater:
def write(self, fp):
options = self.__options
timestamp = time.strftime('%Y-%m-%d %H:%M+%Z')
# The time stamp in the header doesn't have the same format as that
# generated by xgettext...
print(pot_header, file=fp)
@@ -317,15 +321,15 @@ class TokenEater:
# location comments are different b/w Solaris and GNU:
elif options.locationstyle == options.SOLARIS:
for filename, lineno in v:
d = {'filename': filename, 'lineno': lineno}
print('# File: %(filename)s, line: %(lineno)d' % d, file=fp)
d = {"filename": filename, "lineno": lineno}
print("# File: %(filename)s, line: %(lineno)d" % d, file=fp)
elif options.locationstyle == options.GNU:
# fit as many locations on one line, as long as the
# resulting line length doesn't exceeds 'options.width'
locline = '#:'
locline = "#:"
for filename, lineno in v:
d = {'filename': filename, 'lineno': lineno}
s = ' %(filename)s:%(lineno)d' % d
d = {"filename": filename, "lineno": lineno}
s = " %(filename)s:%(lineno)d" % d
if len(locline) + len(s) <= options.width:
locline = locline + s
else:
@@ -334,37 +338,34 @@ class TokenEater:
if len(locline) > 2:
print(locline, file=fp)
if isdocstring:
print('#, docstring', file=fp)
print('msgid', normalize(k), file=fp)
print("#, docstring", file=fp)
print("msgid", normalize(k), file=fp)
print('msgstr ""\n', file=fp)
def main(source_files, outpath, keywords=None):
global default_keywords
# for holding option values
class Options:
# constants
GNU = 1
SOLARIS = 2
# defaults
extractall = 0 # FIXME: currently this option has no effect at all.
extractall = 0 # FIXME: currently this option has no effect at all.
escape = 0
keywords = []
outfile = 'messages.pot'
outfile = "messages.pot"
writelocations = 1
locationstyle = GNU
verbose = 0
width = 78
excludefilename = ''
excludefilename = ""
docstrings = 0
nodocstrings = {}
options = Options()
locations = {'gnu' : options.GNU,
'solaris' : options.SOLARIS,
}
options.outfile = outpath
if keywords:
options.keywords = keywords
@@ -378,11 +379,14 @@ def main(source_files, outpath, keywords=None):
# initialize list of strings to exclude
if options.excludefilename:
try:
fp = open(options.excludefilename, encoding='utf-8')
fp = open(options.excludefilename, encoding="utf-8")
options.toexclude = fp.readlines()
fp.close()
except IOError:
print("Can't read --exclude-file: %s" % options.excludefilename, file=sys.stderr)
print(
"Can't read --exclude-file: %s" % options.excludefilename,
file=sys.stderr,
)
sys.exit(1)
else:
options.toexclude = []
@@ -391,8 +395,8 @@ def main(source_files, outpath, keywords=None):
eater = TokenEater(options)
for filename in source_files:
if options.verbose:
print('Working on %s' % filename)
fp = open(filename, encoding='utf-8')
print("Working on %s" % filename)
fp = open(filename, encoding="utf-8")
closep = 1
try:
eater.set_filename(filename)
@@ -401,14 +405,16 @@ def main(source_files, outpath, keywords=None):
for _token in tokens:
eater(*_token)
except tokenize.TokenError as e:
print('%s: %s, line %d, column %d' % (
e.args[0], filename, e.args[1][0], e.args[1][1]),
file=sys.stderr)
print(
"%s: %s, line %d, column %d"
% (e.args[0], filename, e.args[1][0], e.args[1][1]),
file=sys.stderr,
)
finally:
if closep:
fp.close()
fp = open(options.outfile, 'w', encoding='utf-8')
fp = open(options.outfile, "w", encoding="utf-8")
closep = 1
try:
eater.write(fp)

View File

@@ -19,16 +19,28 @@ CHANGELOG_FORMAT = """
{description}
"""
def tixgen(tixurl):
"""This is a filter *generator*. tixurl is a url pattern for the tix with a {0} placeholder
for the tix #
"""
urlpattern = tixurl.format('\\1') # will be replaced buy the content of the first group in re
R = re.compile(r'#(\d+)')
repl = '`#\\1 <{}>`__'.format(urlpattern)
urlpattern = tixurl.format(
"\\1"
) # will be replaced buy the content of the first group in re
R = re.compile(r"#(\d+)")
repl = "`#\\1 <{}>`__".format(urlpattern)
return lambda text: R.sub(repl, text)
def gen(basepath, destpath, changelogpath, tixurl, confrepl=None, confpath=None, changelogtmpl=None):
def gen(
basepath,
destpath,
changelogpath,
tixurl,
confrepl=None,
confpath=None,
changelogtmpl=None,
):
"""Generate sphinx docs with all bells and whistles.
basepath: The base sphinx source path.
@@ -40,41 +52,47 @@ def gen(basepath, destpath, changelogpath, tixurl, confrepl=None, confpath=None,
if confrepl is None:
confrepl = {}
if confpath is None:
confpath = op.join(basepath, 'conf.tmpl')
confpath = op.join(basepath, "conf.tmpl")
if changelogtmpl is None:
changelogtmpl = op.join(basepath, 'changelog.tmpl')
changelogtmpl = op.join(basepath, "changelog.tmpl")
changelog = read_changelog_file(changelogpath)
tix = tixgen(tixurl)
rendered_logs = []
for log in changelog:
description = tix(log['description'])
description = tix(log["description"])
# The format of the changelog descriptions is in markdown, but since we only use bulled list
# and links, it's not worth depending on the markdown package. A simple regexp suffice.
description = re.sub(r'\[(.*?)\]\((.*?)\)', '`\\1 <\\2>`__', description)
rendered = CHANGELOG_FORMAT.format(version=log['version'], date=log['date_str'],
description=description)
description = re.sub(r"\[(.*?)\]\((.*?)\)", "`\\1 <\\2>`__", description)
rendered = CHANGELOG_FORMAT.format(
version=log["version"], date=log["date_str"], description=description
)
rendered_logs.append(rendered)
confrepl['version'] = changelog[0]['version']
changelog_out = op.join(basepath, 'changelog.rst')
filereplace(changelogtmpl, changelog_out, changelog='\n'.join(rendered_logs))
confrepl["version"] = changelog[0]["version"]
changelog_out = op.join(basepath, "changelog.rst")
filereplace(changelogtmpl, changelog_out, changelog="\n".join(rendered_logs))
if op.exists(confpath):
conf_out = op.join(basepath, 'conf.py')
conf_out = op.join(basepath, "conf.py")
filereplace(confpath, conf_out, **confrepl)
if LooseVersion(get_distribution("sphinx").version) >= LooseVersion("1.7.0"):
from sphinx.cmd.build import build_main as sphinx_build
# Call the sphinx_build function, which is the same as doing sphinx-build from cli
try:
sphinx_build([basepath, destpath])
except SystemExit:
print("Sphinx called sys.exit(), but we're cancelling it because we don't actually want to exit")
print(
"Sphinx called sys.exit(), but we're cancelling it because we don't actually want to exit"
)
else:
# We used to call sphinx-build with print_and_do(), but the problem was that the virtualenv
# of the calling python wasn't correctly considered and caused problems with documentation
# relying on autodoc (which tries to import the module to auto-document, but fail because of
# missing dependencies which are in the virtualenv). Here, we do exactly what is done when
# calling the command from bash.
cmd = load_entry_point('Sphinx', 'console_scripts', 'sphinx-build')
cmd = load_entry_point("Sphinx", "console_scripts", "sphinx-build")
try:
cmd(['sphinx-build', basepath, destpath])
cmd(["sphinx-build", basepath, destpath])
except SystemExit:
print("Sphinx called sys.exit(), but we're cancelling it because we don't actually want to exit")
print(
"Sphinx called sys.exit(), but we're cancelling it because we don't actually want to exit"
)

View File

@@ -2,39 +2,39 @@
# Created On: 2007/05/19
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
import sys
import os
import os.path as op
import threading
from queue import Queue
import time
import sqlite3 as sqlite
STOP = object()
COMMIT = object()
ROLLBACK = object()
class FakeCursor(list):
# It's not possible to use sqlite cursors on another thread than the connection. Thus,
# we can't directly return the cursor. We have to fatch all results, and support its interface.
def fetchall(self):
return self
def fetchone(self):
try:
return self.pop(0)
except IndexError:
return None
class _ActualThread(threading.Thread):
''' We can't use this class directly because thread object are not automatically freed when
""" We can't use this class directly because thread object are not automatically freed when
nothing refers to it, making it hang the application if not explicitely closed.
'''
"""
def __init__(self, dbname, autocommit):
threading.Thread.__init__(self)
self._queries = Queue()
@@ -47,7 +47,7 @@ class _ActualThread(threading.Thread):
self.lastrowid = -1
self.setDaemon(True)
self.start()
def _query(self, query):
with self._lock:
wait_token = object()
@@ -56,30 +56,30 @@ class _ActualThread(threading.Thread):
self._waiting_list.remove(wait_token)
result = self._results.get()
return result
def close(self):
if not self._run:
return
self._query(STOP)
def commit(self):
if not self._run:
return None # Connection closed
return None # Connection closed
self._query(COMMIT)
def execute(self, sql, values=()):
if not self._run:
return None # Connection closed
return None # Connection closed
result = self._query((sql, values))
if isinstance(result, Exception):
raise result
return result
def rollback(self):
if not self._run:
return None # Connection closed
return None # Connection closed
self._query(ROLLBACK)
def run(self):
# The whole chdir thing is because sqlite doesn't handle directory names with non-asci char in the AT ALL.
oldpath = os.getcwd()
@@ -111,31 +111,31 @@ class _ActualThread(threading.Thread):
result = e
self._results.put(result)
con.close()
class ThreadedConn:
"""``sqlite`` connections can't be used across threads. ``TheadedConn`` opens a sqlite
connection in its own thread and sends it queries through a queue, making it suitable in
multi-threaded environment.
"""
def __init__(self, dbname, autocommit):
self._t = _ActualThread(dbname, autocommit)
self.lastrowid = -1
def __del__(self):
self.close()
def close(self):
self._t.close()
def commit(self):
self._t.commit()
def execute(self, sql, values=()):
result = self._t.execute(sql, values)
self.lastrowid = self._t.lastrowid
return result
def rollback(self):
self._t.rollback()

View File

@@ -2,103 +2,105 @@
# Created On: 2008-01-08
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
from ..conflict import *
from ..path import Path
from ..testutil import eq_
class TestCase_GetConflictedName:
def test_simple(self):
name = get_conflicted_name(['bar'], 'bar')
eq_('[000] bar', name)
name = get_conflicted_name(['bar', '[000] bar'], 'bar')
eq_('[001] bar', name)
name = get_conflicted_name(["bar"], "bar")
eq_("[000] bar", name)
name = get_conflicted_name(["bar", "[000] bar"], "bar")
eq_("[001] bar", name)
def test_no_conflict(self):
name = get_conflicted_name(['bar'], 'foobar')
eq_('foobar', name)
name = get_conflicted_name(["bar"], "foobar")
eq_("foobar", name)
def test_fourth_digit(self):
# This test is long because every time we have to add a conflicted name,
# a test must be made for every other conflicted name existing...
# Anyway, this has very few chances to happen.
names = ['bar'] + ['[%03d] bar' % i for i in range(1000)]
name = get_conflicted_name(names, 'bar')
eq_('[1000] bar', name)
names = ["bar"] + ["[%03d] bar" % i for i in range(1000)]
name = get_conflicted_name(names, "bar")
eq_("[1000] bar", name)
def test_auto_unconflict(self):
# Automatically unconflict the name if it's already conflicted.
name = get_conflicted_name([], '[000] foobar')
eq_('foobar', name)
name = get_conflicted_name(['bar'], '[001] bar')
eq_('[000] bar', name)
name = get_conflicted_name([], "[000] foobar")
eq_("foobar", name)
name = get_conflicted_name(["bar"], "[001] bar")
eq_("[000] bar", name)
class TestCase_GetUnconflictedName:
def test_main(self):
eq_('foobar',get_unconflicted_name('[000] foobar'))
eq_('foobar',get_unconflicted_name('[9999] foobar'))
eq_('[000]foobar',get_unconflicted_name('[000]foobar'))
eq_('[000a] foobar',get_unconflicted_name('[000a] foobar'))
eq_('foobar',get_unconflicted_name('foobar'))
eq_('foo [000] bar',get_unconflicted_name('foo [000] bar'))
eq_("foobar", get_unconflicted_name("[000] foobar"))
eq_("foobar", get_unconflicted_name("[9999] foobar"))
eq_("[000]foobar", get_unconflicted_name("[000]foobar"))
eq_("[000a] foobar", get_unconflicted_name("[000a] foobar"))
eq_("foobar", get_unconflicted_name("foobar"))
eq_("foo [000] bar", get_unconflicted_name("foo [000] bar"))
class TestCase_IsConflicted:
def test_main(self):
assert is_conflicted('[000] foobar')
assert is_conflicted('[9999] foobar')
assert not is_conflicted('[000]foobar')
assert not is_conflicted('[000a] foobar')
assert not is_conflicted('foobar')
assert not is_conflicted('foo [000] bar')
assert is_conflicted("[000] foobar")
assert is_conflicted("[9999] foobar")
assert not is_conflicted("[000]foobar")
assert not is_conflicted("[000a] foobar")
assert not is_conflicted("foobar")
assert not is_conflicted("foo [000] bar")
class TestCase_move_copy:
def pytest_funcarg__do_setup(self, request):
tmpdir = request.getfuncargvalue('tmpdir')
tmpdir = request.getfuncargvalue("tmpdir")
self.path = Path(str(tmpdir))
self.path['foo'].open('w').close()
self.path['bar'].open('w').close()
self.path['dir'].mkdir()
self.path["foo"].open("w").close()
self.path["bar"].open("w").close()
self.path["dir"].mkdir()
def test_move_no_conflict(self, do_setup):
smart_move(self.path + 'foo', self.path + 'baz')
assert self.path['baz'].exists()
assert not self.path['foo'].exists()
def test_copy_no_conflict(self, do_setup): # No need to duplicate the rest of the tests... Let's just test on move
smart_copy(self.path + 'foo', self.path + 'baz')
assert self.path['baz'].exists()
assert self.path['foo'].exists()
smart_move(self.path + "foo", self.path + "baz")
assert self.path["baz"].exists()
assert not self.path["foo"].exists()
def test_copy_no_conflict(
self, do_setup
): # No need to duplicate the rest of the tests... Let's just test on move
smart_copy(self.path + "foo", self.path + "baz")
assert self.path["baz"].exists()
assert self.path["foo"].exists()
def test_move_no_conflict_dest_is_dir(self, do_setup):
smart_move(self.path + 'foo', self.path + 'dir')
assert self.path['dir']['foo'].exists()
assert not self.path['foo'].exists()
smart_move(self.path + "foo", self.path + "dir")
assert self.path["dir"]["foo"].exists()
assert not self.path["foo"].exists()
def test_move_conflict(self, do_setup):
smart_move(self.path + 'foo', self.path + 'bar')
assert self.path['[000] bar'].exists()
assert not self.path['foo'].exists()
smart_move(self.path + "foo", self.path + "bar")
assert self.path["[000] bar"].exists()
assert not self.path["foo"].exists()
def test_move_conflict_dest_is_dir(self, do_setup):
smart_move(self.path['foo'], self.path['dir'])
smart_move(self.path['bar'], self.path['foo'])
smart_move(self.path['foo'], self.path['dir'])
assert self.path['dir']['foo'].exists()
assert self.path['dir']['[000] foo'].exists()
assert not self.path['foo'].exists()
assert not self.path['bar'].exists()
smart_move(self.path["foo"], self.path["dir"])
smart_move(self.path["bar"], self.path["foo"])
smart_move(self.path["foo"], self.path["dir"])
assert self.path["dir"]["foo"].exists()
assert self.path["dir"]["[000] foo"].exists()
assert not self.path["foo"].exists()
assert not self.path["bar"].exists()
def test_copy_folder(self, tmpdir):
# smart_copy also works on folders
path = Path(str(tmpdir))
path['foo'].mkdir()
path['bar'].mkdir()
smart_copy(path['foo'], path['bar']) # no crash
assert path['[000] bar'].exists()
path["foo"].mkdir()
path["bar"].mkdir()
smart_copy(path["foo"], path["bar"]) # no crash
assert path["[000] bar"].exists()

View File

@@ -1,12 +1,13 @@
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
from ..testutil import eq_
from ..notify import Broadcaster, Listener, Repeater
class HelloListener(Listener):
def __init__(self, broadcaster):
Listener.__init__(self, broadcaster)
@@ -15,6 +16,7 @@ class HelloListener(Listener):
def hello(self):
self.hello_count += 1
class HelloRepeater(Repeater):
def __init__(self, broadcaster):
Repeater.__init__(self, broadcaster)
@@ -23,13 +25,15 @@ class HelloRepeater(Repeater):
def hello(self):
self.hello_count += 1
def create_pair():
b = Broadcaster()
l = HelloListener(b)
return b, l
def test_disconnect_during_notification():
# When a listener disconnects another listener the other listener will not receive a
# When a listener disconnects another listener the other listener will not receive a
# notification.
# This whole complication scheme below is because the order of the notification is not
# guaranteed. We could disconnect everything from self.broadcaster.listeners, but this
@@ -38,103 +42,116 @@ def test_disconnect_during_notification():
def __init__(self, broadcaster):
Listener.__init__(self, broadcaster)
self.hello_count = 0
def hello(self):
self.hello_count += 1
self.other.disconnect()
broadcaster = Broadcaster()
first = Disconnecter(broadcaster)
second = Disconnecter(broadcaster)
first.other, second.other = second, first
first.connect()
second.connect()
broadcaster.notify('hello')
broadcaster.notify("hello")
# only one of them was notified
eq_(first.hello_count + second.hello_count, 1)
def test_disconnect():
# After a disconnect, the listener doesn't hear anything.
b, l = create_pair()
l.connect()
l.disconnect()
b.notify('hello')
b.notify("hello")
eq_(l.hello_count, 0)
def test_disconnect_when_not_connected():
# When disconnecting an already disconnected listener, nothing happens.
b, l = create_pair()
l.disconnect()
def test_not_connected_on_init():
# A listener is not initialized connected.
b, l = create_pair()
b.notify('hello')
b.notify("hello")
eq_(l.hello_count, 0)
def test_notify():
# The listener listens to the broadcaster.
b, l = create_pair()
l.connect()
b.notify('hello')
b.notify("hello")
eq_(l.hello_count, 1)
def test_reconnect():
# It's possible to reconnect a listener after disconnection.
b, l = create_pair()
l.connect()
l.disconnect()
l.connect()
b.notify('hello')
b.notify("hello")
eq_(l.hello_count, 1)
def test_repeater():
b = Broadcaster()
r = HelloRepeater(b)
l = HelloListener(r)
r.connect()
l.connect()
b.notify('hello')
b.notify("hello")
eq_(r.hello_count, 1)
eq_(l.hello_count, 1)
def test_repeater_with_repeated_notifications():
# If REPEATED_NOTIFICATIONS is not empty, only notifs in this set are repeated (but they're
# still dispatched locally).
class MyRepeater(HelloRepeater):
REPEATED_NOTIFICATIONS = set(['hello'])
REPEATED_NOTIFICATIONS = set(["hello"])
def __init__(self, broadcaster):
HelloRepeater.__init__(self, broadcaster)
self.foo_count = 0
def foo(self):
self.foo_count += 1
b = Broadcaster()
r = MyRepeater(b)
l = HelloListener(r)
r.connect()
l.connect()
b.notify('hello')
b.notify('foo') # if the repeater repeated this notif, we'd get a crash on HelloListener
b.notify("hello")
b.notify(
"foo"
) # if the repeater repeated this notif, we'd get a crash on HelloListener
eq_(r.hello_count, 1)
eq_(l.hello_count, 1)
eq_(r.foo_count, 1)
def test_repeater_doesnt_try_to_dispatch_to_self_if_it_cant():
# if a repeater doesn't handle a particular message, it doesn't crash and simply repeats it.
b = Broadcaster()
r = Repeater(b) # doesnt handle hello
r = Repeater(b) # doesnt handle hello
l = HelloListener(r)
r.connect()
l.connect()
b.notify('hello') # no crash
b.notify("hello") # no crash
eq_(l.hello_count, 1)
def test_bind_messages():
b, l = create_pair()
l.bind_messages({'foo', 'bar'}, l.hello)
l.bind_messages({"foo", "bar"}, l.hello)
l.connect()
b.notify('foo')
b.notify('bar')
b.notify('hello') # Normal dispatching still work
b.notify("foo")
b.notify("bar")
b.notify("hello") # Normal dispatching still work
eq_(l.hello_count, 3)

View File

@@ -2,8 +2,8 @@
# Created On: 2006/02/21
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
import sys
@@ -14,33 +14,39 @@ from pytest import raises, mark
from ..path import Path, pathify
from ..testutil import eq_
def pytest_funcarg__force_ossep(request):
monkeypatch = request.getfuncargvalue('monkeypatch')
monkeypatch.setattr(os, 'sep', '/')
monkeypatch = request.getfuncargvalue("monkeypatch")
monkeypatch.setattr(os, "sep", "/")
def test_empty(force_ossep):
path = Path('')
eq_('',str(path))
eq_(0,len(path))
path = Path("")
eq_("", str(path))
eq_(0, len(path))
path = Path(())
eq_('',str(path))
eq_(0,len(path))
eq_("", str(path))
eq_(0, len(path))
def test_single(force_ossep):
path = Path('foobar')
eq_('foobar',path)
eq_(1,len(path))
path = Path("foobar")
eq_("foobar", path)
eq_(1, len(path))
def test_multiple(force_ossep):
path = Path('foo/bar')
eq_('foo/bar',path)
eq_(2,len(path))
path = Path("foo/bar")
eq_("foo/bar", path)
eq_(2, len(path))
def test_init_with_tuple_and_list(force_ossep):
path = Path(('foo','bar'))
eq_('foo/bar',path)
path = Path(['foo','bar'])
eq_('foo/bar',path)
path = Path(("foo", "bar"))
eq_("foo/bar", path)
path = Path(["foo", "bar"])
eq_("foo/bar", path)
def test_init_with_invalid_value(force_ossep):
try:
@@ -49,208 +55,236 @@ def test_init_with_invalid_value(force_ossep):
except TypeError:
pass
def test_access(force_ossep):
path = Path('foo/bar/bleh')
eq_('foo',path[0])
eq_('foo',path[-3])
eq_('bar',path[1])
eq_('bar',path[-2])
eq_('bleh',path[2])
eq_('bleh',path[-1])
path = Path("foo/bar/bleh")
eq_("foo", path[0])
eq_("foo", path[-3])
eq_("bar", path[1])
eq_("bar", path[-2])
eq_("bleh", path[2])
eq_("bleh", path[-1])
def test_slicing(force_ossep):
path = Path('foo/bar/bleh')
path = Path("foo/bar/bleh")
subpath = path[:2]
eq_('foo/bar',subpath)
assert isinstance(subpath,Path)
def test_parent(force_ossep):
path = Path('foo/bar/bleh')
subpath = path.parent()
eq_('foo/bar', subpath)
eq_("foo/bar", subpath)
assert isinstance(subpath, Path)
def test_parent(force_ossep):
path = Path("foo/bar/bleh")
subpath = path.parent()
eq_("foo/bar", subpath)
assert isinstance(subpath, Path)
def test_filename(force_ossep):
path = Path('foo/bar/bleh.ext')
eq_(path.name, 'bleh.ext')
path = Path("foo/bar/bleh.ext")
eq_(path.name, "bleh.ext")
def test_deal_with_empty_components(force_ossep):
"""Keep ONLY a leading space, which means we want a leading slash.
"""
eq_('foo//bar',str(Path(('foo','','bar'))))
eq_('/foo/bar',str(Path(('','foo','bar'))))
eq_('foo/bar',str(Path('foo/bar/')))
eq_("foo//bar", str(Path(("foo", "", "bar"))))
eq_("/foo/bar", str(Path(("", "foo", "bar"))))
eq_("foo/bar", str(Path("foo/bar/")))
def test_old_compare_paths(force_ossep):
eq_(Path('foobar'),Path('foobar'))
eq_(Path('foobar/'),Path('foobar\\','\\'))
eq_(Path('/foobar/'),Path('\\foobar\\','\\'))
eq_(Path('/foo/bar'),Path('\\foo\\bar','\\'))
eq_(Path('/foo/bar'),Path('\\foo\\bar\\','\\'))
assert Path('/foo/bar') != Path('\\foo\\foo','\\')
#We also have to test __ne__
assert not (Path('foobar') != Path('foobar'))
assert Path('/a/b/c.x') != Path('/a/b/c.y')
eq_(Path("foobar"), Path("foobar"))
eq_(Path("foobar/"), Path("foobar\\", "\\"))
eq_(Path("/foobar/"), Path("\\foobar\\", "\\"))
eq_(Path("/foo/bar"), Path("\\foo\\bar", "\\"))
eq_(Path("/foo/bar"), Path("\\foo\\bar\\", "\\"))
assert Path("/foo/bar") != Path("\\foo\\foo", "\\")
# We also have to test __ne__
assert not (Path("foobar") != Path("foobar"))
assert Path("/a/b/c.x") != Path("/a/b/c.y")
def test_old_split_path(force_ossep):
eq_(Path('foobar'),('foobar',))
eq_(Path('foo/bar'),('foo','bar'))
eq_(Path('/foo/bar/'),('','foo','bar'))
eq_(Path('\\foo\\bar','\\'),('','foo','bar'))
eq_(Path("foobar"), ("foobar",))
eq_(Path("foo/bar"), ("foo", "bar"))
eq_(Path("/foo/bar/"), ("", "foo", "bar"))
eq_(Path("\\foo\\bar", "\\"), ("", "foo", "bar"))
def test_representation(force_ossep):
eq_("('foo', 'bar')",repr(Path(('foo','bar'))))
eq_("('foo', 'bar')", repr(Path(("foo", "bar"))))
def test_add(force_ossep):
eq_('foo/bar/bar/foo',Path(('foo','bar')) + Path('bar/foo'))
eq_('foo/bar/bar/foo',Path('foo/bar') + 'bar/foo')
eq_('foo/bar/bar/foo',Path('foo/bar') + ('bar','foo'))
eq_('foo/bar/bar/foo',('foo','bar') + Path('bar/foo'))
eq_('foo/bar/bar/foo','foo/bar' + Path('bar/foo'))
#Invalid concatenation
eq_("foo/bar/bar/foo", Path(("foo", "bar")) + Path("bar/foo"))
eq_("foo/bar/bar/foo", Path("foo/bar") + "bar/foo")
eq_("foo/bar/bar/foo", Path("foo/bar") + ("bar", "foo"))
eq_("foo/bar/bar/foo", ("foo", "bar") + Path("bar/foo"))
eq_("foo/bar/bar/foo", "foo/bar" + Path("bar/foo"))
# Invalid concatenation
try:
Path(('foo','bar')) + 1
Path(("foo", "bar")) + 1
assert False
except TypeError:
pass
def test_path_slice(force_ossep):
foo = Path('foo')
bar = Path('bar')
foobar = Path('foo/bar')
eq_('bar',foobar[foo:])
eq_('foo',foobar[:bar])
eq_('foo/bar',foobar[bar:])
eq_('foo/bar',foobar[:foo])
eq_((),foobar[foobar:])
eq_((),foobar[:foobar])
abcd = Path('a/b/c/d')
a = Path('a')
b = Path('b')
c = Path('c')
d = Path('d')
z = Path('z')
eq_('b/c',abcd[a:d])
eq_('b/c/d',abcd[a:d+z])
eq_('b/c',abcd[a:z+d])
eq_('a/b/c/d',abcd[:z])
foo = Path("foo")
bar = Path("bar")
foobar = Path("foo/bar")
eq_("bar", foobar[foo:])
eq_("foo", foobar[:bar])
eq_("foo/bar", foobar[bar:])
eq_("foo/bar", foobar[:foo])
eq_((), foobar[foobar:])
eq_((), foobar[:foobar])
abcd = Path("a/b/c/d")
a = Path("a")
b = Path("b")
c = Path("c")
d = Path("d")
z = Path("z")
eq_("b/c", abcd[a:d])
eq_("b/c/d", abcd[a : d + z])
eq_("b/c", abcd[a : z + d])
eq_("a/b/c/d", abcd[:z])
def test_add_with_root_path(force_ossep):
"""if I perform /a/b/c + /d/e/f, I want /a/b/c/d/e/f, not /a/b/c//d/e/f
"""
eq_('/foo/bar',str(Path('/foo') + Path('/bar')))
eq_("/foo/bar", str(Path("/foo") + Path("/bar")))
def test_create_with_tuple_that_have_slash_inside(force_ossep, monkeypatch):
eq_(('','foo','bar'), Path(('/foo','bar')))
monkeypatch.setattr(os, 'sep', '\\')
eq_(('','foo','bar'), Path(('\\foo','bar')))
eq_(("", "foo", "bar"), Path(("/foo", "bar")))
monkeypatch.setattr(os, "sep", "\\")
eq_(("", "foo", "bar"), Path(("\\foo", "bar")))
def test_auto_decode_os_sep(force_ossep, monkeypatch):
"""Path should decode any either / or os.sep, but always encode in os.sep.
"""
eq_(('foo\\bar','bleh'),Path('foo\\bar/bleh'))
monkeypatch.setattr(os, 'sep', '\\')
eq_(('foo','bar/bleh'),Path('foo\\bar/bleh'))
path = Path('foo/bar')
eq_(('foo','bar'),path)
eq_('foo\\bar',str(path))
eq_(("foo\\bar", "bleh"), Path("foo\\bar/bleh"))
monkeypatch.setattr(os, "sep", "\\")
eq_(("foo", "bar/bleh"), Path("foo\\bar/bleh"))
path = Path("foo/bar")
eq_(("foo", "bar"), path)
eq_("foo\\bar", str(path))
def test_contains(force_ossep):
p = Path(('foo','bar'))
assert Path(('foo','bar','bleh')) in p
assert Path(('foo','bar')) in p
assert 'foo' in p
assert 'bleh' not in p
assert Path('foo') not in p
p = Path(("foo", "bar"))
assert Path(("foo", "bar", "bleh")) in p
assert Path(("foo", "bar")) in p
assert "foo" in p
assert "bleh" not in p
assert Path("foo") not in p
def test_is_parent_of(force_ossep):
assert Path(('foo','bar')).is_parent_of(Path(('foo','bar','bleh')))
assert not Path(('foo','bar')).is_parent_of(Path(('foo','baz')))
assert not Path(('foo','bar')).is_parent_of(Path(('foo','bar')))
assert Path(("foo", "bar")).is_parent_of(Path(("foo", "bar", "bleh")))
assert not Path(("foo", "bar")).is_parent_of(Path(("foo", "baz")))
assert not Path(("foo", "bar")).is_parent_of(Path(("foo", "bar")))
def test_windows_drive_letter(force_ossep):
p = Path(('c:',))
eq_('c:\\',str(p))
p = Path(("c:",))
eq_("c:\\", str(p))
def test_root_path(force_ossep):
p = Path('/')
eq_('/',str(p))
p = Path("/")
eq_("/", str(p))
def test_str_encodes_unicode_to_getfilesystemencoding(force_ossep):
p = Path(('foo','bar\u00e9'))
eq_('foo/bar\u00e9'.encode(sys.getfilesystemencoding()), p.tobytes())
p = Path(("foo", "bar\u00e9"))
eq_("foo/bar\u00e9".encode(sys.getfilesystemencoding()), p.tobytes())
def test_unicode(force_ossep):
p = Path(('foo','bar\u00e9'))
eq_('foo/bar\u00e9',str(p))
p = Path(("foo", "bar\u00e9"))
eq_("foo/bar\u00e9", str(p))
def test_str_repr_of_mix_between_non_ascii_str_and_unicode(force_ossep):
u = 'foo\u00e9'
u = "foo\u00e9"
encoded = u.encode(sys.getfilesystemencoding())
p = Path((encoded,'bar'))
p = Path((encoded, "bar"))
print(repr(tuple(p)))
eq_('foo\u00e9/bar'.encode(sys.getfilesystemencoding()), p.tobytes())
eq_("foo\u00e9/bar".encode(sys.getfilesystemencoding()), p.tobytes())
def test_Path_of_a_Path_returns_self(force_ossep):
#if Path() is called with a path as value, just return value.
p = Path('foo/bar')
# if Path() is called with a path as value, just return value.
p = Path("foo/bar")
assert Path(p) is p
def test_getitem_str(force_ossep):
# path['something'] returns the child path corresponding to the name
p = Path('/foo/bar')
eq_(p['baz'], Path('/foo/bar/baz'))
p = Path("/foo/bar")
eq_(p["baz"], Path("/foo/bar/baz"))
def test_getitem_path(force_ossep):
# path[Path('something')] returns the child path corresponding to the name (or subpath)
p = Path('/foo/bar')
eq_(p[Path('baz/bleh')], Path('/foo/bar/baz/bleh'))
p = Path("/foo/bar")
eq_(p[Path("baz/bleh")], Path("/foo/bar/baz/bleh"))
@mark.xfail(reason="pytest's capture mechanism is flaky, I have to investigate")
def test_log_unicode_errors(force_ossep, monkeypatch, capsys):
# When an there's a UnicodeDecodeError on path creation, log it so it can be possible
# to debug the cause of it.
monkeypatch.setattr(sys, 'getfilesystemencoding', lambda: 'ascii')
monkeypatch.setattr(sys, "getfilesystemencoding", lambda: "ascii")
with raises(UnicodeDecodeError):
Path(['', b'foo\xe9'])
Path(["", b"foo\xe9"])
out, err = capsys.readouterr()
assert repr(b'foo\xe9') in err
assert repr(b"foo\xe9") in err
def test_has_drive_letter(monkeypatch):
monkeypatch.setattr(os, 'sep', '\\')
p = Path('foo\\bar')
monkeypatch.setattr(os, "sep", "\\")
p = Path("foo\\bar")
assert not p.has_drive_letter()
p = Path('C:\\')
p = Path("C:\\")
assert p.has_drive_letter()
p = Path('z:\\foo')
p = Path("z:\\foo")
assert p.has_drive_letter()
def test_remove_drive_letter(monkeypatch):
monkeypatch.setattr(os, 'sep', '\\')
p = Path('foo\\bar')
eq_(p.remove_drive_letter(), Path('foo\\bar'))
p = Path('C:\\')
eq_(p.remove_drive_letter(), Path(''))
p = Path('z:\\foo')
eq_(p.remove_drive_letter(), Path('foo'))
monkeypatch.setattr(os, "sep", "\\")
p = Path("foo\\bar")
eq_(p.remove_drive_letter(), Path("foo\\bar"))
p = Path("C:\\")
eq_(p.remove_drive_letter(), Path(""))
p = Path("z:\\foo")
eq_(p.remove_drive_letter(), Path("foo"))
def test_pathify():
@pathify
def foo(a: Path, b, c:Path):
def foo(a: Path, b, c: Path):
return a, b, c
a, b, c = foo('foo', 0, c=Path('bar'))
a, b, c = foo("foo", 0, c=Path("bar"))
assert isinstance(a, Path)
assert a == Path('foo')
assert a == Path("foo")
assert b == 0
assert isinstance(c, Path)
assert c == Path('bar')
assert c == Path("bar")
def test_pathify_preserve_none():
# @pathify preserves None value and doesn't try to return a Path
@pathify
def foo(a: Path):
return a
a = foo(None)
assert a is None

View File

@@ -1,14 +1,15 @@
# Created By: Virgil Dupras
# Created On: 2011-09-06
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
from ..testutil import eq_, callcounter, CallLogger
from ..gui.selectable_list import SelectableList, GUISelectableList
def test_in():
# When a SelectableList is in a list, doing "in list" with another instance returns false, even
# if they're the same as lists.
@@ -16,50 +17,56 @@ def test_in():
some_list = [sl]
assert SelectableList() not in some_list
def test_selection_range():
# selection is correctly adjusted on deletion
sl = SelectableList(['foo', 'bar', 'baz'])
sl = SelectableList(["foo", "bar", "baz"])
sl.selected_index = 3
eq_(sl.selected_index, 2)
del sl[2]
eq_(sl.selected_index, 1)
def test_update_selection_called():
# _update_selection_is called after a change in selection. However, we only do so on select()
# calls. I follow the old behavior of the Table class. At the moment, I don't quite remember
# why there was a specific select() method for triggering _update_selection(), but I think I
# remember there was a reason, so I keep it that way.
sl = SelectableList(['foo', 'bar'])
sl = SelectableList(["foo", "bar"])
sl._update_selection = callcounter()
sl.select(1)
eq_(sl._update_selection.callcount, 1)
sl.selected_index = 0
eq_(sl._update_selection.callcount, 1) # no call
eq_(sl._update_selection.callcount, 1) # no call
def test_guicalls():
# A GUISelectableList appropriately calls its view.
sl = GUISelectableList(['foo', 'bar'])
sl = GUISelectableList(["foo", "bar"])
sl.view = CallLogger()
sl.view.check_gui_calls(['refresh']) # Upon setting the view, we get a call to refresh()
sl[1] = 'baz'
sl.view.check_gui_calls(['refresh'])
sl.append('foo')
sl.view.check_gui_calls(['refresh'])
sl.view.check_gui_calls(
["refresh"]
) # Upon setting the view, we get a call to refresh()
sl[1] = "baz"
sl.view.check_gui_calls(["refresh"])
sl.append("foo")
sl.view.check_gui_calls(["refresh"])
del sl[2]
sl.view.check_gui_calls(['refresh'])
sl.remove('baz')
sl.view.check_gui_calls(['refresh'])
sl.insert(0, 'foo')
sl.view.check_gui_calls(['refresh'])
sl.view.check_gui_calls(["refresh"])
sl.remove("baz")
sl.view.check_gui_calls(["refresh"])
sl.insert(0, "foo")
sl.view.check_gui_calls(["refresh"])
sl.select(1)
sl.view.check_gui_calls(['update_selection'])
sl.view.check_gui_calls(["update_selection"])
# XXX We have to give up on this for now because of a breakage it causes in the tables.
# sl.select(1) # don't update when selection stays the same
# gui.check_gui_calls([])
def test_search_by_prefix():
sl = SelectableList(['foo', 'bAr', 'baZ'])
eq_(sl.search_by_prefix('b'), 1)
eq_(sl.search_by_prefix('BA'), 1)
eq_(sl.search_by_prefix('BAZ'), 2)
eq_(sl.search_by_prefix('BAZZ'), -1)
sl = SelectableList(["foo", "bAr", "baZ"])
eq_(sl.search_by_prefix("b"), 1)
eq_(sl.search_by_prefix("BA"), 1)
eq_(sl.search_by_prefix("BAZ"), 2)
eq_(sl.search_by_prefix("BAZZ"), -1)

View File

@@ -2,8 +2,8 @@
# Created On: 2007/05/19
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
import time
@@ -19,69 +19,75 @@ from ..sqlite import ThreadedConn
# Threading is hard to test. In a lot of those tests, a failure means that the test run will
# hang forever. Well... I don't know a better alternative.
def test_can_access_from_multiple_threads():
def run():
con.execute('insert into foo(bar) values(\'baz\')')
con = ThreadedConn(':memory:', True)
con.execute('create table foo(bar TEXT)')
con.execute("insert into foo(bar) values('baz')")
con = ThreadedConn(":memory:", True)
con.execute("create table foo(bar TEXT)")
t = threading.Thread(target=run)
t.start()
t.join()
result = con.execute('select * from foo')
result = con.execute("select * from foo")
eq_(1, len(result))
eq_('baz', result[0][0])
eq_("baz", result[0][0])
def test_exception_during_query():
con = ThreadedConn(':memory:', True)
con.execute('create table foo(bar TEXT)')
con = ThreadedConn(":memory:", True)
con.execute("create table foo(bar TEXT)")
with raises(sqlite.OperationalError):
con.execute('select * from bleh')
con.execute("select * from bleh")
def test_not_autocommit(tmpdir):
dbpath = str(tmpdir.join('foo.db'))
dbpath = str(tmpdir.join("foo.db"))
con = ThreadedConn(dbpath, False)
con.execute('create table foo(bar TEXT)')
con.execute('insert into foo(bar) values(\'baz\')')
con.execute("create table foo(bar TEXT)")
con.execute("insert into foo(bar) values('baz')")
del con
#The data shouldn't have been inserted
# The data shouldn't have been inserted
con = ThreadedConn(dbpath, False)
result = con.execute('select * from foo')
result = con.execute("select * from foo")
eq_(0, len(result))
con.execute('insert into foo(bar) values(\'baz\')')
con.execute("insert into foo(bar) values('baz')")
con.commit()
del con
# Now the data should be there
con = ThreadedConn(dbpath, False)
result = con.execute('select * from foo')
result = con.execute("select * from foo")
eq_(1, len(result))
def test_rollback():
con = ThreadedConn(':memory:', False)
con.execute('create table foo(bar TEXT)')
con.execute('insert into foo(bar) values(\'baz\')')
con = ThreadedConn(":memory:", False)
con.execute("create table foo(bar TEXT)")
con.execute("insert into foo(bar) values('baz')")
con.rollback()
result = con.execute('select * from foo')
result = con.execute("select * from foo")
eq_(0, len(result))
def test_query_palceholders():
con = ThreadedConn(':memory:', True)
con.execute('create table foo(bar TEXT)')
con.execute('insert into foo(bar) values(?)', ['baz'])
result = con.execute('select * from foo')
con = ThreadedConn(":memory:", True)
con.execute("create table foo(bar TEXT)")
con.execute("insert into foo(bar) values(?)", ["baz"])
result = con.execute("select * from foo")
eq_(1, len(result))
eq_('baz', result[0][0])
eq_("baz", result[0][0])
def test_make_sure_theres_no_messup_between_queries():
def run(expected_rowid):
time.sleep(0.1)
result = con.execute('select rowid from foo where rowid = ?', [expected_rowid])
result = con.execute("select rowid from foo where rowid = ?", [expected_rowid])
assert expected_rowid == result[0][0]
con = ThreadedConn(':memory:', True)
con.execute('create table foo(bar TEXT)')
con = ThreadedConn(":memory:", True)
con.execute("create table foo(bar TEXT)")
for i in range(100):
con.execute('insert into foo(bar) values(\'baz\')')
con.execute("insert into foo(bar) values('baz')")
threads = []
for i in range(1, 101):
t = threading.Thread(target=run, args=(i,))
@@ -91,36 +97,41 @@ def test_make_sure_theres_no_messup_between_queries():
time.sleep(0.1)
threads = [t for t in threads if t.isAlive()]
def test_query_after_close():
con = ThreadedConn(':memory:', True)
con = ThreadedConn(":memory:", True)
con.close()
con.execute('select 1')
con.execute("select 1")
def test_lastrowid():
# It's not possible to return a cursor because of the threading, but lastrowid should be
# fetchable from the connection itself
con = ThreadedConn(':memory:', True)
con.execute('create table foo(bar TEXT)')
con.execute('insert into foo(bar) values(\'baz\')')
con = ThreadedConn(":memory:", True)
con.execute("create table foo(bar TEXT)")
con.execute("insert into foo(bar) values('baz')")
eq_(1, con.lastrowid)
def test_add_fetchone_fetchall_interface_to_results():
con = ThreadedConn(':memory:', True)
con.execute('create table foo(bar TEXT)')
con.execute('insert into foo(bar) values(\'baz1\')')
con.execute('insert into foo(bar) values(\'baz2\')')
result = con.execute('select * from foo')
con = ThreadedConn(":memory:", True)
con.execute("create table foo(bar TEXT)")
con.execute("insert into foo(bar) values('baz1')")
con.execute("insert into foo(bar) values('baz2')")
result = con.execute("select * from foo")
ref = result[:]
eq_(ref, result.fetchall())
eq_(ref[0], result.fetchone())
eq_(ref[1], result.fetchone())
assert result.fetchone() is None
def test_non_ascii_dbname(tmpdir):
ThreadedConn(str(tmpdir.join('foo\u00e9.db')), True)
ThreadedConn(str(tmpdir.join("foo\u00e9.db")), True)
def test_non_ascii_dbdir(tmpdir):
# when this test fails, it doesn't fail gracefully, it brings the whole test suite with it.
dbdir = tmpdir.join('foo\u00e9')
dbdir = tmpdir.join("foo\u00e9")
os.mkdir(str(dbdir))
ThreadedConn(str(dbdir.join('foo.db')), True)
ThreadedConn(str(dbdir.join("foo.db")), True)

View File

@@ -9,6 +9,7 @@
from ..testutil import CallLogger, eq_
from ..gui.table import Table, GUITable, Row
class TestRow(Row):
def __init__(self, table, index, is_new=False):
Row.__init__(self, table)
@@ -55,6 +56,7 @@ def table_with_footer():
table.footer = footer
return table, footer
def table_with_header():
table = Table()
table.append(TestRow(table, 1))
@@ -62,24 +64,28 @@ def table_with_header():
table.header = header
return table, header
#--- Tests
# --- Tests
def test_allow_edit_when_attr_is_property_with_fset():
# When a row has a property that has a fset, by default, make that cell editable.
class TestRow(Row):
@property
def foo(self):
pass
@property
def bar(self):
pass
@bar.setter
def bar(self, value):
pass
row = TestRow(Table())
assert row.can_edit_cell('bar')
assert not row.can_edit_cell('foo')
assert not row.can_edit_cell('baz') # doesn't exist, can't edit
assert row.can_edit_cell("bar")
assert not row.can_edit_cell("foo")
assert not row.can_edit_cell("baz") # doesn't exist, can't edit
def test_can_edit_prop_has_priority_over_fset_checks():
# When a row has a cen_edit_* property, it's the result of that property that is used, not the
@@ -88,13 +94,16 @@ def test_can_edit_prop_has_priority_over_fset_checks():
@property
def bar(self):
pass
@bar.setter
def bar(self, value):
pass
can_edit_bar = False
row = TestRow(Table())
assert not row.can_edit_cell('bar')
assert not row.can_edit_cell("bar")
def test_in():
# When a table is in a list, doing "in list" with another instance returns false, even if
@@ -103,12 +112,14 @@ def test_in():
some_list = [table]
assert Table() not in some_list
def test_footer_del_all():
# Removing all rows doesn't crash when doing the footer check.
table, footer = table_with_footer()
del table[:]
assert table.footer is None
def test_footer_del_row():
# Removing the footer row sets it to None
table, footer = table_with_footer()
@@ -116,18 +127,21 @@ def test_footer_del_row():
assert table.footer is None
eq_(len(table), 1)
def test_footer_is_appened_to_table():
# A footer is appended at the table's bottom
table, footer = table_with_footer()
eq_(len(table), 2)
assert table[1] is footer
def test_footer_remove():
# remove() on footer sets it to None
table, footer = table_with_footer()
table.remove(footer)
assert table.footer is None
def test_footer_replaces_old_footer():
table, footer = table_with_footer()
other = Row(table)
@@ -136,18 +150,21 @@ def test_footer_replaces_old_footer():
eq_(len(table), 2)
assert table[1] is other
def test_footer_rows_and_row_count():
# rows() and row_count() ignore footer.
table, footer = table_with_footer()
eq_(table.row_count, 1)
eq_(table.rows, table[:-1])
def test_footer_setting_to_none_removes_old_one():
table, footer = table_with_footer()
table.footer = None
assert table.footer is None
eq_(len(table), 1)
def test_footer_stays_there_on_append():
# Appending another row puts it above the footer
table, footer = table_with_footer()
@@ -155,6 +172,7 @@ def test_footer_stays_there_on_append():
eq_(len(table), 3)
assert table[2] is footer
def test_footer_stays_there_on_insert():
# Inserting another row puts it above the footer
table, footer = table_with_footer()
@@ -162,12 +180,14 @@ def test_footer_stays_there_on_insert():
eq_(len(table), 3)
assert table[2] is footer
def test_header_del_all():
# Removing all rows doesn't crash when doing the header check.
table, header = table_with_header()
del table[:]
assert table.header is None
def test_header_del_row():
# Removing the header row sets it to None
table, header = table_with_header()
@@ -175,18 +195,21 @@ def test_header_del_row():
assert table.header is None
eq_(len(table), 1)
def test_header_is_inserted_in_table():
# A header is inserted at the table's top
table, header = table_with_header()
eq_(len(table), 2)
assert table[0] is header
def test_header_remove():
# remove() on header sets it to None
table, header = table_with_header()
table.remove(header)
assert table.header is None
def test_header_replaces_old_header():
table, header = table_with_header()
other = Row(table)
@@ -195,18 +218,21 @@ def test_header_replaces_old_header():
eq_(len(table), 2)
assert table[0] is other
def test_header_rows_and_row_count():
# rows() and row_count() ignore header.
table, header = table_with_header()
eq_(table.row_count, 1)
eq_(table.rows, table[1:])
def test_header_setting_to_none_removes_old_one():
table, header = table_with_header()
table.header = None
assert table.header is None
eq_(len(table), 1)
def test_header_stays_there_on_insert():
# Inserting another row at the top puts it below the header
table, header = table_with_header()
@@ -214,21 +240,24 @@ def test_header_stays_there_on_insert():
eq_(len(table), 3)
assert table[0] is header
def test_refresh_view_on_refresh():
# If refresh_view is not False, we refresh the table's view on refresh()
table = TestGUITable(1)
table.refresh()
table.view.check_gui_calls(['refresh'])
table.view.check_gui_calls(["refresh"])
table.view.clear_calls()
table.refresh(refresh_view=False)
table.view.check_gui_calls([])
def test_restore_selection():
# By default, after a refresh, selection goes on the last row
table = TestGUITable(10)
table.refresh()
eq_(table.selected_indexes, [9])
def test_restore_selection_after_cancel_edits():
# _restore_selection() is called after cancel_edits(). Previously, only _update_selection would
# be called.
@@ -242,6 +271,7 @@ def test_restore_selection_after_cancel_edits():
table.cancel_edits()
eq_(table.selected_indexes, [6])
def test_restore_selection_with_previous_selection():
# By default, we try to restore the selection that was there before a refresh
table = TestGUITable(10)
@@ -250,6 +280,7 @@ def test_restore_selection_with_previous_selection():
table.refresh()
eq_(table.selected_indexes, [2, 4])
def test_restore_selection_custom():
# After a _fill() called, the virtual _restore_selection() is called so that it's possible for a
# GUITable subclass to customize its post-refresh selection behavior.
@@ -261,58 +292,64 @@ def test_restore_selection_custom():
table.refresh()
eq_(table.selected_indexes, [6])
def test_row_cell_value():
# *_cell_value() correctly mangles attrnames that are Python reserved words.
row = Row(Table())
row.from_ = 'foo'
eq_(row.get_cell_value('from'), 'foo')
row.set_cell_value('from', 'bar')
eq_(row.get_cell_value('from'), 'bar')
row.from_ = "foo"
eq_(row.get_cell_value("from"), "foo")
row.set_cell_value("from", "bar")
eq_(row.get_cell_value("from"), "bar")
def test_sort_table_also_tries_attributes_without_underscores():
# When determining a sort key, after having unsuccessfully tried the attribute with the,
# underscore, try the one without one.
table = Table()
row1 = Row(table)
row1._foo = 'a' # underscored attr must be checked first
row1.foo = 'b'
row1.bar = 'c'
row1._foo = "a" # underscored attr must be checked first
row1.foo = "b"
row1.bar = "c"
row2 = Row(table)
row2._foo = 'b'
row2.foo = 'a'
row2.bar = 'b'
row2._foo = "b"
row2.foo = "a"
row2.bar = "b"
table.append(row1)
table.append(row2)
table.sort_by('foo')
table.sort_by("foo")
assert table[0] is row1
assert table[1] is row2
table.sort_by('bar')
table.sort_by("bar")
assert table[0] is row2
assert table[1] is row1
def test_sort_table_updates_selection():
table = TestGUITable(10)
table.refresh()
table.select([2, 4])
table.sort_by('index', desc=True)
table.sort_by("index", desc=True)
# Now, the updated rows should be 7 and 5
eq_(len(table.updated_rows), 2)
r1, r2 = table.updated_rows
eq_(r1.index, 7)
eq_(r2.index, 5)
def test_sort_table_with_footer():
# Sorting a table with a footer keeps it at the bottom
table, footer = table_with_footer()
table.sort_by('index', desc=True)
table.sort_by("index", desc=True)
assert table[-1] is footer
def test_sort_table_with_header():
# Sorting a table with a header keeps it at the top
table, header = table_with_header()
table.sort_by('index', desc=True)
table.sort_by("index", desc=True)
assert table[0] is header
def test_add_with_view_that_saves_during_refresh():
# Calling save_edits during refresh() called by add() is ignored.
class TableView(CallLogger):
@@ -321,5 +358,4 @@ def test_add_with_view_that_saves_during_refresh():
table = TestGUITable(10, viewclass=TableView)
table.add()
assert table.edited is not None # still in edit mode
assert table.edited is not None # still in edit mode

View File

@@ -1,23 +1,25 @@
# Created By: Virgil Dupras
# Created On: 2010-02-12
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
#
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
from ..testutil import eq_
from ..gui.tree import Tree, Node
def tree_with_some_nodes():
t = Tree()
t.append(Node('foo'))
t.append(Node('bar'))
t.append(Node('baz'))
t[0].append(Node('sub1'))
t[0].append(Node('sub2'))
t.append(Node("foo"))
t.append(Node("bar"))
t.append(Node("baz"))
t[0].append(Node("sub1"))
t[0].append(Node("sub2"))
return t
def test_selection():
t = tree_with_some_nodes()
assert t.selected_node is None
@@ -25,6 +27,7 @@ def test_selection():
assert t.selected_path is None
eq_(t.selected_paths, [])
def test_select_one_node():
t = tree_with_some_nodes()
t.selected_node = t[0][0]
@@ -33,33 +36,39 @@ def test_select_one_node():
eq_(t.selected_path, [0, 0])
eq_(t.selected_paths, [[0, 0]])
def test_select_one_path():
t = tree_with_some_nodes()
t.selected_path = [0, 1]
assert t.selected_node is t[0][1]
def test_select_multiple_nodes():
t = tree_with_some_nodes()
t.selected_nodes = [t[0], t[1]]
eq_(t.selected_paths, [[0], [1]])
def test_select_multiple_paths():
t = tree_with_some_nodes()
t.selected_paths = [[0], [1]]
eq_(t.selected_nodes, [t[0], t[1]])
def test_select_none_path():
# setting selected_path to None clears the selection
t = Tree()
t.selected_path = None
assert t.selected_path is None
def test_select_none_node():
# setting selected_node to None clears the selection
t = Tree()
t.selected_node = None
eq_(t.selected_nodes, [])
def test_clear_removes_selection():
# When clearing a tree, we want to clear the selection as well or else we end up with a crash
# when calling selected_paths.
@@ -68,15 +77,16 @@ def test_clear_removes_selection():
t.clear()
assert t.selected_node is None
def test_selection_override():
# All selection changed pass through the _select_node() method so it's easy for subclasses to
# customize the tree's behavior.
class MyTree(Tree):
called = False
def _select_nodes(self, nodes):
self.called = True
t = MyTree()
t.selected_paths = []
assert t.called
@@ -84,26 +94,32 @@ def test_selection_override():
t.selected_node = None
assert t.called
def test_findall():
t = tree_with_some_nodes()
r = t.findall(lambda n: n.name.startswith('sub'))
r = t.findall(lambda n: n.name.startswith("sub"))
eq_(set(r), set([t[0][0], t[0][1]]))
def test_findall_dont_include_self():
# When calling findall with include_self=False, the node itself is never evaluated.
t = tree_with_some_nodes()
del t._name # so that if the predicate is called on `t`, we crash
r = t.findall(lambda n: not n.name.startswith('sub'), include_self=False) # no crash
del t._name # so that if the predicate is called on `t`, we crash
r = t.findall(
lambda n: not n.name.startswith("sub"), include_self=False
) # no crash
eq_(set(r), set([t[0], t[1], t[2]]))
def test_find_dont_include_self():
# When calling find with include_self=False, the node itself is never evaluated.
t = tree_with_some_nodes()
del t._name # so that if the predicate is called on `t`, we crash
r = t.find(lambda n: not n.name.startswith('sub'), include_self=False) # no crash
del t._name # so that if the predicate is called on `t`, we crash
r = t.find(lambda n: not n.name.startswith("sub"), include_self=False) # no crash
assert r is t[0]
def test_find_none():
# when find() yields no result, return None
t = Tree()
assert t.find(lambda n: False) is None # no StopIteration exception
assert t.find(lambda n: False) is None # no StopIteration exception

View File

@@ -14,43 +14,53 @@ from ..testutil import eq_
from ..path import Path
from ..util import *
def test_nonone():
eq_('foo', nonone('foo', 'bar'))
eq_('bar', nonone(None, 'bar'))
eq_("foo", nonone("foo", "bar"))
eq_("bar", nonone(None, "bar"))
def test_tryint():
eq_(42,tryint('42'))
eq_(0,tryint('abc'))
eq_(0,tryint(None))
eq_(42,tryint(None, 42))
eq_(42, tryint("42"))
eq_(0, tryint("abc"))
eq_(0, tryint(None))
eq_(42, tryint(None, 42))
def test_minmax():
eq_(minmax(2, 1, 3), 2)
eq_(minmax(0, 1, 3), 1)
eq_(minmax(4, 1, 3), 3)
#--- Sequence
# --- Sequence
def test_first():
eq_(first([3, 2, 1]), 3)
eq_(first(i for i in [3, 2, 1] if i < 3), 2)
def test_flatten():
eq_([1,2,3,4],flatten([[1,2],[3,4]]))
eq_([],flatten([]))
eq_([1, 2, 3, 4], flatten([[1, 2], [3, 4]]))
eq_([], flatten([]))
def test_dedupe():
reflist = [0,7,1,2,3,4,4,5,6,7,1,2,3]
eq_(dedupe(reflist),[0,7,1,2,3,4,5,6])
reflist = [0, 7, 1, 2, 3, 4, 4, 5, 6, 7, 1, 2, 3]
eq_(dedupe(reflist), [0, 7, 1, 2, 3, 4, 5, 6])
def test_stripfalse():
eq_([1, 2, 3], stripfalse([None, 0, 1, 2, 3, None]))
def test_extract():
wheat, shaft = extract(lambda n: n % 2 == 0, list(range(10)))
eq_(wheat, [0, 2, 4, 6, 8])
eq_(shaft, [1, 3, 5, 7, 9])
def test_allsame():
assert allsame([42, 42, 42])
assert not allsame([42, 43, 42])
@@ -58,25 +68,32 @@ def test_allsame():
# Works on non-sequence as well
assert allsame(iter([42, 42, 42]))
def test_trailiter():
eq_(list(trailiter([])), [])
eq_(list(trailiter(['foo'])), [(None, 'foo')])
eq_(list(trailiter(['foo', 'bar'])), [(None, 'foo'), ('foo', 'bar')])
eq_(list(trailiter(['foo', 'bar'], skipfirst=True)), [('foo', 'bar')])
eq_(list(trailiter([], skipfirst=True)), []) # no crash
eq_(list(trailiter(["foo"])), [(None, "foo")])
eq_(list(trailiter(["foo", "bar"])), [(None, "foo"), ("foo", "bar")])
eq_(list(trailiter(["foo", "bar"], skipfirst=True)), [("foo", "bar")])
eq_(list(trailiter([], skipfirst=True)), []) # no crash
def test_iterconsume():
# We just want to make sure that we return *all* items and that we're not mistakenly skipping
# one.
eq_(list(range(2500)), list(iterconsume(list(range(2500)))))
eq_(list(reversed(range(2500))), list(iterconsume(list(range(2500)), reverse=False)))
eq_(
list(reversed(range(2500))), list(iterconsume(list(range(2500)), reverse=False))
)
# --- String
#--- String
def test_escape():
eq_('f\\o\\ob\\ar', escape('foobar', 'oa'))
eq_('f*o*ob*ar', escape('foobar', 'oa', '*'))
eq_('f*o*ob*ar', escape('foobar', set('oa'), '*'))
eq_("f\\o\\ob\\ar", escape("foobar", "oa"))
eq_("f*o*ob*ar", escape("foobar", "oa", "*"))
eq_("f*o*ob*ar", escape("foobar", set("oa"), "*"))
def test_get_file_ext():
eq_(get_file_ext("foobar"), "")
@@ -84,146 +101,155 @@ def test_get_file_ext():
eq_(get_file_ext("foobar."), "")
eq_(get_file_ext(".foobar"), "foobar")
def test_rem_file_ext():
eq_(rem_file_ext("foobar"), "foobar")
eq_(rem_file_ext("foo.bar"), "foo")
eq_(rem_file_ext("foobar."), "foobar")
eq_(rem_file_ext(".foobar"), "")
def test_pluralize():
eq_('0 song', pluralize(0,'song'))
eq_('1 song', pluralize(1,'song'))
eq_('2 songs', pluralize(2,'song'))
eq_('1 song', pluralize(1.1,'song'))
eq_('2 songs', pluralize(1.5,'song'))
eq_('1.1 songs', pluralize(1.1,'song',1))
eq_('1.5 songs', pluralize(1.5,'song',1))
eq_('2 entries', pluralize(2,'entry', plural_word='entries'))
eq_("0 song", pluralize(0, "song"))
eq_("1 song", pluralize(1, "song"))
eq_("2 songs", pluralize(2, "song"))
eq_("1 song", pluralize(1.1, "song"))
eq_("2 songs", pluralize(1.5, "song"))
eq_("1.1 songs", pluralize(1.1, "song", 1))
eq_("1.5 songs", pluralize(1.5, "song", 1))
eq_("2 entries", pluralize(2, "entry", plural_word="entries"))
def test_format_time():
eq_(format_time(0),'00:00:00')
eq_(format_time(1),'00:00:01')
eq_(format_time(23),'00:00:23')
eq_(format_time(60),'00:01:00')
eq_(format_time(101),'00:01:41')
eq_(format_time(683),'00:11:23')
eq_(format_time(3600),'01:00:00')
eq_(format_time(3754),'01:02:34')
eq_(format_time(36000),'10:00:00')
eq_(format_time(366666),'101:51:06')
eq_(format_time(0, with_hours=False),'00:00')
eq_(format_time(1, with_hours=False),'00:01')
eq_(format_time(23, with_hours=False),'00:23')
eq_(format_time(60, with_hours=False),'01:00')
eq_(format_time(101, with_hours=False),'01:41')
eq_(format_time(683, with_hours=False),'11:23')
eq_(format_time(3600, with_hours=False),'60:00')
eq_(format_time(6036, with_hours=False),'100:36')
eq_(format_time(60360, with_hours=False),'1006:00')
eq_(format_time(0), "00:00:00")
eq_(format_time(1), "00:00:01")
eq_(format_time(23), "00:00:23")
eq_(format_time(60), "00:01:00")
eq_(format_time(101), "00:01:41")
eq_(format_time(683), "00:11:23")
eq_(format_time(3600), "01:00:00")
eq_(format_time(3754), "01:02:34")
eq_(format_time(36000), "10:00:00")
eq_(format_time(366666), "101:51:06")
eq_(format_time(0, with_hours=False), "00:00")
eq_(format_time(1, with_hours=False), "00:01")
eq_(format_time(23, with_hours=False), "00:23")
eq_(format_time(60, with_hours=False), "01:00")
eq_(format_time(101, with_hours=False), "01:41")
eq_(format_time(683, with_hours=False), "11:23")
eq_(format_time(3600, with_hours=False), "60:00")
eq_(format_time(6036, with_hours=False), "100:36")
eq_(format_time(60360, with_hours=False), "1006:00")
def test_format_time_decimal():
eq_(format_time_decimal(0), '0.0 second')
eq_(format_time_decimal(1), '1.0 second')
eq_(format_time_decimal(23), '23.0 seconds')
eq_(format_time_decimal(60), '1.0 minute')
eq_(format_time_decimal(101), '1.7 minutes')
eq_(format_time_decimal(683), '11.4 minutes')
eq_(format_time_decimal(3600), '1.0 hour')
eq_(format_time_decimal(6036), '1.7 hours')
eq_(format_time_decimal(86400), '1.0 day')
eq_(format_time_decimal(160360), '1.9 days')
eq_(format_time_decimal(0), "0.0 second")
eq_(format_time_decimal(1), "1.0 second")
eq_(format_time_decimal(23), "23.0 seconds")
eq_(format_time_decimal(60), "1.0 minute")
eq_(format_time_decimal(101), "1.7 minutes")
eq_(format_time_decimal(683), "11.4 minutes")
eq_(format_time_decimal(3600), "1.0 hour")
eq_(format_time_decimal(6036), "1.7 hours")
eq_(format_time_decimal(86400), "1.0 day")
eq_(format_time_decimal(160360), "1.9 days")
def test_format_size():
eq_(format_size(1024), '1 KB')
eq_(format_size(1024,2), '1.00 KB')
eq_(format_size(1024,0,2), '1 MB')
eq_(format_size(1024,2,2), '0.01 MB')
eq_(format_size(1024,3,2), '0.001 MB')
eq_(format_size(1024,3,2,False), '0.001')
eq_(format_size(1023), '1023 B')
eq_(format_size(1023,0,1), '1 KB')
eq_(format_size(511,0,1), '1 KB')
eq_(format_size(9), '9 B')
eq_(format_size(99), '99 B')
eq_(format_size(999), '999 B')
eq_(format_size(9999), '10 KB')
eq_(format_size(99999), '98 KB')
eq_(format_size(999999), '977 KB')
eq_(format_size(9999999), '10 MB')
eq_(format_size(99999999), '96 MB')
eq_(format_size(999999999), '954 MB')
eq_(format_size(9999999999), '10 GB')
eq_(format_size(99999999999), '94 GB')
eq_(format_size(999999999999), '932 GB')
eq_(format_size(9999999999999), '10 TB')
eq_(format_size(99999999999999), '91 TB')
eq_(format_size(999999999999999), '910 TB')
eq_(format_size(9999999999999999), '9 PB')
eq_(format_size(99999999999999999), '89 PB')
eq_(format_size(999999999999999999), '889 PB')
eq_(format_size(9999999999999999999), '9 EB')
eq_(format_size(99999999999999999999), '87 EB')
eq_(format_size(999999999999999999999), '868 EB')
eq_(format_size(9999999999999999999999), '9 ZB')
eq_(format_size(99999999999999999999999), '85 ZB')
eq_(format_size(999999999999999999999999), '848 ZB')
eq_(format_size(1024), "1 KB")
eq_(format_size(1024, 2), "1.00 KB")
eq_(format_size(1024, 0, 2), "1 MB")
eq_(format_size(1024, 2, 2), "0.01 MB")
eq_(format_size(1024, 3, 2), "0.001 MB")
eq_(format_size(1024, 3, 2, False), "0.001")
eq_(format_size(1023), "1023 B")
eq_(format_size(1023, 0, 1), "1 KB")
eq_(format_size(511, 0, 1), "1 KB")
eq_(format_size(9), "9 B")
eq_(format_size(99), "99 B")
eq_(format_size(999), "999 B")
eq_(format_size(9999), "10 KB")
eq_(format_size(99999), "98 KB")
eq_(format_size(999999), "977 KB")
eq_(format_size(9999999), "10 MB")
eq_(format_size(99999999), "96 MB")
eq_(format_size(999999999), "954 MB")
eq_(format_size(9999999999), "10 GB")
eq_(format_size(99999999999), "94 GB")
eq_(format_size(999999999999), "932 GB")
eq_(format_size(9999999999999), "10 TB")
eq_(format_size(99999999999999), "91 TB")
eq_(format_size(999999999999999), "910 TB")
eq_(format_size(9999999999999999), "9 PB")
eq_(format_size(99999999999999999), "89 PB")
eq_(format_size(999999999999999999), "889 PB")
eq_(format_size(9999999999999999999), "9 EB")
eq_(format_size(99999999999999999999), "87 EB")
eq_(format_size(999999999999999999999), "868 EB")
eq_(format_size(9999999999999999999999), "9 ZB")
eq_(format_size(99999999999999999999999), "85 ZB")
eq_(format_size(999999999999999999999999), "848 ZB")
def test_remove_invalid_xml():
eq_(remove_invalid_xml('foo\0bar\x0bbaz'), 'foo bar baz')
eq_(remove_invalid_xml("foo\0bar\x0bbaz"), "foo bar baz")
# surrogate blocks have to be replaced, but not the rest
eq_(remove_invalid_xml('foo\ud800bar\udfffbaz\ue000'), 'foo bar baz\ue000')
eq_(remove_invalid_xml("foo\ud800bar\udfffbaz\ue000"), "foo bar baz\ue000")
# replace with something else
eq_(remove_invalid_xml('foo\0baz', replace_with='bar'), 'foobarbaz')
eq_(remove_invalid_xml("foo\0baz", replace_with="bar"), "foobarbaz")
def test_multi_replace():
eq_('136',multi_replace('123456',('2','45')))
eq_('1 3 6',multi_replace('123456',('2','45'),' '))
eq_('1 3 6',multi_replace('123456','245',' '))
eq_('173896',multi_replace('123456','245','789'))
eq_('173896',multi_replace('123456','245',('7','8','9')))
eq_('17386',multi_replace('123456',('2','45'),'78'))
eq_('17386',multi_replace('123456',('2','45'),('7','8')))
eq_("136", multi_replace("123456", ("2", "45")))
eq_("1 3 6", multi_replace("123456", ("2", "45"), " "))
eq_("1 3 6", multi_replace("123456", "245", " "))
eq_("173896", multi_replace("123456", "245", "789"))
eq_("173896", multi_replace("123456", "245", ("7", "8", "9")))
eq_("17386", multi_replace("123456", ("2", "45"), "78"))
eq_("17386", multi_replace("123456", ("2", "45"), ("7", "8")))
with raises(ValueError):
multi_replace('123456',('2','45'),('7','8','9'))
eq_('17346',multi_replace('12346',('2','45'),'78'))
multi_replace("123456", ("2", "45"), ("7", "8", "9"))
eq_("17346", multi_replace("12346", ("2", "45"), "78"))
# --- Files
#--- Files
class TestCase_modified_after:
def test_first_is_modified_after(self, monkeyplus):
monkeyplus.patch_osstat('first', st_mtime=42)
monkeyplus.patch_osstat('second', st_mtime=41)
assert modified_after('first', 'second')
monkeyplus.patch_osstat("first", st_mtime=42)
monkeyplus.patch_osstat("second", st_mtime=41)
assert modified_after("first", "second")
def test_second_is_modified_after(self, monkeyplus):
monkeyplus.patch_osstat('first', st_mtime=42)
monkeyplus.patch_osstat('second', st_mtime=43)
assert not modified_after('first', 'second')
monkeyplus.patch_osstat("first", st_mtime=42)
monkeyplus.patch_osstat("second", st_mtime=43)
assert not modified_after("first", "second")
def test_same_mtime(self, monkeyplus):
monkeyplus.patch_osstat('first', st_mtime=42)
monkeyplus.patch_osstat('second', st_mtime=42)
assert not modified_after('first', 'second')
monkeyplus.patch_osstat("first", st_mtime=42)
monkeyplus.patch_osstat("second", st_mtime=42)
assert not modified_after("first", "second")
def test_first_file_does_not_exist(self, monkeyplus):
# when the first file doesn't exist, we return False
monkeyplus.patch_osstat('second', st_mtime=42)
assert not modified_after('does_not_exist', 'second') # no crash
monkeyplus.patch_osstat("second", st_mtime=42)
assert not modified_after("does_not_exist", "second") # no crash
def test_second_file_does_not_exist(self, monkeyplus):
# when the second file doesn't exist, we return True
monkeyplus.patch_osstat('first', st_mtime=42)
assert modified_after('first', 'does_not_exist') # no crash
monkeyplus.patch_osstat("first", st_mtime=42)
assert modified_after("first", "does_not_exist") # no crash
def test_first_file_is_none(self, monkeyplus):
# when the first file is None, we return False
monkeyplus.patch_osstat('second', st_mtime=42)
assert not modified_after(None, 'second') # no crash
monkeyplus.patch_osstat("second", st_mtime=42)
assert not modified_after(None, "second") # no crash
def test_second_file_is_none(self, monkeyplus):
# when the second file is None, we return True
monkeyplus.patch_osstat('first', st_mtime=42)
assert modified_after('first', None) # no crash
monkeyplus.patch_osstat("first", st_mtime=42)
assert modified_after("first", None) # no crash
class TestCase_delete_if_empty:
@@ -234,92 +260,91 @@ class TestCase_delete_if_empty:
def test_not_empty(self, tmpdir):
testpath = Path(str(tmpdir))
testpath['foo'].mkdir()
testpath["foo"].mkdir()
assert not delete_if_empty(testpath)
assert testpath.exists()
def test_with_files_to_delete(self, tmpdir):
testpath = Path(str(tmpdir))
testpath['foo'].open('w')
testpath['bar'].open('w')
assert delete_if_empty(testpath, ['foo', 'bar'])
testpath["foo"].open("w")
testpath["bar"].open("w")
assert delete_if_empty(testpath, ["foo", "bar"])
assert not testpath.exists()
def test_directory_in_files_to_delete(self, tmpdir):
testpath = Path(str(tmpdir))
testpath['foo'].mkdir()
assert not delete_if_empty(testpath, ['foo'])
testpath["foo"].mkdir()
assert not delete_if_empty(testpath, ["foo"])
assert testpath.exists()
def test_delete_files_to_delete_only_if_dir_is_empty(self, tmpdir):
testpath = Path(str(tmpdir))
testpath['foo'].open('w')
testpath['bar'].open('w')
assert not delete_if_empty(testpath, ['foo'])
testpath["foo"].open("w")
testpath["bar"].open("w")
assert not delete_if_empty(testpath, ["foo"])
assert testpath.exists()
assert testpath['foo'].exists()
assert testpath["foo"].exists()
def test_doesnt_exist(self):
# When the 'path' doesn't exist, just do nothing.
delete_if_empty(Path('does_not_exist')) # no crash
delete_if_empty(Path("does_not_exist")) # no crash
def test_is_file(self, tmpdir):
# When 'path' is a file, do nothing.
p = Path(str(tmpdir)) + 'filename'
p.open('w').close()
delete_if_empty(p) # no crash
p = Path(str(tmpdir)) + "filename"
p.open("w").close()
delete_if_empty(p) # no crash
def test_ioerror(self, tmpdir, monkeypatch):
# if an IO error happens during the operation, ignore it.
def do_raise(*args, **kw):
raise OSError()
monkeypatch.setattr(Path, 'rmdir', do_raise)
delete_if_empty(Path(str(tmpdir))) # no crash
monkeypatch.setattr(Path, "rmdir", do_raise)
delete_if_empty(Path(str(tmpdir))) # no crash
class TestCase_open_if_filename:
def test_file_name(self, tmpdir):
filepath = str(tmpdir.join('test.txt'))
open(filepath, 'wb').write(b'test_data')
filepath = str(tmpdir.join("test.txt"))
open(filepath, "wb").write(b"test_data")
file, close = open_if_filename(filepath)
assert close
eq_(b'test_data', file.read())
eq_(b"test_data", file.read())
file.close()
def test_opened_file(self):
sio = StringIO()
sio.write('test_data')
sio.write("test_data")
sio.seek(0)
file, close = open_if_filename(sio)
assert not close
eq_('test_data', file.read())
eq_("test_data", file.read())
def test_mode_is_passed_to_open(self, tmpdir):
filepath = str(tmpdir.join('test.txt'))
open(filepath, 'w').close()
file, close = open_if_filename(filepath, 'a')
eq_('a', file.mode)
filepath = str(tmpdir.join("test.txt"))
open(filepath, "w").close()
file, close = open_if_filename(filepath, "a")
eq_("a", file.mode)
file.close()
class TestCase_FileOrPath:
def test_path(self, tmpdir):
filepath = str(tmpdir.join('test.txt'))
open(filepath, 'wb').write(b'test_data')
filepath = str(tmpdir.join("test.txt"))
open(filepath, "wb").write(b"test_data")
with FileOrPath(filepath) as fp:
eq_(b'test_data', fp.read())
eq_(b"test_data", fp.read())
def test_opened_file(self):
sio = StringIO()
sio.write('test_data')
sio.write("test_data")
sio.seek(0)
with FileOrPath(sio) as fp:
eq_('test_data', fp.read())
eq_("test_data", fp.read())
def test_mode_is_passed_to_open(self, tmpdir):
filepath = str(tmpdir.join('test.txt'))
open(filepath, 'w').close()
with FileOrPath(filepath, 'a') as fp:
eq_('a', fp.mode)
filepath = str(tmpdir.join("test.txt"))
open(filepath, "w").close()
with FileOrPath(filepath, "a") as fp:
eq_("a", fp.mode)

View File

@@ -9,10 +9,12 @@
import threading
import py.path
def eq_(a, b, msg=None):
__tracebackhide__ = True
assert a == b, msg or "%r != %r" % (a, b)
def eq_sorted(a, b, msg=None):
"""If both a and b are iterable sort them and compare using eq_, otherwise just pass them through to eq_ anyway."""
try:
@@ -20,10 +22,12 @@ def eq_sorted(a, b, msg=None):
except TypeError:
eq_(a, b, msg)
def assert_almost_equal(a, b, places=7):
__tracebackhide__ = True
assert round(a, ndigits=places) == round(b, ndigits=places)
def callcounter():
def f(*args, **kwargs):
f.callcount += 1
@@ -31,6 +35,7 @@ def callcounter():
f.callcount = 0
return f
class TestData:
def __init__(self, datadirpath):
self.datadirpath = py.path.local(datadirpath)
@@ -53,12 +58,14 @@ class CallLogger:
It is used to simulate the GUI layer.
"""
def __init__(self):
self.calls = []
def __getattr__(self, func_name):
def func(*args, **kw):
self.calls.append(func_name)
return func
def clear_calls(self):
@@ -77,7 +84,9 @@ class CallLogger:
eq_(set(self.calls), set(expected))
self.clear_calls()
def check_gui_calls_partial(self, expected=None, not_expected=None, verify_order=False):
def check_gui_calls_partial(
self, expected=None, not_expected=None, verify_order=False
):
"""Checks that the expected calls have been made to 'self', then clears the log.
`expected` is an iterable of strings representing method names. Order doesn't matter.
@@ -88,17 +97,25 @@ class CallLogger:
__tracebackhide__ = True
if expected is not None:
not_called = set(expected) - set(self.calls)
assert not not_called, "These calls haven't been made: {0}".format(not_called)
assert not not_called, "These calls haven't been made: {0}".format(
not_called
)
if verify_order:
max_index = 0
for call in expected:
index = self.calls.index(call)
if index < max_index:
raise AssertionError("The call {0} hasn't been made in the correct order".format(call))
raise AssertionError(
"The call {0} hasn't been made in the correct order".format(
call
)
)
max_index = index
if not_expected is not None:
called = set(not_expected) & set(self.calls)
assert not called, "These calls shouldn't have been made: {0}".format(called)
assert not called, "These calls shouldn't have been made: {0}".format(
called
)
self.clear_calls()
@@ -124,7 +141,7 @@ class TestApp:
parent = self.default_parent
if holder is None:
holder = self
setattr(holder, '{0}_gui'.format(name), view)
setattr(holder, "{0}_gui".format(name), view)
gui = class_(parent)
gui.view = view
setattr(holder, name, gui)
@@ -136,38 +153,44 @@ def with_app(setupfunc):
def decorator(func):
func.setupfunc = setupfunc
return func
return decorator
def pytest_funcarg__app(request):
setupfunc = request.function.setupfunc
if hasattr(setupfunc, '__code__'):
argnames = setupfunc.__code__.co_varnames[:setupfunc.__code__.co_argcount]
if hasattr(setupfunc, "__code__"):
argnames = setupfunc.__code__.co_varnames[: setupfunc.__code__.co_argcount]
def getarg(name):
if name == 'self':
if name == "self":
return request.function.__self__
else:
return request.getfixturevalue(name)
args = [getarg(argname) for argname in argnames]
else:
args = []
app = setupfunc(*args)
return app
def jointhreads():
"""Join all threads to the main thread"""
for thread in threading.enumerate():
if hasattr(thread, 'BUGGY'):
if hasattr(thread, "BUGGY"):
continue
if thread.getName() != 'MainThread' and thread.isAlive():
if hasattr(thread, 'close'):
if thread.getName() != "MainThread" and thread.isAlive():
if hasattr(thread, "close"):
thread.close()
thread.join(1)
if thread.isAlive():
print("Thread problem. Some thread doesn't want to stop.")
thread.BUGGY = True
def _unify_args(func, args, kwargs, args_to_ignore=None):
''' Unify args and kwargs in the same dictionary.
""" Unify args and kwargs in the same dictionary.
The result is kwargs with args added to it. func.func_code.co_varnames is used to determine
under what key each elements of arg will be mapped in kwargs.
@@ -181,36 +204,40 @@ def _unify_args(func, args, kwargs, args_to_ignore=None):
def foo(bar, baz)
_unifyArgs(foo, (42,), {'baz': 23}) --> {'bar': 42, 'baz': 23}
_unifyArgs(foo, (42,), {'baz': 23}, ['bar']) --> {'baz': 23}
'''
"""
result = kwargs.copy()
if hasattr(func, '__code__'): # built-in functions don't have func_code
if hasattr(func, "__code__"): # built-in functions don't have func_code
args = list(args)
if getattr(func, '__self__', None) is not None: # bound method, we have to add self to args list
if (
getattr(func, "__self__", None) is not None
): # bound method, we have to add self to args list
args = [func.__self__] + args
defaults = list(func.__defaults__) if func.__defaults__ is not None else []
arg_count = func.__code__.co_argcount
arg_names = list(func.__code__.co_varnames)
if len(args) < arg_count: # We have default values
if len(args) < arg_count: # We have default values
required_arg_count = arg_count - len(args)
args = args + defaults[-required_arg_count:]
for arg_name, arg in zip(arg_names, args):
# setdefault is used because if the arg is already in kwargs, we don't want to use default values
result.setdefault(arg_name, arg)
else:
#'func' has a *args argument
result['args'] = args
# 'func' has a *args argument
result["args"] = args
if args_to_ignore:
for kw in args_to_ignore:
del result[kw]
return result
def log_calls(func):
''' Logs all func calls' arguments under func.calls.
""" Logs all func calls' arguments under func.calls.
func.calls is a list of _unify_args() result (dict).
Mostly used for unit testing.
'''
"""
def wrapper(*args, **kwargs):
unifiedArgs = _unify_args(func, args, kwargs)
wrapper.calls.append(unifiedArgs)
@@ -218,4 +245,3 @@ def log_calls(func):
wrapper.calls = []
return wrapper

View File

@@ -19,6 +19,7 @@ _trfunc = None
_trget = None
installed_lang = None
def tr(s, context=None):
if _trfunc is None:
return s
@@ -28,6 +29,7 @@ def tr(s, context=None):
else:
return _trfunc(s)
def trget(domain):
# Returns a tr() function for the specified domain.
if _trget is None:
@@ -35,57 +37,61 @@ def trget(domain):
else:
return _trget(domain)
def set_tr(new_tr, new_trget=None):
global _trfunc, _trget
_trfunc = new_tr
if new_trget is not None:
_trget = new_trget
def get_locale_name(lang):
if ISWINDOWS:
# http://msdn.microsoft.com/en-us/library/39cwe7zf(vs.71).aspx
LANG2LOCALENAME = {
'cs': 'czy',
'de': 'deu',
'el': 'grc',
'es': 'esn',
'fr': 'fra',
'it': 'ita',
'ko': 'korean',
'nl': 'nld',
'pl_PL': 'polish_poland',
'pt_BR': 'ptb',
'ru': 'rus',
'zh_CN': 'chs',
"cs": "czy",
"de": "deu",
"el": "grc",
"es": "esn",
"fr": "fra",
"it": "ita",
"ko": "korean",
"nl": "nld",
"pl_PL": "polish_poland",
"pt_BR": "ptb",
"ru": "rus",
"zh_CN": "chs",
}
else:
LANG2LOCALENAME = {
'cs': 'cs_CZ',
'de': 'de_DE',
'el': 'el_GR',
'es': 'es_ES',
'fr': 'fr_FR',
'it': 'it_IT',
'nl': 'nl_NL',
'hy': 'hy_AM',
'ko': 'ko_KR',
'pl_PL': 'pl_PL',
'pt_BR': 'pt_BR',
'ru': 'ru_RU',
'uk': 'uk_UA',
'vi': 'vi_VN',
'zh_CN': 'zh_CN',
"cs": "cs_CZ",
"de": "de_DE",
"el": "el_GR",
"es": "es_ES",
"fr": "fr_FR",
"it": "it_IT",
"nl": "nl_NL",
"hy": "hy_AM",
"ko": "ko_KR",
"pl_PL": "pl_PL",
"pt_BR": "pt_BR",
"ru": "ru_RU",
"uk": "uk_UA",
"vi": "vi_VN",
"zh_CN": "zh_CN",
}
if lang not in LANG2LOCALENAME:
return None
result = LANG2LOCALENAME[lang]
if ISLINUX:
result += '.UTF-8'
result += ".UTF-8"
return result
#--- Qt
# --- Qt
def install_qt_trans(lang=None):
from PyQt5.QtCore import QCoreApplication, QTranslator, QLocale
if not lang:
lang = str(QLocale.system().name())[:2]
localename = get_locale_name(lang)
@@ -95,54 +101,66 @@ def install_qt_trans(lang=None):
except locale.Error:
logging.warning("Couldn't set locale %s", localename)
else:
lang = 'en'
lang = "en"
qtr1 = QTranslator(QCoreApplication.instance())
qtr1.load(':/qt_%s' % lang)
qtr1.load(":/qt_%s" % lang)
QCoreApplication.installTranslator(qtr1)
qtr2 = QTranslator(QCoreApplication.instance())
qtr2.load(':/%s' % lang)
qtr2.load(":/%s" % lang)
QCoreApplication.installTranslator(qtr2)
def qt_tr(s, context='core'):
def qt_tr(s, context="core"):
return str(QCoreApplication.translate(context, s, None))
set_tr(qt_tr)
#--- gettext
# --- gettext
def install_gettext_trans(base_folder, lang):
import gettext
def gettext_trget(domain):
if not lang:
return lambda s: s
try:
return gettext.translation(domain, localedir=base_folder, languages=[lang]).gettext
return gettext.translation(
domain, localedir=base_folder, languages=[lang]
).gettext
except IOError:
return lambda s: s
default_gettext = gettext_trget('core')
default_gettext = gettext_trget("core")
def gettext_tr(s, context=None):
if not context:
return default_gettext(s)
else:
trfunc = gettext_trget(context)
return trfunc(s)
set_tr(gettext_tr, gettext_trget)
global installed_lang
installed_lang = lang
def install_gettext_trans_under_cocoa():
from cocoa import proxy
resFolder = proxy.getResourcePath()
baseFolder = op.join(resFolder, 'locale')
baseFolder = op.join(resFolder, "locale")
currentLang = proxy.systemLang()
install_gettext_trans(baseFolder, currentLang)
localename = get_locale_name(currentLang)
if localename is not None:
locale.setlocale(locale.LC_ALL, localename)
def install_gettext_trans_under_qt(base_folder, lang=None):
# So, we install the gettext locale, great, but we also should try to install qt_*.qm if
# available so that strings that are inside Qt itself over which I have no control are in the
# right language.
from PyQt5.QtCore import QCoreApplication, QTranslator, QLocale, QLibraryInfo
if not lang:
lang = str(QLocale.system().name())[:2]
localename = get_locale_name(lang)
@@ -151,7 +169,7 @@ def install_gettext_trans_under_qt(base_folder, lang=None):
locale.setlocale(locale.LC_ALL, localename)
except locale.Error:
logging.warning("Couldn't set locale %s", localename)
qmname = 'qt_%s' % lang
qmname = "qt_%s" % lang
if ISLINUX:
# Under linux, a full Qt installation is already available in the system, we didn't bundle
# up the qm files in our package, so we have to load translations from the system.

View File

@@ -17,6 +17,7 @@ from datetime import timedelta
from .path import Path, pathify, log_io_error
def nonone(value, replace_value):
"""Returns ``value`` if ``value`` is not ``None``. Returns ``replace_value`` otherwise.
"""
@@ -25,6 +26,7 @@ def nonone(value, replace_value):
else:
return value
def tryint(value, default=0):
"""Tries to convert ``value`` to in ``int`` and returns ``default`` if it fails.
"""
@@ -33,12 +35,15 @@ def tryint(value, default=0):
except (TypeError, ValueError):
return default
def minmax(value, min_value, max_value):
"""Returns `value` or one of the min/max bounds if `value` is not between them.
"""
return min(max(value, min_value), max_value)
#--- Sequence related
# --- Sequence related
def dedupe(iterable):
"""Returns a list of elements in ``iterable`` with all dupes removed.
@@ -54,6 +59,7 @@ def dedupe(iterable):
result.append(item)
return result
def flatten(iterables, start_with=None):
"""Takes a list of lists ``iterables`` and returns a list containing elements of every list.
@@ -67,6 +73,7 @@ def flatten(iterables, start_with=None):
result.extend(iterable)
return result
def first(iterable):
"""Returns the first item of ``iterable``.
"""
@@ -75,11 +82,13 @@ def first(iterable):
except StopIteration:
return None
def stripfalse(seq):
"""Returns a sequence with all false elements stripped out of seq.
"""
return [x for x in seq if x]
def extract(predicate, iterable):
"""Separates the wheat from the shaft (`predicate` defines what's the wheat), and returns both.
"""
@@ -92,6 +101,7 @@ def extract(predicate, iterable):
shaft.append(item)
return wheat, shaft
def allsame(iterable):
"""Returns whether all elements of 'iterable' are the same.
"""
@@ -102,6 +112,7 @@ def allsame(iterable):
raise ValueError("iterable cannot be empty")
return all(element == first_item for element in it)
def trailiter(iterable, skipfirst=False):
"""Yields (prev_element, element), starting with (None, first_element).
@@ -120,6 +131,7 @@ def trailiter(iterable, skipfirst=False):
yield prev, item
prev = item
def iterconsume(seq, reverse=True):
"""Iterate over ``seq`` and pops yielded objects.
@@ -135,31 +147,36 @@ def iterconsume(seq, reverse=True):
while seq:
yield seq.pop()
#--- String related
def escape(s, to_escape, escape_with='\\'):
# --- String related
def escape(s, to_escape, escape_with="\\"):
"""Returns ``s`` with characters in ``to_escape`` all prepended with ``escape_with``.
"""
return ''.join((escape_with + c if c in to_escape else c) for c in s)
return "".join((escape_with + c if c in to_escape else c) for c in s)
def get_file_ext(filename):
"""Returns the lowercase extension part of filename, without the dot.
"""
pos = filename.rfind('.')
pos = filename.rfind(".")
if pos > -1:
return filename[pos + 1:].lower()
return filename[pos + 1 :].lower()
else:
return ''
return ""
def rem_file_ext(filename):
"""Returns the filename without extension.
"""
pos = filename.rfind('.')
pos = filename.rfind(".")
if pos > -1:
return filename[:pos]
else:
return filename
def pluralize(number, word, decimals=0, plural_word=None):
"""Returns a pluralized string with ``number`` in front of ``word``.
@@ -173,11 +190,12 @@ def pluralize(number, word, decimals=0, plural_word=None):
format = "%%1.%df %%s" % decimals
if number > 1:
if plural_word is None:
word += 's'
word += "s"
else:
word = plural_word
return format % (number, word)
def format_time(seconds, with_hours=True):
"""Transforms seconds in a hh:mm:ss string.
@@ -189,14 +207,15 @@ def format_time(seconds, with_hours=True):
m, s = divmod(seconds, 60)
if with_hours:
h, m = divmod(m, 60)
r = '%02d:%02d:%02d' % (h, m, s)
r = "%02d:%02d:%02d" % (h, m, s)
else:
r = '%02d:%02d' % (m,s)
r = "%02d:%02d" % (m, s)
if minus:
return '-' + r
return "-" + r
else:
return r
def format_time_decimal(seconds):
"""Transforms seconds in a strings like '3.4 minutes'.
"""
@@ -204,20 +223,23 @@ def format_time_decimal(seconds):
if minus:
seconds *= -1
if seconds < 60:
r = pluralize(seconds, 'second', 1)
r = pluralize(seconds, "second", 1)
elif seconds < 3600:
r = pluralize(seconds / 60.0, 'minute', 1)
r = pluralize(seconds / 60.0, "minute", 1)
elif seconds < 86400:
r = pluralize(seconds / 3600.0, 'hour', 1)
r = pluralize(seconds / 3600.0, "hour", 1)
else:
r = pluralize(seconds / 86400.0, 'day', 1)
r = pluralize(seconds / 86400.0, "day", 1)
if minus:
return '-' + r
return "-" + r
else:
return r
SIZE_DESC = ('B','KB','MB','GB','TB','PB','EB','ZB','YB')
SIZE_VALS = tuple(1024 ** i for i in range(1,9))
SIZE_DESC = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
SIZE_VALS = tuple(1024 ** i for i in range(1, 9))
def format_size(size, decimal=0, forcepower=-1, showdesc=True):
"""Transform a byte count in a formatted string (KB, MB etc..).
@@ -238,12 +260,12 @@ def format_size(size, decimal=0, forcepower=-1, showdesc=True):
else:
i = forcepower
if i > 0:
div = SIZE_VALS[i-1]
div = SIZE_VALS[i - 1]
else:
div = 1
format = '%%%d.%df' % (decimal,decimal)
format = "%%%d.%df" % (decimal, decimal)
negative = size < 0
divided_size = ((0.0 + abs(size)) / div)
divided_size = (0.0 + abs(size)) / div
if decimal == 0:
divided_size = ceil(divided_size)
else:
@@ -252,18 +274,21 @@ def format_size(size, decimal=0, forcepower=-1, showdesc=True):
divided_size *= -1
result = format % divided_size
if showdesc:
result += ' ' + SIZE_DESC[i]
result += " " + SIZE_DESC[i]
return result
_valid_xml_range = '\x09\x0A\x0D\x20-\uD7FF\uE000-\uFFFD'
if sys.maxunicode > 0x10000:
_valid_xml_range += '%s-%s' % (chr(0x10000), chr(min(sys.maxunicode, 0x10FFFF)))
RE_INVALID_XML_SUB = re.compile('[^%s]' % _valid_xml_range, re.U).sub
def remove_invalid_xml(s, replace_with=' '):
_valid_xml_range = "\x09\x0A\x0D\x20-\uD7FF\uE000-\uFFFD"
if sys.maxunicode > 0x10000:
_valid_xml_range += "%s-%s" % (chr(0x10000), chr(min(sys.maxunicode, 0x10FFFF)))
RE_INVALID_XML_SUB = re.compile("[^%s]" % _valid_xml_range, re.U).sub
def remove_invalid_xml(s, replace_with=" "):
return RE_INVALID_XML_SUB(replace_with, s)
def multi_replace(s, replace_from, replace_to=''):
def multi_replace(s, replace_from, replace_to=""):
"""A function like str.replace() with multiple replacements.
``replace_from`` is a list of things you want to replace. Ex: ['a','bc','d']
@@ -280,17 +305,20 @@ def multi_replace(s, replace_from, replace_to=''):
if isinstance(replace_to, str) and (len(replace_from) != len(replace_to)):
replace_to = [replace_to for r in replace_from]
if len(replace_from) != len(replace_to):
raise ValueError('len(replace_from) must be equal to len(replace_to)')
raise ValueError("len(replace_from) must be equal to len(replace_to)")
replace = list(zip(replace_from, replace_to))
for r_from, r_to in [r for r in replace if r[0] in s]:
s = s.replace(r_from, r_to)
return s
#--- Date related
# --- Date related
# It might seem like needless namespace pollution, but the speedup gained by this constant is
# significant, so it stays.
ONE_DAY = timedelta(1)
def iterdaterange(start, end):
"""Yields every day between ``start`` and ``end``.
"""
@@ -299,7 +327,9 @@ def iterdaterange(start, end):
yield date
date += ONE_DAY
#--- Files related
# --- Files related
@pathify
def modified_after(first_path: Path, second_path: Path):
@@ -317,19 +347,21 @@ def modified_after(first_path: Path, second_path: Path):
return True
return first_mtime > second_mtime
def find_in_path(name, paths=None):
"""Search for `name` in all directories of `paths` and return the absolute path of the first
occurrence. If `paths` is None, $PATH is used.
"""
if paths is None:
paths = os.environ['PATH']
if isinstance(paths, str): # if it's not a string, it's already a list
paths = os.environ["PATH"]
if isinstance(paths, str): # if it's not a string, it's already a list
paths = paths.split(os.pathsep)
for path in paths:
if op.exists(op.join(path, name)):
return op.join(path, name)
return None
@log_io_error
@pathify
def delete_if_empty(path: Path, files_to_delete=[]):
@@ -345,7 +377,8 @@ def delete_if_empty(path: Path, files_to_delete=[]):
path.rmdir()
return True
def open_if_filename(infile, mode='rb'):
def open_if_filename(infile, mode="rb"):
"""If ``infile`` is a string, it opens and returns it. If it's already a file object, it simply returns it.
This function returns ``(file, should_close_flag)``. The should_close_flag is True is a file has
@@ -364,15 +397,18 @@ def open_if_filename(infile, mode='rb'):
else:
return (infile, False)
def ensure_folder(path):
"Create `path` as a folder if it doesn't exist."
if not op.exists(path):
os.makedirs(path)
def ensure_file(path):
"Create `path` as an empty file if it doesn't exist."
if not op.exists(path):
open(path, 'w').close()
open(path, "w").close()
def delete_files_with_pattern(folder_path, pattern, recursive=True):
"""Delete all files (or folders) in `folder_path` that match the glob `pattern`.
@@ -389,6 +425,7 @@ def delete_files_with_pattern(folder_path, pattern, recursive=True):
for p in subfolders:
delete_files_with_pattern(p, pattern, True)
class FileOrPath:
"""Does the same as :func:`open_if_filename`, but it can be used with a ``with`` statement.
@@ -397,7 +434,8 @@ class FileOrPath:
with FileOrPath(infile):
dostuff()
"""
def __init__(self, file_or_path, mode='rb'):
def __init__(self, file_or_path, mode="rb"):
self.file_or_path = file_or_path
self.mode = mode
self.mustclose = False
@@ -410,4 +448,3 @@ class FileOrPath:
def __exit__(self, exc_type, exc_value, traceback):
if self.fp and self.mustclose:
self.fp.close()