Removed unused code in hscommon/util

Also added type hints throughout
This commit is contained in:
Andrew Senetar 2022-05-09 00:47:57 -05:00
parent 40ff40bea8
commit f587c7b5d8
Signed by: arsenetar
GPG Key ID: C63300DCE48AB2F1
2 changed files with 30 additions and 174 deletions

View File

@ -15,18 +15,14 @@ from pathlib import Path
from ..util import ( from ..util import (
nonone, nonone,
tryint, tryint,
minmax,
first, first,
flatten, flatten,
dedupe, dedupe,
stripfalse,
extract, extract,
allsame, allsame,
trailiter,
format_time, format_time,
format_time_decimal, format_time_decimal,
format_size, format_size,
remove_invalid_xml,
multi_replace, multi_replace,
delete_if_empty, delete_if_empty,
open_if_filename, open_if_filename,
@ -51,12 +47,6 @@ def test_tryint():
eq_(42, tryint(None, 42)) eq_(42, tryint(None, 42))
def test_minmax():
eq_(minmax(2, 1, 3), 2)
eq_(minmax(0, 1, 3), 1)
eq_(minmax(4, 1, 3), 3)
# --- Sequence # --- Sequence
@ -75,10 +65,6 @@ def test_dedupe():
eq_(dedupe(reflist), [0, 7, 1, 2, 3, 4, 5, 6]) eq_(dedupe(reflist), [0, 7, 1, 2, 3, 4, 5, 6])
def test_stripfalse():
eq_([1, 2, 3], stripfalse([None, 0, 1, 2, 3, None]))
def test_extract(): def test_extract():
wheat, shaft = extract(lambda n: n % 2 == 0, list(range(10))) wheat, shaft = extract(lambda n: n % 2 == 0, list(range(10)))
eq_(wheat, [0, 2, 4, 6, 8]) eq_(wheat, [0, 2, 4, 6, 8])
@ -93,14 +79,6 @@ def test_allsame():
assert allsame(iter([42, 42, 42])) assert allsame(iter([42, 42, 42]))
def test_trailiter():
eq_(list(trailiter([])), [])
eq_(list(trailiter(["foo"])), [(None, "foo")])
eq_(list(trailiter(["foo", "bar"])), [(None, "foo"), ("foo", "bar")])
eq_(list(trailiter(["foo", "bar"], skipfirst=True)), [("foo", "bar")])
eq_(list(trailiter([], skipfirst=True)), []) # no crash
def test_iterconsume(): def test_iterconsume():
# We just want to make sure that we return *all* items and that we're not mistakenly skipping # We just want to make sure that we return *all* items and that we're not mistakenly skipping
# one. # one.
@ -213,14 +191,6 @@ def test_format_size():
eq_(format_size(999999999999999999999999), "848 ZB") eq_(format_size(999999999999999999999999), "848 ZB")
def test_remove_invalid_xml():
eq_(remove_invalid_xml("foo\0bar\x0bbaz"), "foo bar baz")
# surrogate blocks have to be replaced, but not the rest
eq_(remove_invalid_xml("foo\ud800bar\udfffbaz\ue000"), "foo bar baz\ue000")
# replace with something else
eq_(remove_invalid_xml("foo\0baz", replace_with="bar"), "foobarbaz")
def test_multi_replace(): def test_multi_replace():
eq_("136", multi_replace("123456", ("2", "45"))) eq_("136", multi_replace("123456", ("2", "45")))
eq_("1 3 6", multi_replace("123456", ("2", "45"), " ")) eq_("1 3 6", multi_replace("123456", ("2", "45"), " "))

View File

@ -6,20 +6,14 @@
# which should be included with this package. The terms are also available at # which should be included with this package. The terms are also available at
# http://www.gnu.org/licenses/gpl-3.0.html # http://www.gnu.org/licenses/gpl-3.0.html
import sys
import os
import os.path as op
import re
from math import ceil from math import ceil
import glob
import shutil
from datetime import timedelta
from pathlib import Path from pathlib import Path
from .path import pathify, log_io_error from .path import pathify, log_io_error
from typing import IO, Any, Callable, Generator, Iterable, List, Tuple, Union
def nonone(value, replace_value):
def nonone(value: Any, replace_value: Any) -> Any:
"""Returns ``value`` if ``value`` is not ``None``. Returns ``replace_value`` otherwise.""" """Returns ``value`` if ``value`` is not ``None``. Returns ``replace_value`` otherwise."""
if value is None: if value is None:
return replace_value return replace_value
@ -27,7 +21,7 @@ def nonone(value, replace_value):
return value return value
def tryint(value, default=0): def tryint(value: Any, default: int = 0) -> int:
"""Tries to convert ``value`` to in ``int`` and returns ``default`` if it fails.""" """Tries to convert ``value`` to in ``int`` and returns ``default`` if it fails."""
try: try:
return int(value) return int(value)
@ -35,15 +29,10 @@ def tryint(value, default=0):
return default return default
def minmax(value, min_value, max_value):
"""Returns `value` or one of the min/max bounds if `value` is not between them."""
return min(max(value, min_value), max_value)
# --- Sequence related # --- Sequence related
def dedupe(iterable): def dedupe(iterable: Iterable[Any]) -> List[Any]:
"""Returns a list of elements in ``iterable`` with all dupes removed. """Returns a list of elements in ``iterable`` with all dupes removed.
The order of the elements is preserved. The order of the elements is preserved.
@ -58,13 +47,13 @@ def dedupe(iterable):
return result return result
def flatten(iterables, start_with=None): def flatten(iterables: Iterable[Iterable], start_with: Iterable[Any] = None) -> List[Any]:
"""Takes a list of lists ``iterables`` and returns a list containing elements of every list. """Takes a list of lists ``iterables`` and returns a list containing elements of every list.
If ``start_with`` is not ``None``, the result will start with ``start_with`` items, exactly as If ``start_with`` is not ``None``, the result will start with ``start_with`` items, exactly as
if ``start_with`` would be the first item of lists. if ``start_with`` would be the first item of lists.
""" """
result = [] result: List[Any] = []
if start_with: if start_with:
result.extend(start_with) result.extend(start_with)
for iterable in iterables: for iterable in iterables:
@ -72,7 +61,7 @@ def flatten(iterables, start_with=None):
return result return result
def first(iterable): def first(iterable: Iterable[Any]):
"""Returns the first item of ``iterable``.""" """Returns the first item of ``iterable``."""
try: try:
return next(iter(iterable)) return next(iter(iterable))
@ -80,12 +69,7 @@ def first(iterable):
return None return None
def stripfalse(seq): def extract(predicate: Callable[[Any], bool], iterable: Iterable[Any]) -> Tuple[List[Any], List[Any]]:
"""Returns a sequence with all false elements stripped out of seq."""
return [x for x in seq if x]
def extract(predicate, iterable):
"""Separates the wheat from the shaft (`predicate` defines what's the wheat), and returns both.""" """Separates the wheat from the shaft (`predicate` defines what's the wheat), and returns both."""
wheat = [] wheat = []
shaft = [] shaft = []
@ -97,7 +81,7 @@ def extract(predicate, iterable):
return wheat, shaft return wheat, shaft
def allsame(iterable): def allsame(iterable: Iterable[Any]) -> bool:
"""Returns whether all elements of 'iterable' are the same.""" """Returns whether all elements of 'iterable' are the same."""
it = iter(iterable) it = iter(iterable)
try: try:
@ -107,26 +91,7 @@ def allsame(iterable):
return all(element == first_item for element in it) return all(element == first_item for element in it)
def trailiter(iterable, skipfirst=False): def iterconsume(seq: List[Any], reverse: bool = True) -> Generator[Any, None, None]:
"""Yields (prev_element, element), starting with (None, first_element).
If skipfirst is True, there will be no (None, item1) element and we'll start
directly with (item1, item2).
"""
it = iter(iterable)
if skipfirst:
try:
prev = next(it)
except StopIteration:
return
else:
prev = None
for item in it:
yield prev, item
prev = item
def iterconsume(seq, reverse=True):
"""Iterate over ``seq`` and pops yielded objects. """Iterate over ``seq`` and pops yielded objects.
Because we use the ``pop()`` method, we reverse ``seq`` before proceeding. If you don't need Because we use the ``pop()`` method, we reverse ``seq`` before proceeding. If you don't need
@ -145,12 +110,12 @@ def iterconsume(seq, reverse=True):
# --- String related # --- String related
def escape(s, to_escape, escape_with="\\"): def escape(s: str, to_escape: str, escape_with: str = "\\") -> str:
"""Returns ``s`` with characters in ``to_escape`` all prepended with ``escape_with``.""" """Returns ``s`` with characters in ``to_escape`` all prepended with ``escape_with``."""
return "".join((escape_with + c if c in to_escape else c) for c in s) return "".join((escape_with + c if c in to_escape else c) for c in s)
def get_file_ext(filename): def get_file_ext(filename: str) -> str:
"""Returns the lowercase extension part of filename, without the dot.""" """Returns the lowercase extension part of filename, without the dot."""
pos = filename.rfind(".") pos = filename.rfind(".")
if pos > -1: if pos > -1:
@ -159,7 +124,7 @@ def get_file_ext(filename):
return "" return ""
def rem_file_ext(filename): def rem_file_ext(filename: str) -> str:
"""Returns the filename without extension.""" """Returns the filename without extension."""
pos = filename.rfind(".") pos = filename.rfind(".")
if pos > -1: if pos > -1:
@ -168,7 +133,8 @@ def rem_file_ext(filename):
return filename return filename
def pluralize(number, word, decimals=0, plural_word=None): # TODO type hint number
def pluralize(number, word: str, decimals: int = 0, plural_word: Union[str, None] = None) -> str:
"""Returns a pluralized string with ``number`` in front of ``word``. """Returns a pluralized string with ``number`` in front of ``word``.
Adds a 's' to s if ``number`` > 1. Adds a 's' to s if ``number`` > 1.
@ -187,7 +153,7 @@ def pluralize(number, word, decimals=0, plural_word=None):
return plural_format % (number, word) return plural_format % (number, word)
def format_time(seconds, with_hours=True): def format_time(seconds: int, with_hours: bool = True) -> str:
"""Transforms seconds in a hh:mm:ss string. """Transforms seconds in a hh:mm:ss string.
If ``with_hours`` if false, the format is mm:ss. If ``with_hours`` if false, the format is mm:ss.
@ -207,7 +173,7 @@ def format_time(seconds, with_hours=True):
return r return r
def format_time_decimal(seconds): def format_time_decimal(seconds: int) -> str:
"""Transforms seconds in a strings like '3.4 minutes'.""" """Transforms seconds in a strings like '3.4 minutes'."""
minus = seconds < 0 minus = seconds < 0
if minus: if minus:
@ -230,7 +196,7 @@ SIZE_DESC = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
SIZE_VALS = tuple(1024**i for i in range(1, 9)) SIZE_VALS = tuple(1024**i for i in range(1, 9))
def format_size(size, decimal=0, forcepower=-1, showdesc=True): def format_size(size: int, decimal: int = 0, forcepower: int = -1, showdesc: bool = True) -> str:
"""Transform a byte count in a formatted string (KB, MB etc..). """Transform a byte count in a formatted string (KB, MB etc..).
``size`` is the number of bytes to format. ``size`` is the number of bytes to format.
@ -268,17 +234,7 @@ def format_size(size, decimal=0, forcepower=-1, showdesc=True):
return result return result
_valid_xml_range = "\x09\x0A\x0D\x20-\uD7FF\uE000-\uFFFD" def multi_replace(s: str, replace_from: Union[str, List[str]], replace_to: Union[str, List[str]] = "") -> str:
if sys.maxunicode > 0x10000:
_valid_xml_range += "{}-{}".format(chr(0x10000), chr(min(sys.maxunicode, 0x10FFFF)))
RE_INVALID_XML_SUB = re.compile("[^%s]" % _valid_xml_range, re.U).sub
def remove_invalid_xml(s, replace_with=" "):
return RE_INVALID_XML_SUB(replace_with, s)
def multi_replace(s, replace_from, replace_to=""):
"""A function like str.replace() with multiple replacements. """A function like str.replace() with multiple replacements.
``replace_from`` is a list of things you want to replace. Ex: ['a','bc','d'] ``replace_from`` is a list of things you want to replace. Ex: ['a','bc','d']
@ -302,61 +258,15 @@ def multi_replace(s, replace_from, replace_to=""):
return s return s
# --- Date related
# It might seem like needless namespace pollution, but the speedup gained by this constant is
# significant, so it stays.
ONE_DAY = timedelta(1)
def iterdaterange(start, end):
"""Yields every day between ``start`` and ``end``."""
date = start
while date <= end:
yield date
date += ONE_DAY
# --- Files related # --- Files related
@pathify
def modified_after(first_path: Path, second_path: Path):
"""Returns ``True`` if first_path's mtime is higher than second_path's mtime.
If one of the files doesn't exist or is ``None``, it is considered "never modified".
"""
try:
first_mtime = first_path.stat().st_mtime
except (OSError, AttributeError):
return False
try:
second_mtime = second_path.stat().st_mtime
except (OSError, AttributeError):
return True
return first_mtime > second_mtime
def find_in_path(name, paths=None):
"""Search for `name` in all directories of `paths` and return the absolute path of the first
occurrence. If `paths` is None, $PATH is used.
"""
if paths is None:
paths = os.environ["PATH"]
if isinstance(paths, str): # if it's not a string, it's already a list
paths = paths.split(os.pathsep)
for path in paths:
if op.exists(op.join(path, name)):
return op.join(path, name)
return None
@log_io_error @log_io_error
@pathify @pathify
def delete_if_empty(path: Path, files_to_delete=[]): def delete_if_empty(path: Path, files_to_delete: List[str] = []) -> bool:
"""Deletes the directory at 'path' if it is empty or if it only contains files_to_delete.""" """Deletes the directory at 'path' if it is empty or if it only contains files_to_delete."""
if not path.exists() or not path.is_dir(): if not path.exists() or not path.is_dir():
return return False
contents = list(path.glob("*")) contents = list(path.glob("*"))
if any(p for p in contents if (p.name not in files_to_delete) or p.is_dir()): if any(p for p in contents if (p.name not in files_to_delete) or p.is_dir()):
return False return False
@ -366,7 +276,10 @@ def delete_if_empty(path: Path, files_to_delete=[]):
return True return True
def open_if_filename(infile, mode="rb"): def open_if_filename(
infile: Union[Path, str, IO],
mode: str = "rb",
) -> Tuple[IO, bool]:
"""If ``infile`` is a string, it opens and returns it. If it's already a file object, it simply returns it. """If ``infile`` is a string, it opens and returns it. If it's already a file object, it simply returns it.
This function returns ``(file, should_close_flag)``. The should_close_flag is True is a file has This function returns ``(file, should_close_flag)``. The should_close_flag is True is a file has
@ -386,33 +299,6 @@ def open_if_filename(infile, mode="rb"):
return (infile, False) return (infile, False)
def ensure_folder(path):
"Create `path` as a folder if it doesn't exist."
if not op.exists(path):
os.makedirs(path)
def ensure_file(path):
"Create `path` as an empty file if it doesn't exist."
if not op.exists(path):
open(path, "w").close()
def delete_files_with_pattern(folder_path, pattern, recursive=True):
"""Delete all files (or folders) in `folder_path` that match the glob `pattern`."""
to_delete = glob.glob(op.join(folder_path, pattern))
for fn in to_delete:
if op.isdir(fn):
shutil.rmtree(fn)
else:
os.remove(fn)
if recursive:
subpaths = [op.join(folder_path, fn) for fn in os.listdir(folder_path)]
subfolders = [p for p in subpaths if op.isdir(p)]
for p in subfolders:
delete_files_with_pattern(p, pattern, True)
class FileOrPath: class FileOrPath:
"""Does the same as :func:`open_if_filename`, but it can be used with a ``with`` statement. """Does the same as :func:`open_if_filename`, but it can be used with a ``with`` statement.
@ -422,16 +308,16 @@ class FileOrPath:
dostuff() dostuff()
""" """
def __init__(self, file_or_path, mode="rb"): def __init__(self, file_or_path: Union[Path, str], mode: str = "rb") -> None:
self.file_or_path = file_or_path self.file_or_path = file_or_path
self.mode = mode self.mode = mode
self.mustclose = False self.mustclose = False
self.fp = None self.fp: Union[IO, None] = None
def __enter__(self): def __enter__(self) -> IO:
self.fp, self.mustclose = open_if_filename(self.file_or_path, self.mode) self.fp, self.mustclose = open_if_filename(self.file_or_path, self.mode)
return self.fp return self.fp
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback) -> None:
if self.fp and self.mustclose: if self.fp and self.mustclose:
self.fp.close() self.fp.close()