Type hinting hscommon & cleanup

This commit is contained in:
Andrew Senetar 2022-05-09 23:36:39 -05:00
parent 58863b1728
commit 7865e4aeac
Signed by: arsenetar
GPG Key ID: C63300DCE48AB2F1
10 changed files with 88 additions and 398 deletions

5
hscommon/.gitignore vendored
View File

@ -1,5 +0,0 @@
*.pyc
*.mo
*.so
.DS_Store
/docs_html

View File

@ -9,6 +9,7 @@
"""This module is a collection of function to help in HS apps build process. """This module is a collection of function to help in HS apps build process.
""" """
from argparse import ArgumentParser
import os import os
import sys import sys
import os.path as op import os.path as op
@ -20,18 +21,19 @@ import re
import importlib import importlib
from datetime import datetime from datetime import datetime
import glob import glob
from typing import Any, AnyStr, Callable, Dict, List, Union
from hscommon.plat import ISWINDOWS from hscommon.plat import ISWINDOWS
def print_and_do(cmd): def print_and_do(cmd: str) -> int:
"""Prints ``cmd`` and executes it in the shell.""" """Prints ``cmd`` and executes it in the shell."""
print(cmd) print(cmd)
p = Popen(cmd, shell=True) p = Popen(cmd, shell=True)
return p.wait() return p.wait()
def _perform(src, dst, action, actionname): def _perform(src: os.PathLike, dst: os.PathLike, action: Callable, actionname: str) -> None:
if not op.lexists(src): if not op.lexists(src):
print("Copying %s failed: it doesn't exist." % src) print("Copying %s failed: it doesn't exist." % src)
return return
@ -44,30 +46,22 @@ def _perform(src, dst, action, actionname):
action(src, dst) action(src, dst)
def copy_file_or_folder(src, dst): def copy_file_or_folder(src: os.PathLike, dst: os.PathLike) -> None:
if op.isdir(src): if op.isdir(src):
shutil.copytree(src, dst, symlinks=True) shutil.copytree(src, dst, symlinks=True)
else: else:
shutil.copy(src, dst) shutil.copy(src, dst)
def move(src, dst): def move(src: os.PathLike, dst: os.PathLike) -> None:
_perform(src, dst, os.rename, "Moving") _perform(src, dst, os.rename, "Moving")
def copy(src, dst): def copy(src: os.PathLike, dst: os.PathLike) -> None:
_perform(src, dst, copy_file_or_folder, "Copying") _perform(src, dst, copy_file_or_folder, "Copying")
def symlink(src, dst): def _perform_on_all(pattern: AnyStr, dst: os.PathLike, action: Callable) -> None:
_perform(src, dst, os.symlink, "Symlinking")
def hardlink(src, dst):
_perform(src, dst, os.link, "Hardlinking")
def _perform_on_all(pattern, dst, action):
# pattern is a glob pattern, example "folder/foo*". The file is moved directly in dst, no folder # pattern is a glob pattern, example "folder/foo*". The file is moved directly in dst, no folder
# structure from src is kept. # structure from src is kept.
filenames = glob.glob(pattern) filenames = glob.glob(pattern)
@ -76,22 +70,15 @@ def _perform_on_all(pattern, dst, action):
action(fn, destpath) action(fn, destpath)
def move_all(pattern, dst): def move_all(pattern: AnyStr, dst: os.PathLike) -> None:
_perform_on_all(pattern, dst, move) _perform_on_all(pattern, dst, move)
def copy_all(pattern, dst): def copy_all(pattern: AnyStr, dst: os.PathLike) -> None:
_perform_on_all(pattern, dst, copy) _perform_on_all(pattern, dst, copy)
def ensure_empty_folder(path): def filereplace(filename: os.PathLike, outfilename: Union[os.PathLike, None] = None, **kwargs) -> None:
"""Make sure that the path exists and that it's an empty folder."""
if op.exists(path):
shutil.rmtree(path)
os.mkdir(path)
def filereplace(filename, outfilename=None, **kwargs):
"""Reads `filename`, replaces all {variables} in kwargs, and writes the result to `outfilename`.""" """Reads `filename`, replaces all {variables} in kwargs, and writes the result to `outfilename`."""
if outfilename is None: if outfilename is None:
outfilename = filename outfilename = filename
@ -106,12 +93,12 @@ def filereplace(filename, outfilename=None, **kwargs):
fp.close() fp.close()
def get_module_version(modulename): def get_module_version(modulename: str) -> str:
mod = importlib.import_module(modulename) mod = importlib.import_module(modulename)
return mod.__version__ return mod.__version__
def setup_package_argparser(parser): def setup_package_argparser(parser: ArgumentParser):
parser.add_argument( parser.add_argument(
"--sign", "--sign",
dest="sign_identity", dest="sign_identity",
@ -138,7 +125,7 @@ def setup_package_argparser(parser):
# `args` come from an ArgumentParser updated with setup_package_argparser() # `args` come from an ArgumentParser updated with setup_package_argparser()
def package_cocoa_app_in_dmg(app_path, destfolder, args): def package_cocoa_app_in_dmg(app_path: os.PathLike, destfolder: os.PathLike, args) -> None:
# Rather than signing our app in XCode during the build phase, we sign it during the package # Rather than signing our app in XCode during the build phase, we sign it during the package
# phase because running the app before packaging can modify it and we want to be sure to have # phase because running the app before packaging can modify it and we want to be sure to have
# a valid signature. # a valid signature.
@ -154,13 +141,14 @@ def package_cocoa_app_in_dmg(app_path, destfolder, args):
build_dmg(app_path, destfolder) build_dmg(app_path, destfolder)
def build_dmg(app_path, destfolder): def build_dmg(app_path: os.PathLike, destfolder: os.PathLike) -> None:
"""Builds a DMG volume with application at ``app_path`` and puts it in ``dest_path``. """Builds a DMG volume with application at ``app_path`` and puts it in ``dest_path``.
The name of the resulting DMG volume is determined by the app's name and version. The name of the resulting DMG volume is determined by the app's name and version.
""" """
print(repr(op.join(app_path, "Contents", "Info.plist"))) print(repr(op.join(app_path, "Contents", "Info.plist")))
plist = plistlib.readPlist(op.join(app_path, "Contents", "Info.plist")) with open(op.join(app_path, "Contents", "Info.plist"), "rb") as fp:
plist = plistlib.load(fp)
workpath = tempfile.mkdtemp() workpath = tempfile.mkdtemp()
dmgpath = op.join(workpath, plist["CFBundleName"]) dmgpath = op.join(workpath, plist["CFBundleName"])
os.mkdir(dmgpath) os.mkdir(dmgpath)
@ -178,7 +166,7 @@ def build_dmg(app_path, destfolder):
print("Build Complete") print("Build Complete")
def add_to_pythonpath(path): def add_to_pythonpath(path: os.PathLike) -> None:
"""Adds ``path`` to both ``PYTHONPATH`` env and ``sys.path``.""" """Adds ``path`` to both ``PYTHONPATH`` env and ``sys.path``."""
abspath = op.abspath(path) abspath = op.abspath(path)
pythonpath = os.environ.get("PYTHONPATH", "") pythonpath = os.environ.get("PYTHONPATH", "")
@ -191,7 +179,12 @@ def add_to_pythonpath(path):
# This is a method to hack around those freakingly tricky data inclusion/exlusion rules # This is a method to hack around those freakingly tricky data inclusion/exlusion rules
# in setuptools. We copy the packages *without data* in a build folder and then build the plugin # in setuptools. We copy the packages *without data* in a build folder and then build the plugin
# from there. # from there.
def copy_packages(packages_names, dest, create_links=False, extra_ignores=None): def copy_packages(
packages_names: List[str],
dest: os.PathLike,
create_links: bool = False,
extra_ignores: Union[List[str], None] = None,
) -> None:
"""Copy python packages ``packages_names`` to ``dest``, spurious data. """Copy python packages ``packages_names`` to ``dest``, spurious data.
Copy will happen without tests, testdata, mercurial data or C extension module source with it. Copy will happen without tests, testdata, mercurial data or C extension module source with it.
@ -229,13 +222,13 @@ def copy_packages(packages_names, dest, create_links=False, extra_ignores=None):
def build_debian_changelog( def build_debian_changelog(
changelogpath, changelogpath: os.PathLike,
destfile, destfile: os.PathLike,
pkgname, pkgname: str,
from_version=None, from_version: Union[str, None] = None,
distribution="precise", distribution: str = "precise",
fix_version=None, fix_version: Union[str, None] = None,
): ) -> None:
"""Builds a debian changelog out of a YAML changelog. """Builds a debian changelog out of a YAML changelog.
Use fix_version to patch the top changelog to that version (if, for example, there was a Use fix_version to patch the top changelog to that version (if, for example, there was a
@ -288,7 +281,7 @@ def build_debian_changelog(
re_changelog_header = re.compile(r"=== ([\d.b]*) \(([\d\-]*)\)") re_changelog_header = re.compile(r"=== ([\d.b]*) \(([\d\-]*)\)")
def read_changelog_file(filename): def read_changelog_file(filename: os.PathLike) -> List[Dict[str, Any]]:
def iter_by_three(it): def iter_by_three(it):
while True: while True:
try: try:
@ -315,7 +308,7 @@ def read_changelog_file(filename):
return result return result
def fix_qt_resource_file(path): def fix_qt_resource_file(path: os.PathLike) -> None:
# pyrcc5 under Windows, if the locale is non-english, can produce a source file with a date # pyrcc5 under Windows, if the locale is non-english, can produce a source file with a date
# containing accented characters. If it does, the encoding is wrong and it prevents the file # containing accented characters. If it does, the encoding is wrong and it prevents the file
# from being correctly frozen by cx_freeze. To work around that, we open the file, strip all # from being correctly frozen by cx_freeze. To work around that, we open the file, strip all

View File

@ -1,30 +0,0 @@
# Copyright 2016 Virgil Dupras
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
import argparse
from setuptools import setup, Extension
def get_parser():
parser = argparse.ArgumentParser(description="Build an arbitrary Python extension.")
parser.add_argument("source_files", nargs="+", help="List of source files to compile")
parser.add_argument("name", nargs=1, help="Name of the resulting extension")
return parser
def main():
args = get_parser().parse_args()
print(f"Building {args.name[0]}...")
ext = Extension(args.name[0], args.source_files)
setup(
script_args=["build_ext", "--inplace"],
ext_modules=[ext],
)
if __name__ == "__main__":
main()

View File

@ -15,6 +15,7 @@ import os
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Callable, List
# This matches [123], but not [12] (3 digits being the minimum). # This matches [123], but not [12] (3 digits being the minimum).
# It also matches [1234] [12345] etc.. # It also matches [1234] [12345] etc..
@ -22,7 +23,7 @@ from pathlib import Path
re_conflict = re.compile(r"^\[\d{3}\d*\] ") re_conflict = re.compile(r"^\[\d{3}\d*\] ")
def get_conflicted_name(other_names, name): def get_conflicted_name(other_names: List[str], name: str) -> str:
"""Returns name with a ``[000]`` number in front of it. """Returns name with a ``[000]`` number in front of it.
The number between brackets depends on how many conlicted filenames The number between brackets depends on how many conlicted filenames
@ -39,7 +40,7 @@ def get_conflicted_name(other_names, name):
i += 1 i += 1
def get_unconflicted_name(name): def get_unconflicted_name(name: str) -> str:
"""Returns ``name`` without ``[]`` brackets. """Returns ``name`` without ``[]`` brackets.
Brackets which, of course, might have been added by func:`get_conflicted_name`. Brackets which, of course, might have been added by func:`get_conflicted_name`.
@ -47,12 +48,12 @@ def get_unconflicted_name(name):
return re_conflict.sub("", name, 1) return re_conflict.sub("", name, 1)
def is_conflicted(name): def is_conflicted(name: str) -> bool:
"""Returns whether ``name`` is prepended with a bracketed number.""" """Returns whether ``name`` is prepended with a bracketed number."""
return re_conflict.match(name) is not None return re_conflict.match(name) is not None
def _smart_move_or_copy(operation, source_path: Path, dest_path: Path): def _smart_move_or_copy(operation: Callable, source_path: Path, dest_path: Path) -> None:
"""Use move() or copy() to move and copy file with the conflict management.""" """Use move() or copy() to move and copy file with the conflict management."""
if dest_path.is_dir() and not source_path.is_dir(): if dest_path.is_dir() and not source_path.is_dir():
dest_path = dest_path.joinpath(source_path.name) dest_path = dest_path.joinpath(source_path.name)
@ -64,12 +65,12 @@ def _smart_move_or_copy(operation, source_path: Path, dest_path: Path):
operation(str(source_path), str(dest_path)) operation(str(source_path), str(dest_path))
def smart_move(source_path, dest_path): def smart_move(source_path: Path, dest_path: Path) -> None:
"""Same as :func:`smart_copy`, but it moves files instead.""" """Same as :func:`smart_copy`, but it moves files instead."""
_smart_move_or_copy(shutil.move, source_path, dest_path) _smart_move_or_copy(shutil.move, source_path, dest_path)
def smart_copy(source_path, dest_path): def smart_copy(source_path: Path, dest_path: Path) -> None:
"""Copies ``source_path`` to ``dest_path``, recursively and with conflict resolution.""" """Copies ``source_path`` to ``dest_path``, recursively and with conflict resolution."""
try: try:
_smart_move_or_copy(shutil.copy, source_path, dest_path) _smart_move_or_copy(shutil.copy, source_path, dest_path)

View File

@ -2,6 +2,7 @@ import os
import os.path as op import os.path as op
import shutil import shutil
import tempfile import tempfile
from typing import Any, List
import polib import polib
@ -10,15 +11,15 @@ from hscommon import pygettext
LC_MESSAGES = "LC_MESSAGES" LC_MESSAGES = "LC_MESSAGES"
def get_langs(folder): def get_langs(folder: str) -> List[str]:
return [name for name in os.listdir(folder) if op.isdir(op.join(folder, name))] return [name for name in os.listdir(folder) if op.isdir(op.join(folder, name))]
def files_with_ext(folder, ext): def files_with_ext(folder: str, ext: str) -> List[str]:
return [op.join(folder, fn) for fn in os.listdir(folder) if fn.endswith(ext)] return [op.join(folder, fn) for fn in os.listdir(folder) if fn.endswith(ext)]
def generate_pot(folders, outpath, keywords, merge=False): def generate_pot(folders: List[str], outpath: str, keywords: Any, merge: bool = False) -> None:
if merge and not op.exists(outpath): if merge and not op.exists(outpath):
merge = False merge = False
if merge: if merge:
@ -39,7 +40,7 @@ def generate_pot(folders, outpath, keywords, merge=False):
print("Exception while removing temporary folder %s\n", genpath) print("Exception while removing temporary folder %s\n", genpath)
def compile_all_po(base_folder): def compile_all_po(base_folder: str) -> None:
langs = get_langs(base_folder) langs = get_langs(base_folder)
for lang in langs: for lang in langs:
pofolder = op.join(base_folder, lang, LC_MESSAGES) pofolder = op.join(base_folder, lang, LC_MESSAGES)
@ -49,7 +50,7 @@ def compile_all_po(base_folder):
p.save_as_mofile(pofile[:-3] + ".mo") p.save_as_mofile(pofile[:-3] + ".mo")
def merge_locale_dir(target, mergeinto): def merge_locale_dir(target: str, mergeinto: str) -> None:
langs = get_langs(target) langs = get_langs(target)
for lang in langs: for lang in langs:
if not op.exists(op.join(mergeinto, lang)): if not op.exists(op.join(mergeinto, lang)):
@ -60,7 +61,7 @@ def merge_locale_dir(target, mergeinto):
shutil.copy(mofile, op.join(mergeinto, lang, LC_MESSAGES)) shutil.copy(mofile, op.join(mergeinto, lang, LC_MESSAGES))
def merge_pots_into_pos(folder): def merge_pots_into_pos(folder: str) -> None:
# We're going to take all pot files in `folder` and for each lang, merge it with the po file # We're going to take all pot files in `folder` and for each lang, merge it with the po file
# with the same name. # with the same name.
potfiles = files_with_ext(folder, ".pot") potfiles = files_with_ext(folder, ".pot")
@ -73,7 +74,7 @@ def merge_pots_into_pos(folder):
po.save() po.save()
def merge_po_and_preserve(source, dest): def merge_po_and_preserve(source: str, dest: str) -> None:
# Merges source entries into dest, but keep old entries intact # Merges source entries into dest, but keep old entries intact
sourcepo = polib.pofile(source) sourcepo = polib.pofile(source)
destpo = polib.pofile(dest) destpo = polib.pofile(dest)
@ -85,7 +86,7 @@ def merge_po_and_preserve(source, dest):
destpo.save() destpo.save()
def normalize_all_pos(base_folder): def normalize_all_pos(base_folder: str) -> None:
"""Normalize the format of .po files in base_folder. """Normalize the format of .po files in base_folder.
When getting POs from external sources, such as Transifex, we end up with spurious diffs because When getting POs from external sources, such as Transifex, we end up with spurious diffs because

View File

@ -13,6 +13,7 @@ the method with the same name as the broadcasted message is called on the listen
""" """
from collections import defaultdict from collections import defaultdict
from typing import Callable, DefaultDict, List
class Broadcaster: class Broadcaster:
@ -21,10 +22,10 @@ class Broadcaster:
def __init__(self): def __init__(self):
self.listeners = set() self.listeners = set()
def add_listener(self, listener): def add_listener(self, listener: "Listener") -> None:
self.listeners.add(listener) self.listeners.add(listener)
def notify(self, msg): def notify(self, msg: str) -> None:
"""Notify all connected listeners of ``msg``. """Notify all connected listeners of ``msg``.
That means that each listeners will have their method with the same name as ``msg`` called. That means that each listeners will have their method with the same name as ``msg`` called.
@ -33,18 +34,18 @@ class Broadcaster:
if listener in self.listeners: # disconnected during notification if listener in self.listeners: # disconnected during notification
listener.dispatch(msg) listener.dispatch(msg)
def remove_listener(self, listener): def remove_listener(self, listener: "Listener") -> None:
self.listeners.discard(listener) self.listeners.discard(listener)
class Listener: class Listener:
"""A listener is initialized with the broadcaster it's going to listen to. Initially, it is not connected.""" """A listener is initialized with the broadcaster it's going to listen to. Initially, it is not connected."""
def __init__(self, broadcaster): def __init__(self, broadcaster: Broadcaster) -> None:
self.broadcaster = broadcaster self.broadcaster = broadcaster
self._bound_notifications = defaultdict(list) self._bound_notifications: DefaultDict[str, List[Callable]] = defaultdict(list)
def bind_messages(self, messages, func): def bind_messages(self, messages: str, func: Callable) -> None:
"""Binds multiple message to the same function. """Binds multiple message to the same function.
Often, we perform the same thing on multiple messages. Instead of having the same function Often, we perform the same thing on multiple messages. Instead of having the same function
@ -54,15 +55,15 @@ class Listener:
for message in messages: for message in messages:
self._bound_notifications[message].append(func) self._bound_notifications[message].append(func)
def connect(self): def connect(self) -> None:
"""Connects the listener to its broadcaster.""" """Connects the listener to its broadcaster."""
self.broadcaster.add_listener(self) self.broadcaster.add_listener(self)
def disconnect(self): def disconnect(self) -> None:
"""Disconnects the listener from its broadcaster.""" """Disconnects the listener from its broadcaster."""
self.broadcaster.remove_listener(self) self.broadcaster.remove_listener(self)
def dispatch(self, msg): def dispatch(self, msg: str) -> None:
if msg in self._bound_notifications: if msg in self._bound_notifications:
for func in self._bound_notifications[msg]: for func in self._bound_notifications[msg]:
func() func()
@ -74,14 +75,14 @@ class Listener:
class Repeater(Broadcaster, Listener): class Repeater(Broadcaster, Listener):
REPEATED_NOTIFICATIONS = None REPEATED_NOTIFICATIONS = None
def __init__(self, broadcaster): def __init__(self, broadcaster: Broadcaster) -> None:
Broadcaster.__init__(self) Broadcaster.__init__(self)
Listener.__init__(self, broadcaster) Listener.__init__(self, broadcaster)
def _repeat_message(self, msg): def _repeat_message(self, msg: str) -> None:
if not self.REPEATED_NOTIFICATIONS or msg in self.REPEATED_NOTIFICATIONS: if not self.REPEATED_NOTIFICATIONS or msg in self.REPEATED_NOTIFICATIONS:
self.notify(msg) self.notify(msg)
def dispatch(self, msg): def dispatch(self, msg: str) -> None:
Listener.dispatch(self, msg) Listener.dispatch(self, msg)
self._repeat_message(msg) self._repeat_message(msg)

View File

@ -6,6 +6,7 @@
from pathlib import Path from pathlib import Path
import re import re
from typing import Callable, Dict, Union
from hscommon.build import read_changelog_file, filereplace from hscommon.build import read_changelog_file, filereplace
from sphinx.cmd.build import build_main as sphinx_build from sphinx.cmd.build import build_main as sphinx_build
@ -18,7 +19,7 @@ CHANGELOG_FORMAT = """
""" """
def tixgen(tixurl): def tixgen(tixurl: str) -> Callable[[str], str]:
"""This is a filter *generator*. tixurl is a url pattern for the tix with a {0} placeholder """This is a filter *generator*. tixurl is a url pattern for the tix with a {0} placeholder
for the tix # for the tix #
""" """
@ -29,14 +30,14 @@ def tixgen(tixurl):
def gen( def gen(
basepath, basepath: Path,
destpath, destpath: Path,
changelogpath, changelogpath: Path,
tixurl, tixurl: str,
confrepl=None, confrepl: Union[Dict[str, str], None] = None,
confpath=None, confpath: Union[Path, None] = None,
changelogtmpl=None, changelogtmpl: Union[Path, None] = None,
): ) -> None:
"""Generate sphinx docs with all bells and whistles. """Generate sphinx docs with all bells and whistles.
basepath: The base sphinx source path. basepath: The base sphinx source path.

View File

@ -1,141 +0,0 @@
# Created By: Virgil Dupras
# Created On: 2007/05/19
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
import os
import os.path as op
import threading
from queue import Queue
import sqlite3 as sqlite
STOP = object()
COMMIT = object()
ROLLBACK = object()
class FakeCursor(list):
# It's not possible to use sqlite cursors on another thread than the connection. Thus,
# we can't directly return the cursor. We have to fatch all results, and support its interface.
def fetchall(self):
return self
def fetchone(self):
try:
return self.pop(0)
except IndexError:
return None
class _ActualThread(threading.Thread):
"""We can't use this class directly because thread object are not automatically freed when
nothing refers to it, making it hang the application if not explicitely closed.
"""
def __init__(self, dbname, autocommit):
threading.Thread.__init__(self)
self._queries = Queue()
self._results = Queue()
self._dbname = dbname
self._autocommit = autocommit
self._waiting_list = set()
self._lock = threading.Lock()
self._run = True
self.lastrowid = -1
self.daemon = True
self.start()
def _query(self, query):
with self._lock:
wait_token = object()
self._waiting_list.add(wait_token)
self._queries.put(query)
self._waiting_list.remove(wait_token)
result = self._results.get()
return result
def close(self):
if not self._run:
return
self._query(STOP)
def commit(self):
if not self._run:
return None # Connection closed
self._query(COMMIT)
def execute(self, sql, values=()):
if not self._run:
return None # Connection closed
result = self._query((sql, values))
if isinstance(result, Exception):
raise result
return result
def rollback(self):
if not self._run:
return None # Connection closed
self._query(ROLLBACK)
def run(self):
# The whole chdir thing is because sqlite doesn't handle directory names with non-asci char in the AT ALL.
oldpath = os.getcwd()
dbdir, dbname = op.split(self._dbname)
if dbdir:
os.chdir(dbdir)
if self._autocommit:
con = sqlite.connect(dbname, isolation_level=None)
else:
con = sqlite.connect(dbname)
os.chdir(oldpath)
while self._run or self._waiting_list:
query = self._queries.get()
result = None
if query is STOP:
self._run = False
elif query is COMMIT:
con.commit()
elif query is ROLLBACK:
con.rollback()
else:
sql, values = query
try:
cur = con.execute(sql, values)
self.lastrowid = cur.lastrowid
result = FakeCursor(cur.fetchall())
result.lastrowid = cur.lastrowid
except Exception as e:
result = e
self._results.put(result)
con.close()
class ThreadedConn:
"""``sqlite`` connections can't be used across threads. ``TheadedConn`` opens a sqlite
connection in its own thread and sends it queries through a queue, making it suitable in
multi-threaded environment.
"""
def __init__(self, dbname, autocommit):
self._t = _ActualThread(dbname, autocommit)
self.lastrowid = -1
def __del__(self):
self.close()
def close(self):
self._t.close()
def commit(self):
self._t.commit()
def execute(self, sql, values=()):
result = self._t.execute(sql, values)
self.lastrowid = self._t.lastrowid
return result
def rollback(self):
self._t.rollback()

View File

@ -1,137 +0,0 @@
# Created By: Virgil Dupras
# Created On: 2007/05/19
# Copyright 2015 Hardcoded Software (http://www.hardcoded.net)
# This software is licensed under the "GPLv3" License as described in the "LICENSE" file,
# which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html
import time
import threading
import os
import sqlite3 as sqlite
from pytest import raises
from hscommon.testutil import eq_
from hscommon.sqlite import ThreadedConn
# Threading is hard to test. In a lot of those tests, a failure means that the test run will
# hang forever. Well... I don't know a better alternative.
def test_can_access_from_multiple_threads():
def run():
con.execute("insert into foo(bar) values('baz')")
con = ThreadedConn(":memory:", True)
con.execute("create table foo(bar TEXT)")
t = threading.Thread(target=run)
t.start()
t.join()
result = con.execute("select * from foo")
eq_(1, len(result))
eq_("baz", result[0][0])
def test_exception_during_query():
con = ThreadedConn(":memory:", True)
con.execute("create table foo(bar TEXT)")
with raises(sqlite.OperationalError):
con.execute("select * from bleh")
def test_not_autocommit(tmpdir):
dbpath = str(tmpdir.join("foo.db"))
con = ThreadedConn(dbpath, False)
con.execute("create table foo(bar TEXT)")
con.execute("insert into foo(bar) values('baz')")
del con
# The data shouldn't have been inserted
con = ThreadedConn(dbpath, False)
result = con.execute("select * from foo")
eq_(0, len(result))
con.execute("insert into foo(bar) values('baz')")
con.commit()
del con
# Now the data should be there
con = ThreadedConn(dbpath, False)
result = con.execute("select * from foo")
eq_(1, len(result))
def test_rollback():
con = ThreadedConn(":memory:", False)
con.execute("create table foo(bar TEXT)")
con.execute("insert into foo(bar) values('baz')")
con.rollback()
result = con.execute("select * from foo")
eq_(0, len(result))
def test_query_palceholders():
con = ThreadedConn(":memory:", True)
con.execute("create table foo(bar TEXT)")
con.execute("insert into foo(bar) values(?)", ["baz"])
result = con.execute("select * from foo")
eq_(1, len(result))
eq_("baz", result[0][0])
def test_make_sure_theres_no_messup_between_queries():
def run(expected_rowid):
time.sleep(0.1)
result = con.execute("select rowid from foo where rowid = ?", [expected_rowid])
assert expected_rowid == result[0][0]
con = ThreadedConn(":memory:", True)
con.execute("create table foo(bar TEXT)")
for i in range(100):
con.execute("insert into foo(bar) values('baz')")
threads = []
for i in range(1, 101):
t = threading.Thread(target=run, args=(i,))
t.start()
threads.append(t)
while threads:
time.sleep(0.1)
threads = [t for t in threads if t.is_alive()]
def test_query_after_close():
con = ThreadedConn(":memory:", True)
con.close()
con.execute("select 1")
def test_lastrowid():
# It's not possible to return a cursor because of the threading, but lastrowid should be
# fetchable from the connection itself
con = ThreadedConn(":memory:", True)
con.execute("create table foo(bar TEXT)")
con.execute("insert into foo(bar) values('baz')")
eq_(1, con.lastrowid)
def test_add_fetchone_fetchall_interface_to_results():
con = ThreadedConn(":memory:", True)
con.execute("create table foo(bar TEXT)")
con.execute("insert into foo(bar) values('baz1')")
con.execute("insert into foo(bar) values('baz2')")
result = con.execute("select * from foo")
ref = result[:]
eq_(ref, result.fetchall())
eq_(ref[0], result.fetchone())
eq_(ref[1], result.fetchone())
assert result.fetchone() is None
def test_non_ascii_dbname(tmpdir):
ThreadedConn(str(tmpdir.join("foo\u00e9.db")), True)
def test_non_ascii_dbdir(tmpdir):
# when this test fails, it doesn't fail gracefully, it brings the whole test suite with it.
dbdir = tmpdir.join("foo\u00e9")
os.mkdir(str(dbdir))
ThreadedConn(str(dbdir.join("foo.db")), True)

View File

@ -11,7 +11,9 @@
import locale import locale
import logging import logging
import os
import os.path as op import os.path as op
from typing import Callable, Union
from hscommon.plat import ISLINUX from hscommon.plat import ISLINUX
@ -20,7 +22,7 @@ _trget = None
installed_lang = None installed_lang = None
def tr(s, context=None): def tr(s: str, context: Union[str, None] = None) -> str:
if _trfunc is None: if _trfunc is None:
return s return s
else: else:
@ -30,7 +32,7 @@ def tr(s, context=None):
return _trfunc(s) return _trfunc(s)
def trget(domain): def trget(domain: str) -> Callable[[str], str]:
# Returns a tr() function for the specified domain. # Returns a tr() function for the specified domain.
if _trget is None: if _trget is None:
return lambda s: tr(s, domain) return lambda s: tr(s, domain)
@ -38,14 +40,16 @@ def trget(domain):
return _trget(domain) return _trget(domain)
def set_tr(new_tr, new_trget=None): def set_tr(
new_tr: Callable[[str, Union[str, None]], str], new_trget: Union[Callable[[str], Callable[[str], str]], None] = None
) -> None:
global _trfunc, _trget global _trfunc, _trget
_trfunc = new_tr _trfunc = new_tr
if new_trget is not None: if new_trget is not None:
_trget = new_trget _trget = new_trget
def get_locale_name(lang): def get_locale_name(lang: str) -> Union[str, None]:
# Removed old conversion code as windows seems to support these # Removed old conversion code as windows seems to support these
LANG2LOCALENAME = { LANG2LOCALENAME = {
"cs": "cs_CZ", "cs": "cs_CZ",
@ -77,7 +81,7 @@ def get_locale_name(lang):
# --- Qt # --- Qt
def install_qt_trans(lang=None): def install_qt_trans(lang: str = None) -> None:
from PyQt5.QtCore import QCoreApplication, QTranslator, QLocale from PyQt5.QtCore import QCoreApplication, QTranslator, QLocale
if not lang: if not lang:
@ -97,17 +101,19 @@ def install_qt_trans(lang=None):
qtr2.load(":/%s" % lang) qtr2.load(":/%s" % lang)
QCoreApplication.installTranslator(qtr2) QCoreApplication.installTranslator(qtr2)
def qt_tr(s, context="core"): def qt_tr(s: str, context: Union[str, None] = "core") -> str:
if context is None:
context = "core"
return str(QCoreApplication.translate(context, s, None)) return str(QCoreApplication.translate(context, s, None))
set_tr(qt_tr) set_tr(qt_tr)
# --- gettext # --- gettext
def install_gettext_trans(base_folder, lang): def install_gettext_trans(base_folder: os.PathLike, lang: str) -> None:
import gettext import gettext
def gettext_trget(domain): def gettext_trget(domain: str) -> Callable[[str], str]:
if not lang: if not lang:
return lambda s: s return lambda s: s
try: try:
@ -117,7 +123,7 @@ def install_gettext_trans(base_folder, lang):
default_gettext = gettext_trget("core") default_gettext = gettext_trget("core")
def gettext_tr(s, context=None): def gettext_tr(s: str, context: Union[str, None] = None) -> str:
if not context: if not context:
return default_gettext(s) return default_gettext(s)
else: else:
@ -129,7 +135,7 @@ def install_gettext_trans(base_folder, lang):
installed_lang = lang installed_lang = lang
def install_gettext_trans_under_qt(base_folder, lang=None): def install_gettext_trans_under_qt(base_folder: os.PathLike, lang: str = None) -> None:
# So, we install the gettext locale, great, but we also should try to install qt_*.qm if # So, we install the gettext locale, great, but we also should try to install qt_*.qm if
# available so that strings that are inside Qt itself over which I have no control are in the # available so that strings that are inside Qt itself over which I have no control are in the
# right language. # right language.