diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..e78aa1e --- /dev/null +++ b/.flake8 @@ -0,0 +1,11 @@ +[flake8] +exclude = + .git, + build, + dist, + *.egg-info, + venv, + __pycache__, + +extend-ignore = + E501, # max line length diff --git a/.gitignore b/.gitignore index 1f0dcc1..5057722 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ tests/media dist build tests.json +*.egg-info diff --git a/.travis.yml b/.travis.yml index 23b3bd5..d877c49 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,16 +1,14 @@ +dist: bionic language: python +python: + - 3.5 + - 3.6 + - 3.7 + - 3.8 -virtualenv: - system_site_packages: true -before_install: - - sudo apt-get update - - sudo apt-get install python-opencv - - sudo dpkg -L python-opencv - - sudo ln /dev/null /dev/raw1394 install: - - "pip install -r requirements.txt" + - pip install opencv-python-headless + - pip install -r requirements.txt -python: - - "2.7" script: - python run-tests.py diff --git a/README.md b/README.md index 9771cbd..badc9b4 100644 --- a/README.md +++ b/README.md @@ -26,9 +26,9 @@ Do note that WAV is not the only format Sushi can work with. It can process audi ### Requirements Sushi should work on Windows, Linux and OS X. Please open an issue if it doesn't. To run it, you have to have the following installed: -1. [Python 2.7.x][5] +1. [Python 3.5 or higher][5] 2. [NumPy][6] (1.8 or newer) -3. [OpenCV 2.4.x or newer][7] (on Windows putting [this file][8] in the same folder as Sushi should be enough, assuming you use x86 Python) +3. [OpenCV 2.4.x or newer][7] Optionally, you might want: @@ -41,16 +41,21 @@ The provided Windows binaries include all required components and Colorama so yo #### Installation on Mac OS X -No binary packages are provided for OS X right now so you'll have to use the script form. Assuming you have python 2, pip and [homebrew](http://brew.sh/) installed, run the following: +No binary packages are provided for OS X right now so you'll have to use the script form. Assuming you have Python 3, pip and [homebrew](http://brew.sh/) installed, run the following: ```bash brew tap homebrew/science brew install git opencv -pip install numpy -git clone https://github.com/tp7/sushi -# create a symlink if you want to run sushi globally -ln -s `pwd`/sushi/sushi.py /usr/local/bin/sushi +pip3 install numpy # install some optional dependencies brew install ffmpeg mkvtoolnix + +# fetch sushi +git clone https://github.com/tp7/sushi +# run from source +python3 -m sushi args… +# install globally (for your user) +python3 setup.py install --user +sushi args… ``` If you don't have pip, you can install numpy with homebrew, but that will probably add a few more dependencies. ```bash @@ -62,9 +67,13 @@ brew install numpy If you have apt-get available, the installation process is trivial. ```bash sudo apt-get update -sudo apt-get install git python python-numpy python-opencv +sudo apt-get install git python3 python3-numpy python3-opencv git clone https://github.com/tp7/sushi -ln -s `pwd`/sushi/sushi.py /usr/local/bin/sushi +# run from source +python3 -m sushi args… +# install globally (for your user; ensure ~/.local/bin is in your PATH) +python3 setup.py install --user +sushi args… ``` ### Limitations @@ -82,7 +91,6 @@ In short, while this might be safe for immediate viewing, you probably shouldn't [5]: https://www.python.org/downloads/ [6]: http://www.scipy.org/scipylib/download.html [7]: http://opencv.org/ - [8]: https://www.dropbox.com/s/nlylgdh4bgrjgxv/cv2.pyd?dl=0 [9]: http://www.ffmpeg.org/download.html [10]: http://www.bunkus.org/videotools/mkvtoolnix/downloads.html [11]: https://github.com/soyokaze/SCXvid-standalone/releases diff --git a/build-windows.bat b/build-windows.bat index 2cecb33..05fb79f 100644 --- a/build-windows.bat +++ b/build-windows.bat @@ -4,9 +4,10 @@ pyinstaller --noupx --onefile --noconfirm ^ --exclude-module Tkconstants ^ --exclude-module Tkinter ^ --exclude-module matplotlib ^ - sushi.py + --name sushi ^ + sushi/__main__.py mkdir dist\licenses copy /Y licenses\* dist\licenses\* copy LICENSE dist\licenses\Sushi.txt -copy README.md dist\readme.md \ No newline at end of file +copy README.md dist\readme.md diff --git a/requirements.txt b/requirements.txt index 54b39f5..24ce15a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ numpy -mock diff --git a/run-tests.py b/run-tests.py index 30c8194..171a31f 100644 --- a/run-tests.py +++ b/run-tests.py @@ -1,7 +1,8 @@ import unittest -from tests.timecodes import * -from tests.main import * -from tests.subtitles import * -from tests.demuxing import * + +from tests.timecodes import * # noqa +from tests.main import * # noqa +from tests.subtitles import * # noqa +from tests.demuxing import * # noqa unittest.main(verbosity=0) diff --git a/setup.py b/setup.py index 5a2c7dd..66d85e8 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,16 @@ -from distutils.core import setup +from setuptools import setup import sushi setup( name='Sushi', description='Automatic subtitle shifter based on audio', + packages=['sushi'], version=sushi.VERSION, url='https://github.com/tp7/Sushi', - console=['sushi.py'], - license='MIT' + license='MIT', + entry_points={ + 'console_scripts': [ + "sushi=sushi.__main__:main", + ], + }, ) - diff --git a/sushi.py b/sushi/__init__.py old mode 100755 new mode 100644 similarity index 73% rename from sushi.py rename to sushi/__init__.py index 65cb35d..7847ea8 --- a/sushi.py +++ b/sushi/__init__.py @@ -1,22 +1,17 @@ -#!/usr/bin/env python2 +import bisect +from itertools import takewhile, chain import logging -import sys import operator -import argparse import os -import bisect -import collections -from itertools import takewhile, izip, chain -import time import numpy as np -import chapters -from common import SushiError, get_extension, format_time, ensure_static_collection -from demux import Timecodes, Demuxer -import keyframes -from subs import AssScript, SrtScript -from wav import WavStream +from . import chapters +from .common import SushiError, get_extension, format_time, ensure_static_collection +from .demux import Timecodes, Demuxer +from . import keyframes +from .subs import AssScript, SrtScript +from .wav import WavStream try: @@ -25,71 +20,38 @@ except ImportError: plot_enabled = False -if sys.platform == 'win32': - try: - import colorama - colorama.init() - console_colors_supported = True - except ImportError: - console_colors_supported = False -else: - console_colors_supported = True - ALLOWED_ERROR = 0.01 MAX_GROUP_STD = 0.025 VERSION = '0.5.1' -class ColoredLogFormatter(logging.Formatter): - bold_code = "\033[1m" - reset_code = "\033[0m" - grey_code = "\033[30m\033[1m" - - error_format = "{bold}ERROR: %(message)s{reset}".format(bold=bold_code, reset=reset_code) - warn_format = "{bold}WARNING: %(message)s{reset}".format(bold=bold_code, reset=reset_code) - debug_format = "{grey}%(message)s{reset}".format(grey=grey_code, reset=reset_code) - default_format = "%(message)s" - - def format(self, record): - if record.levelno == logging.DEBUG: - self._fmt = self.debug_format - elif record.levelno == logging.WARN: - self._fmt = self.warn_format - elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: - self._fmt = self.error_format - else: - self._fmt = self.default_format - - return super(ColoredLogFormatter, self).format(record) - - def abs_diff(a, b): return abs(a - b) def interpolate_nones(data, points): data = ensure_static_collection(data) - values_lookup = {p: v for p, v in izip(points, data) if v is not None} + values_lookup = {p: v for p, v in zip(points, data) if v is not None} if not values_lookup: return [] - zero_points = {p for p, v in izip(points, data) if v is None} + zero_points = {p for p, v in zip(points, data) if v is None} if not zero_points: return data - data_list = sorted(values_lookup.iteritems()) + data_list = sorted(values_lookup.items()) zero_points = sorted(x for x in zero_points if x not in values_lookup) out = np.interp(x=zero_points, - xp=map(operator.itemgetter(0), data_list), - fp=map(operator.itemgetter(1), data_list)) + xp=list(map(operator.itemgetter(0), data_list)), + fp=list(map(operator.itemgetter(1), data_list))) - values_lookup.update(izip(zero_points, out)) + values_lookup.update(zip(zero_points, out)) return [ values_lookup[point] if value is None else value - for point, value in izip(points, data) + for point, value in zip(points, data) ] @@ -100,9 +62,9 @@ def running_median(values, window_size): half_window = window_size // 2 medians = [] items_count = len(values) - for idx in xrange(items_count): - radius = min(half_window, idx, items_count-idx-1) - med = np.median(values[idx-radius:idx+radius+1]) + for idx in range(items_count): + radius = min(half_window, idx, items_count - idx - 1) + med = np.median(values[idx - radius:idx + radius + 1]) medians.append(med) return medians @@ -110,10 +72,10 @@ def running_median(values, window_size): def smooth_events(events, radius): if not radius: return - window_size = radius*2+1 + window_size = radius * 2 + 1 shifts = [e.shift for e in events] smoothed = running_median(shifts, window_size) - for event, new_shift in izip(events, smoothed): + for event, new_shift in zip(events, smoothed): event.set_shift(new_shift, event.diff) @@ -128,7 +90,7 @@ def detect_groups(events_iter): def groups_from_chapters(events, times): - logging.info(u'Chapter start points: {0}'.format([format_time(t) for t in times])) + logging.info('Chapter start points: {0}'.format([format_time(t) for t in times])) groups = [[]] chapter_times = iter(times[1:] + [36000000000]) # very large event at the end current_chapter = next(chapter_times) @@ -141,7 +103,7 @@ def groups_from_chapters(events, times): groups[-1].append(event) - groups = filter(None, groups) # non-empty groups + groups = [g for g in groups if g] # non-empty groups # check if we have any groups where every event is linked # for example a chapter with only comments inside broken_groups = [group for group in groups if not any(e for e in group if not e.linked)] @@ -152,7 +114,7 @@ def groups_from_chapters(events, times): parent_group = next(group for group in groups if parent in group) parent_group.append(event) del group[:] - groups = filter(None, groups) + groups = [g for g in groups if g] # re-sort the groups again since we might break the order when inserting linked events # sorting everything again is far from optimal but python sorting is very fast for sorted arrays anyway for group in groups: @@ -167,9 +129,9 @@ def split_broken_groups(groups): for g in groups: std = np.std([e.shift for e in g]) if std > MAX_GROUP_STD: - logging.warn(u'Shift is not consistent between {0} and {1}, most likely chapters are wrong (std: {2}). ' - u'Switching to automatic grouping.'.format(format_time(g[0].start), format_time(g[-1].end), - std)) + logging.warn('Shift is not consistent between {0} and {1}, most likely chapters are wrong (std: {2}). ' + 'Switching to automatic grouping.'.format(format_time(g[0].start), format_time(g[-1].end), + std)) correct_groups.extend(detect_groups(g)) broken_found = True else: @@ -254,7 +216,7 @@ def find_keyframe_distance(src_time, dst_time): dst = get_distance_to_closest_kf(dst_time, dst_keytimes) snapping_limit = timecodes.get_frame_size(src_time) * max_kf_distance - if abs(src) < snapping_limit and abs(dst) < snapping_limit and abs(src-dst) < snapping_limit: + if abs(src) < snapping_limit and abs(dst) < snapping_limit and abs(src - dst) < snapping_limit: return dst - src return 0 @@ -281,11 +243,11 @@ def snap_groups_to_keyframes(events, chapter_times, max_ts_duration, max_ts_dist shifts = interpolate_nones(shifts, times) if shifts: mean_shift = np.mean(shifts) - shifts = zip(*(iter(shifts), ) * 2) + shifts = zip(*[iter(shifts)] * 2) logging.info('Group {0}-{1} corrected by {2}'.format(format_time(events[0].start), format_time(events[-1].end), mean_shift)) - for group, (start_shift, end_shift) in izip(groups, shifts): - if abs(start_shift-end_shift) > 0.001 and len(group) > 1: + for group, (start_shift, end_shift) in zip(groups, shifts): + if abs(start_shift - end_shift) > 0.001 and len(group) > 1: actual_shift = min(start_shift, end_shift, key=lambda x: abs(x - mean_shift)) logging.warning("Typesetting group at {0} had different shift at start/end points ({1} and {2}). Shifting by {3}." .format(format_time(group[0].start), start_shift, end_shift, actual_shift)) @@ -336,7 +298,7 @@ def merge_short_lines_into_groups(events, chapter_times, max_ts_duration, max_ts else: group = [event] group_end = event.end - i = idx+1 + i = idx + 1 while i < len(events) and abs(group_end - events[i].start) < max_ts_distance: if events[i].end < next_chapter and events[i].duration <= max_ts_duration: processed.add(i) @@ -354,12 +316,12 @@ def prepare_search_groups(events, source_duration, chapter_times, max_ts_duratio for idx, event in enumerate(events): if event.is_comment: try: - event.link_event(events[idx+1]) + event.link_event(events[idx + 1]) except IndexError: event.link_event(last_unlinked) continue if (event.start + event.duration / 2.0) > source_duration: - logging.info('Event time outside of audio range, ignoring: %s' % unicode(event)) + logging.info('Event time outside of audio range, ignoring: %s', event) event.link_event(last_unlinked) continue elif event.end == event.start: @@ -372,8 +334,9 @@ def prepare_search_groups(events, source_duration, chapter_times, max_ts_duratio # link lines with start and end times identical to some other event # assuming scripts are sorted by start time so we don't search the entire collection - same_start = lambda x: event.start == x.start - processed = next((x for x in takewhile(same_start, reversed(events[:idx])) if not x.linked and x.end == event.end),None) + def same_start(x): + return event.start == x.start + processed = next((x for x in takewhile(same_start, reversed(events[:idx])) if not x.linked and x.end == event.end), None) if processed: event.link_event(processed) else: @@ -400,7 +363,7 @@ def prepare_search_groups(events, source_duration, chapter_times, max_ts_duratio def calculate_shifts(src_stream, dst_stream, groups_list, normal_window, max_window, rewind_thresh): def log_shift(state): logging.info('{0}-{1}: shift: {2:0.10f}, diff: {3:0.10f}' - .format(format_time(state["start_time"]), format_time(state["end_time"]), state["shift"], state["diff"])) + .format(format_time(state["start_time"]), format_time(state["end_time"]), state["shift"], state["diff"])) def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offset): logging.debug('{0}-{1}: shift: {2:0.5f} [{3:0.5f}, {4:0.5f}], search offset: {5:0.6f}' @@ -442,7 +405,7 @@ def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offs idx += 1 continue - left_audio_half, right_audio_half = np.split(tv_audio, [len(tv_audio[0])/2], axis=1) + left_audio_half, right_audio_half = np.split(tv_audio, [len(tv_audio[0]) // 2], axis=1) right_half_offset = len(left_audio_half[0]) / float(src_stream.sample_rate) terminate = False # searching from last committed shift @@ -456,7 +419,7 @@ def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offs if not terminate and uncommitted_states and uncommitted_states[-1]["shift"] is not None \ and original_time + uncommitted_states[-1]["shift"] < dst_stream.duration_seconds: - start_offset = uncommitted_states[-1]["shift"] + start_offset = uncommitted_states[-1]["shift"] diff, new_time = dst_stream.find_substream(tv_audio, original_time + start_offset, window) left_side_time = dst_stream.find_substream(left_audio_half, original_time + start_offset, window)[1] right_side_time = dst_stream.find_substream(right_audio_half, original_time + start_offset + right_half_offset, window)[1] - right_half_offset @@ -495,7 +458,7 @@ def log_uncommitted(state, shift, left_side_shift, right_side_shift, search_offs for state in uncommitted_states: log_shift(state) - for idx, (search_group, group_state) in enumerate(izip(groups_list, chain(committed_states, uncommitted_states))): + for idx, (search_group, group_state) in enumerate(zip(groups_list, chain(committed_states, uncommitted_states))): if group_state["shift"] is None: for group in reversed(groups_list[:idx]): link_to = next((x for x in reversed(group) if not x.linked), None) @@ -580,7 +543,7 @@ def run(args): src_script_path = args.script_file else: stype = src_demuxer.get_subs_type(args.src_script_idx) - src_script_path = format_full_path(args.temp_dir, args.source, '.sushi'+ stype) + src_script_path = format_full_path(args.temp_dir, args.source, '.sushi' + stype) src_demuxer.set_script(stream_idx=args.src_script_idx, output_path=src_script_path) script_extension = get_extension(src_script_path) @@ -698,8 +661,8 @@ def select_timecodes(external_file, fps_arg, demuxer): start_shift = g[0].shift end_shift = g[-1].shift avg_shift = average_shifts(g) - logging.info(u'Group (start: {0}, end: {1}, lines: {2}), ' - u'shifts (start: {3}, end: {4}, average: {5})' + logging.info('Group (start: {0}, end: {1}, lines: {2}), ' + 'shifts (start: {3}, end: {4}, average: {5})' .format(format_time(g[0].start), format_time(g[-1].end), len(g), start_shift, end_shift, avg_shift)) @@ -726,7 +689,7 @@ def select_timecodes(external_file, fps_arg, demuxer): script.save_to_file(dst_script_path) if write_plot: - plt.plot([x.shift + (x._start_shift + x._end_shift)/2.0 for x in events], label='After correction') + plt.plot([x.shift + (x._start_shift + x._end_shift) / 2.0 for x in events], label='After correction') plt.legend(fontsize=5, frameon=False, fancybox=False) plt.savefig(args.plot_path, dpi=300) @@ -734,110 +697,3 @@ def select_timecodes(external_file, fps_arg, demuxer): if args.cleanup: src_demuxer.cleanup() dst_demuxer.cleanup() - - -def create_arg_parser(): - parser = argparse.ArgumentParser(description='Sushi - Automatic Subtitle Shifter') - - parser.add_argument('--window', default=10, type=int, metavar='', dest='window', - help='Search window size. [%(default)s]') - parser.add_argument('--max-window', default=30, type=int, metavar='', dest='max_window', - help="Maximum search size Sushi is allowed to use when trying to recover from errors. [%(default)s]") - parser.add_argument('--rewind-thresh', default=5, type=int, metavar='', dest='rewind_thresh', - help="Number of consecutive errors Sushi has to encounter to consider results broken " - "and retry with larger window. Set to 0 to disable. [%(default)s]") - parser.add_argument('--no-grouping', action='store_false', dest='grouping', - help="Don't events into groups before shifting. Also disables error recovery.") - parser.add_argument('--max-kf-distance', default=2, type=float, metavar='', dest='max_kf_distance', - help='Maximum keyframe snapping distance. [%(default)s]') - parser.add_argument('--kf-mode', default='all', choices=['shift', 'snap', 'all'], dest='kf_mode', - help='Keyframes-based shift correction/snapping mode. [%(default)s]') - parser.add_argument('--smooth-radius', default=3, type=int, metavar='', dest='smooth_radius', - help='Radius of smoothing median filter. [%(default)s]') - - # 10 frames at 23.976 - parser.add_argument('--max-ts-duration', default=1001.0 / 24000.0 * 10, type=float, metavar='', - dest='max_ts_duration', - help='Maximum duration of a line to be considered typesetting. [%(default).3f]') - # 10 frames at 23.976 - parser.add_argument('--max-ts-distance', default=1001.0 / 24000.0 * 10, type=float, metavar='', - dest='max_ts_distance', - help='Maximum distance between two adjacent typesetting lines to be merged. [%(default).3f]') - - # deprecated/test options, do not use - parser.add_argument('--test-shift-plot', default=None, dest='plot_path', help=argparse.SUPPRESS) - parser.add_argument('--sample-type', default='uint8', choices=['float32', 'uint8'], dest='sample_type', - help=argparse.SUPPRESS) - - parser.add_argument('--sample-rate', default=12000, type=int, metavar='', dest='sample_rate', - help='Downsampled audio sample rate. [%(default)s]') - - parser.add_argument('--src-audio', default=None, type=int, metavar='', dest='src_audio_idx', - help='Audio stream index of the source video') - parser.add_argument('--src-script', default=None, type=int, metavar='', dest='src_script_idx', - help='Script stream index of the source video') - parser.add_argument('--dst-audio', default=None, type=int, metavar='', dest='dst_audio_idx', - help='Audio stream index of the destination video') - # files - parser.add_argument('--no-cleanup', action='store_false', dest='cleanup', - help="Don't delete demuxed streams") - parser.add_argument('--temp-dir', default=None, dest='temp_dir', metavar='', - help='Specify temporary folder to use when demuxing stream.') - parser.add_argument('--chapters', default=None, dest='chapters_file', metavar='', - help="XML or OGM chapters to use instead of any found in the source. 'none' to disable.") - parser.add_argument('--script', default=None, dest='script_file', metavar='', - help='Subtitle file path to use instead of any found in the source') - - parser.add_argument('--dst-keyframes', default=None, dest='dst_keyframes', metavar='', - help='Destination keyframes file') - parser.add_argument('--src-keyframes', default=None, dest='src_keyframes', metavar='', - help='Source keyframes file') - parser.add_argument('--dst-fps', default=None, type=float, dest='dst_fps', metavar='', - help='Fps of the destination video. Must be provided if keyframes are used.') - parser.add_argument('--src-fps', default=None, type=float, dest='src_fps', metavar='', - help='Fps of the source video. Must be provided if keyframes are used.') - parser.add_argument('--dst-timecodes', default=None, dest='dst_timecodes', metavar='', - help='Timecodes file to use instead of making one from the destination (when possible)') - parser.add_argument('--src-timecodes', default=None, dest='src_timecodes', metavar='', - help='Timecodes file to use instead of making one from the source (when possible)') - - parser.add_argument('--src', required=True, dest="source", metavar='', - help='Source audio/video') - parser.add_argument('--dst', required=True, dest="destination", metavar='', - help='Destination audio/video') - parser.add_argument('-o', '--output', default=None, dest='output_script', metavar='', - help='Output script') - - parser.add_argument('-v', '--verbose', default=False, dest='verbose', action='store_true', - help='Enable verbose logging') - parser.add_argument('--version', action='version', version=VERSION) - - return parser - - -def parse_args_and_run(cmd_keys): - def format_arg(arg): - return arg if ' ' not in arg else '"{0}"'.format(arg) - - args = create_arg_parser().parse_args(cmd_keys) - handler = logging.StreamHandler() - if console_colors_supported and os.isatty(sys.stderr.fileno()): - # enable colors - handler.setFormatter(ColoredLogFormatter()) - else: - handler.setFormatter(logging.Formatter(fmt=ColoredLogFormatter.default_format)) - logging.root.addHandler(handler) - logging.root.setLevel(logging.DEBUG if args.verbose else logging.INFO) - - logging.info("Sushi's running with arguments: {0}".format(' '.join(map(format_arg, cmd_keys)))) - start_time = time.time() - run(args) - logging.info('Done in {0}s'.format(time.time() - start_time)) - - -if __name__ == '__main__': - try: - parse_args_and_run(sys.argv[1:]) - except SushiError as e: - logging.critical(e.message) - sys.exit(2) diff --git a/sushi/__main__.py b/sushi/__main__.py new file mode 100755 index 0000000..41637c3 --- /dev/null +++ b/sushi/__main__.py @@ -0,0 +1,154 @@ +import argparse +import logging +import os +import sys +import time + +# Use absolute imports to support pyinstaller +# https://github.com/pyinstaller/pyinstaller/issues/2560 +from sushi import run, VERSION +from sushi.common import SushiError + +if sys.platform == 'win32': + try: + import colorama + colorama.init() + console_colors_supported = True + except ImportError: + console_colors_supported = False +else: + console_colors_supported = True + + +class ColoredLogFormatter(logging.Formatter): + bold_code = "\033[1m" + reset_code = "\033[0m" + grey_code = "\033[30m\033[1m" + + error_format = "{bold}ERROR: %(message)s{reset}".format(bold=bold_code, reset=reset_code) + warn_format = "{bold}WARNING: %(message)s{reset}".format(bold=bold_code, reset=reset_code) + debug_format = "{grey}%(message)s{reset}".format(grey=grey_code, reset=reset_code) + default_format = "%(message)s" + + def format(self, record): + if record.levelno == logging.DEBUG: + self._fmt = self.debug_format + elif record.levelno == logging.WARN: + self._fmt = self.warn_format + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + self._fmt = self.error_format + else: + self._fmt = self.default_format + + return super(ColoredLogFormatter, self).format(record) + + +def create_arg_parser(): + parser = argparse.ArgumentParser(description='Sushi - Automatic Subtitle Shifter') + + parser.add_argument('--window', default=10, type=int, metavar='', dest='window', + help='Search window size. [%(default)s]') + parser.add_argument('--max-window', default=30, type=int, metavar='', dest='max_window', + help="Maximum search size Sushi is allowed to use when trying to recover from errors. [%(default)s]") + parser.add_argument('--rewind-thresh', default=5, type=int, metavar='', dest='rewind_thresh', + help="Number of consecutive errors Sushi has to encounter to consider results broken " + "and retry with larger window. Set to 0 to disable. [%(default)s]") + parser.add_argument('--no-grouping', action='store_false', dest='grouping', + help="Don't events into groups before shifting. Also disables error recovery.") + parser.add_argument('--max-kf-distance', default=2, type=float, metavar='', dest='max_kf_distance', + help='Maximum keyframe snapping distance. [%(default)s]') + parser.add_argument('--kf-mode', default='all', choices=['shift', 'snap', 'all'], dest='kf_mode', + help='Keyframes-based shift correction/snapping mode. [%(default)s]') + parser.add_argument('--smooth-radius', default=3, type=int, metavar='', dest='smooth_radius', + help='Radius of smoothing median filter. [%(default)s]') + + # 10 frames at 23.976 + parser.add_argument('--max-ts-duration', default=1001.0 / 24000.0 * 10, type=float, metavar='', + dest='max_ts_duration', + help='Maximum duration of a line to be considered typesetting. [%(default).3f]') + # 10 frames at 23.976 + parser.add_argument('--max-ts-distance', default=1001.0 / 24000.0 * 10, type=float, metavar='', + dest='max_ts_distance', + help='Maximum distance between two adjacent typesetting lines to be merged. [%(default).3f]') + + # deprecated/test options, do not use + parser.add_argument('--test-shift-plot', default=None, dest='plot_path', help=argparse.SUPPRESS) + parser.add_argument('--sample-type', default='uint8', choices=['float32', 'uint8'], dest='sample_type', + help=argparse.SUPPRESS) + + parser.add_argument('--sample-rate', default=12000, type=int, metavar='', dest='sample_rate', + help='Downsampled audio sample rate. [%(default)s]') + + parser.add_argument('--src-audio', default=None, type=int, metavar='', dest='src_audio_idx', + help='Audio stream index of the source video') + parser.add_argument('--src-script', default=None, type=int, metavar='', dest='src_script_idx', + help='Script stream index of the source video') + parser.add_argument('--dst-audio', default=None, type=int, metavar='', dest='dst_audio_idx', + help='Audio stream index of the destination video') + # files + parser.add_argument('--no-cleanup', action='store_false', dest='cleanup', + help="Don't delete demuxed streams") + parser.add_argument('--temp-dir', default=None, dest='temp_dir', metavar='', + help='Specify temporary folder to use when demuxing stream.') + parser.add_argument('--chapters', default=None, dest='chapters_file', metavar='', + help="XML or OGM chapters to use instead of any found in the source. 'none' to disable.") + parser.add_argument('--script', default=None, dest='script_file', metavar='', + help='Subtitle file path to use instead of any found in the source') + + parser.add_argument('--dst-keyframes', default=None, dest='dst_keyframes', metavar='', + help='Destination keyframes file') + parser.add_argument('--src-keyframes', default=None, dest='src_keyframes', metavar='', + help='Source keyframes file') + parser.add_argument('--dst-fps', default=None, type=float, dest='dst_fps', metavar='', + help='Fps of the destination video. Must be provided if keyframes are used.') + parser.add_argument('--src-fps', default=None, type=float, dest='src_fps', metavar='', + help='Fps of the source video. Must be provided if keyframes are used.') + parser.add_argument('--dst-timecodes', default=None, dest='dst_timecodes', metavar='', + help='Timecodes file to use instead of making one from the destination (when possible)') + parser.add_argument('--src-timecodes', default=None, dest='src_timecodes', metavar='', + help='Timecodes file to use instead of making one from the source (when possible)') + + parser.add_argument('--src', required=True, dest="source", metavar='', + help='Source audio/video') + parser.add_argument('--dst', required=True, dest="destination", metavar='', + help='Destination audio/video') + parser.add_argument('-o', '--output', default=None, dest='output_script', metavar='', + help='Output script') + + parser.add_argument('-v', '--verbose', default=False, dest='verbose', action='store_true', + help='Enable verbose logging') + parser.add_argument('--version', action='version', version=VERSION) + + return parser + + +def parse_args_and_run(cmd_keys): + def format_arg(arg): + return arg if ' ' not in arg else '"{0}"'.format(arg) + + args = create_arg_parser().parse_args(cmd_keys) + handler = logging.StreamHandler() + if console_colors_supported and os.isatty(sys.stderr.fileno()): + # enable colors + handler.setFormatter(ColoredLogFormatter()) + else: + handler.setFormatter(logging.Formatter(fmt=ColoredLogFormatter.default_format)) + logging.root.addHandler(handler) + logging.root.setLevel(logging.DEBUG if args.verbose else logging.INFO) + + logging.info("Sushi's running with arguments: {0}".format(' '.join(map(format_arg, cmd_keys)))) + start_time = time.time() + run(args) + logging.info('Done in {0}s'.format(time.time() - start_time)) + + +def main(): + try: + parse_args_and_run(sys.argv[1:]) + except SushiError as e: + logging.critical(e.args[0]) + sys.exit(2) + + +if __name__ == '__main__': + main() diff --git a/chapters.py b/sushi/chapters.py similarity index 91% rename from chapters.py rename to sushi/chapters.py index 4bf1de3..104fb2d 100644 --- a/chapters.py +++ b/sushi/chapters.py @@ -1,11 +1,11 @@ import re -import common +from . import common def parse_times(times): result = [] for t in times: - hours, minutes, seconds = map(float, t.split(':')) + hours, minutes, seconds = list(map(float, t.split(':'))) result.append(hours * 3600 + minutes * 60 + seconds) result.sort() diff --git a/common.py b/sushi/common.py similarity index 88% rename from common.py rename to sushi/common.py index 595fbe5..7a6d96e 100644 --- a/common.py +++ b/sushi/common.py @@ -22,7 +22,7 @@ def ensure_static_collection(value): def format_srt_time(seconds): cs = round(seconds * 1000) - return u'{0:02d}:{1:02d}:{2:02d},{3:03d}'.format( + return '{0:02d}:{1:02d}:{2:02d},{3:03d}'.format( int(cs // 3600000), int((cs // 60000) % 60), int((cs // 1000) % 60), @@ -31,7 +31,7 @@ def format_srt_time(seconds): def format_time(seconds): cs = round(seconds * 100) - return u'{0}:{1:02d}:{2:02d}.{3:02d}'.format( + return '{0}:{1:02d}:{2:02d}.{3:02d}'.format( int(cs // 360000), int((cs // 6000) % 60), int((cs // 100) % 60), diff --git a/demux.py b/sushi/demux.py similarity index 94% rename from demux.py rename to sushi/demux.py index d5a5e37..fd7192d 100644 --- a/demux.py +++ b/sushi/demux.py @@ -5,8 +5,8 @@ import logging import bisect -from common import SushiError, get_extension -import chapters +from .common import SushiError, get_extension +from . import chapters MediaStreamInfo = namedtuple('MediaStreamInfo', ['id', 'info', 'default', 'title']) SubtitlesStreamInfo = namedtuple('SubtitlesStreamInfo', ['id', 'info', 'type', 'default', 'title']) @@ -17,7 +17,9 @@ class FFmpeg(object): @staticmethod def get_info(path): try: - process = subprocess.Popen(['ffmpeg', '-hide_banner', '-i', path], stderr=subprocess.PIPE) + # text=True is an alias for universal_newlines since 3.7 + process = subprocess.Popen(['ffmpeg', '-hide_banner', '-i', path], stderr=subprocess.PIPE, + universal_newlines=True, encoding='utf-8') out, err = process.communicate() process.wait() return err @@ -75,7 +77,7 @@ def _get_video_streams(info): @staticmethod def _get_chapters_times(info): - return map(float, re.findall(r'Chapter #0.\d+: start (\d+\.\d+)', info)) + return list(map(float, re.findall(r'Chapter #0.\d+: start (\d+\.\d+)', info))) @staticmethod def _get_subtitles_streams(info): @@ -113,10 +115,11 @@ class SCXviD(object): def make_keyframes(cls, video_path, log_path): try: ffmpeg_process = subprocess.Popen(['ffmpeg', '-i', video_path, - '-f', 'yuv4mpegpipe', - '-vf', 'scale=640:360', - '-pix_fmt', 'yuv420p', - '-vsync', 'drop', '-'], stdout=subprocess.PIPE) + '-f', 'yuv4mpegpipe', + '-vf', 'scale=640:360', + '-pix_fmt', 'yuv420p', + '-vsync', 'drop', '-'], + stdout=subprocess.PIPE) except OSError as e: if e.errno == 2: raise SushiError("Couldn't invoke ffmpeg, check that it's installed") @@ -143,7 +146,7 @@ def get_frame_time(self, number): return self.times[number] except IndexError: if not self.default_frame_duration: - return self.get_frame_time(len(self.times)-1) + return self.get_frame_time(len(self.times) - 1) if self.times: return self.times[-1] + (self.default_frame_duration) * (number - len(self.times) + 1) else: @@ -157,7 +160,7 @@ def get_frame_number(self, timestamp): def get_frame_size(self, timestamp): try: number = bisect.bisect_left(self.times, timestamp) - except: + except Exception: return self.default_frame_duration c = self.get_frame_time(number) @@ -353,4 +356,3 @@ def _select_stream(self, streams, chosen_idx, name): raise SushiError("Stream with index {0} doesn't exist in {1}.\n" "Here are all that do:\n" "{2}".format(chosen_idx, self._path, self._format_streams_list(streams))) - diff --git a/keyframes.py b/sushi/keyframes.py similarity index 69% rename from keyframes.py rename to sushi/keyframes.py index f749393..f1b0282 100644 --- a/keyframes.py +++ b/sushi/keyframes.py @@ -1,8 +1,9 @@ -from common import SushiError, read_all_text +from .common import SushiError, read_all_text def parse_scxvid_keyframes(text): - return [i-3 for i,line in enumerate(text.splitlines()) if line and line[0] == 'i'] + return [i - 3 for i, line in enumerate(text.splitlines()) if line and line[0] == 'i'] + def parse_keyframes(path): text = read_all_text(path) diff --git a/regression-tests.py b/sushi/regression-tests.py similarity index 91% rename from regression-tests.py rename to sushi/regression-tests.py index f1161e8..825f787 100644 --- a/regression-tests.py +++ b/sushi/regression-tests.py @@ -9,10 +9,10 @@ import subprocess import argparse -from common import format_time -from demux import Timecodes -from subs import AssScript -from wav import WavStream +from .common import format_time +from .demux import Timecodes +from .subs import AssScript +from .wav import WavStream root_logger = logging.getLogger('') @@ -54,24 +54,24 @@ def compare_scripts(ideal_path, test_path, timecodes, test_name, expected_errors test_end_frame = timecodes.get_frame_number(test.end) if ideal_start_frame != test_start_frame and ideal_end_frame != test_end_frame: - logging.debug(u'{0}: start and end time failed at "{1}". {2}-{3} vs {4}-{5}'.format( + logging.debug('{0}: start and end time failed at "{1}". {2}-{3} vs {4}-{5}'.format( idx, strip_tags(ideal.text), ft(ideal.start), ft(ideal.end), ft(test.start), ft(test.end)) ) failed += 1 elif ideal_end_frame != test_end_frame: logging.debug( - u'{0}: end time failed at "{1}". {2} vs {3}'.format( + '{0}: end time failed at "{1}". {2} vs {3}'.format( idx, strip_tags(ideal.text), ft(ideal.end), ft(test.end)) ) failed += 1 elif ideal_start_frame != test_start_frame: logging.debug( - u'{0}: start time failed at "{1}". {2} vs {3}'.format( + '{0}: start time failed at "{1}". {2} vs {3}'.format( idx, strip_tags(ideal.text), ft(ideal.start), ft(test.start)) ) failed += 1 - logging.info('Total lines: {0}, good: {1}, failed: {2}'.format(len(ideal_script.events), len(ideal_script.events)-failed, failed)) + logging.info('Total lines: {0}, good: {1}, failed: {2}'.format(len(ideal_script.events), len(ideal_script.events) - failed, failed)) if failed > expected_errors: logging.critical('Got more failed lines than expected ({0} actual vs {1} expected)'.format(failed, expected_errors)) @@ -141,7 +141,7 @@ def run_wav_test(test_name, file_path, params): gc.collect(2) before = resource.getrusage(resource.RUSAGE_SELF) - loaded = WavStream(file_path, params.get('sample_rate', 12000), params.get('sample_type', 'uint8')) + _ = WavStream(file_path, params.get('sample_rate', 12000), params.get('sample_type', 'uint8')) after = resource.getrusage(resource.RUSAGE_SELF) total_time = (after.ru_stime - before.ru_stime) + (after.ru_utime - before.ru_utime) @@ -189,7 +189,7 @@ def should_run(name): return not args.run_only or name in args.run_only failed = ran = 0 - for test_name, params in config.get('tests', {}).iteritems(): + for test_name, params in config.get('tests', {}).items(): if not should_run(test_name): continue if not params.get('disabled', False): @@ -201,7 +201,7 @@ def should_run(name): logging.warn('Test "{0}" disabled'.format(test_name)) if should_run("wavs"): - for test_name, params in config.get('wavs', {}).iteritems(): + for test_name, params in config.get('wavs', {}).items(): ran += 1 if not run_wav_test(test_name, os.path.join(config['basepath'], params['file']), params): failed += 1 diff --git a/subs.py b/sushi/subs.py similarity index 82% rename from subs.py rename to sushi/subs.py index 6b31d9a..7d8ab39 100644 --- a/subs.py +++ b/sushi/subs.py @@ -3,7 +3,7 @@ import re import collections -from common import SushiError, format_time, format_srt_time +from .common import SushiError, format_time, format_srt_time def _parse_ass_time(string): @@ -79,9 +79,6 @@ def adjust_shift(self, value): assert not self.linked, 'Cannot adjust time of linked events' self._shift += value - def __repr__(self): - return unicode(self) - class ScriptBase(object): def __init__(self, events): @@ -95,7 +92,7 @@ class SrtEvent(ScriptEventBase): is_comment = False style = None - EVENT_REGEX = re.compile(""" + EVENT_REGEX = re.compile(r""" (\d+?)\s+? # line-number (\d{1,2}:\d{1,2}:\d{1,2},\d+)\s-->\s(\d{1,2}:\d{1,2}:\d{1,2},\d+). # timestamp (.+?) # actual text @@ -112,9 +109,9 @@ def from_string(cls, text): end = cls.parse_time(match.group(3)) return SrtEvent(int(match.group(1)), start, end, match.group(4).strip()) - def __unicode__(self): - return u'{0}\n{1} --> {2}\n{3}'.format(self.source_index, self._format_time(self.start), - self._format_time(self.end), self.text) + def __str__(self): + return '{0}\n{1} --> {2}\n{3}'.format(self.source_index, self._format_time(self.start), + self._format_time(self.end), self.text) @staticmethod def parse_time(time_string): @@ -142,7 +139,7 @@ def from_file(cls, path): raise SushiError("Script {0} not found".format(path)) def save_to_file(self, path): - text = '\n\n'.join(map(unicode, self.events)) + text = '\n\n'.join(map(str, self.events)) with codecs.open(path, encoding='utf-8', mode='w') as script: script.write(text) @@ -168,14 +165,14 @@ def __init__(self, text, position=0): self.margin_vertical = split[7] self.effect = split[8] - def __unicode__(self): - return u'{0}: {1},{2},{3},{4},{5},{6},{7},{8},{9},{10}'.format(self.kind, self.layer, - self._format_time(self.start), - self._format_time(self.end), - self.style, self.name, - self.margin_left, self.margin_right, - self.margin_vertical, self.effect, - self.text) + def __str__(self): + return '{0}: {1},{2},{3},{4},{5},{6},{7},{8},{9},{10}'.format(self.kind, self.layer, + self._format_time(self.start), + self._format_time(self.end), + self.style, self.name, + self.margin_left, self.margin_right, + self.margin_vertical, self.effect, + self.text) @staticmethod def _format_time(seconds): @@ -195,19 +192,19 @@ def from_file(cls, path): other_sections = collections.OrderedDict() def parse_script_info_line(line): - if line.startswith(u'Format:'): + if line.startswith('Format:'): return script_info.append(line) def parse_styles_line(line): - if line.startswith(u'Format:'): + if line.startswith('Format:'): return styles.append(line) def parse_event_line(line): - if line.startswith(u'Format:'): + if line.startswith('Format:'): return - events.append(AssEvent(line, position=len(events)+1)) + events.append(AssEvent(line, position=len(events) + 1)) def create_generic_parse(section_name): if section_name in other_sections: @@ -224,11 +221,11 @@ def create_generic_parse(section_name): if not line: continue low = line.lower() - if low == u'[script info]': + if low == '[script info]': parse_function = parse_script_info_line - elif low == u'[v4+ styles]': + elif low == '[v4+ styles]': parse_function = parse_styles_line - elif low == u'[events]': + elif low == '[events]': parse_function = parse_event_line elif re.match(r'\[.+?\]', low): parse_function = create_generic_parse(line) @@ -248,27 +245,27 @@ def save_to_file(self, path): # raise RuntimeError('File %s already exists' % path) lines = [] if self.script_info: - lines.append(u'[Script Info]') + lines.append('[Script Info]') lines.extend(self.script_info) lines.append('') if self.styles: - lines.append(u'[V4+ Styles]') - lines.append(u'Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding') + lines.append('[V4+ Styles]') + lines.append('Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding') lines.extend(self.styles) lines.append('') if self.events: events = sorted(self.events, key=lambda x: x.source_index) - lines.append(u'[Events]') - lines.append(u'Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text') - lines.extend(map(unicode, events)) + lines.append('[Events]') + lines.append('Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text') + lines.extend(map(str, events)) if self.other: - for section_name, section_lines in self.other.iteritems(): + for section_name, section_lines in self.other.items(): lines.append('') lines.append(section_name) lines.extend(section_lines) with codecs.open(path, encoding='utf-8-sig', mode='w') as script: - script.write(unicode(os.linesep).join(lines)) + script.write(str(os.linesep).join(lines)) diff --git a/wav.py b/sushi/wav.py similarity index 94% rename from wav.py rename to sushi/wav.py index 4d4f785..e0707be 100644 --- a/wav.py +++ b/sushi/wav.py @@ -6,7 +6,9 @@ import math from time import time import os.path -from common import SushiError, clip + +from .common import SushiError, clip +from functools import reduce WAVE_FORMAT_PCM = 0x0001 WAVE_FORMAT_EXTENSIBLE = 0xFFFE @@ -20,9 +22,9 @@ def __init__(self, path): self._file = open(path, 'rb') try: riff = Chunk(self._file, bigendian=False) - if riff.getname() != 'RIFF': + if riff.getname() != b'RIFF': raise SushiError('File does not start with RIFF id') - if riff.read(4) != 'WAVE': + if riff.read(4) != b'WAVE': raise SushiError('Not a WAVE file') fmt_chunk_read = False @@ -35,10 +37,10 @@ def __init__(self, path): except EOFError: break - if chunk.getname() == 'fmt ': + if chunk.getname() == b'fmt ': self._read_fmt_chunk(chunk) fmt_chunk_read = True - elif chunk.getname() == 'data': + elif chunk.getname() == b'data': if file_size > 0xFFFFFFFF: # large broken wav self.frames_count = (file_size - self._file.tell()) // self.frame_size @@ -49,7 +51,7 @@ def __init__(self, path): chunk.skip() if not fmt_chunk_read or not data_chink_read: raise SushiError('Invalid WAV file') - except: + except Exception: self.close() raise @@ -85,8 +87,8 @@ def readframes(self, count): if min_length != real_length: logging.error("Length of audio channels didn't match. This might result in broken output") - channels = (unpacked[i::self.channels_count] for i in xrange(self.channels_count)) - data = reduce(lambda a, b: a[:min_length]+b[:min_length], channels) + channels = (unpacked[i::self.channels_count] for i in range(self.channels_count)) + data = reduce(lambda a, b: a[:min_length] + b[:min_length], channels) data /= float(self.channels_count) return data @@ -126,7 +128,7 @@ def __init__(self, path, sample_rate=12000, sample_type='uint8'): data = stream.readframes(int(self.READ_CHUNK_SIZE * stream.framerate)) new_length = int(round(len(data) * downsample_rate)) - dst_view = self.data[0][samples_read:samples_read+new_length] + dst_view = self.data[0][samples_read:samples_read + new_length] if downsample_rate != 1: data = data.reshape((1, len(data))) @@ -138,7 +140,7 @@ def __init__(self, path, sample_rate=12000, sample_type='uint8'): # padding the audio from both sides self.data[0][0:self.padding_size].fill(self.data[0][self.padding_size]) - self.data[0][-self.padding_size:].fill(self.data[0][-self.padding_size-1]) + self.data[0][-self.padding_size:].fill(self.data[0][-self.padding_size - 1]) # normalizing # also clipping the stream by 3*median value from both sides of zero diff --git a/tests/demuxing.py b/tests/demuxing.py index 8eb4d42..4d9a610 100644 --- a/tests/demuxing.py +++ b/tests/demuxing.py @@ -1,9 +1,9 @@ import unittest -import mock +from unittest import mock -from demux import FFmpeg, MkvToolnix, SCXviD -from common import SushiError -import chapters +from sushi.demux import FFmpeg, MkvToolnix, SCXviD +from sushi.common import SushiError +from sushi import chapters def create_popen_mock(): @@ -60,7 +60,7 @@ def test_parses_subtitles_stream(self): @mock.patch('subprocess.Popen', new_callable=create_popen_mock) def test_get_info_call_args(self, popen_mock): FFmpeg.get_info('random_file.mkv') - self.assertEquals(popen_mock.call_args[0][0], ['ffmpeg', '-hide_banner', '-i', 'random_file.mkv']) + self.assertEqual(popen_mock.call_args[0][0], ['ffmpeg', '-hide_banner', '-i', 'random_file.mkv']) @mock.patch('subprocess.Popen', new_callable=create_popen_mock) def test_get_info_fail_when_no_mmpeg(self, popen_mock): @@ -112,8 +112,8 @@ def raise_no_app(cmd_args, **kwargs): raise OSError(2, 'ignored') popen_mock.side_effect = raise_no_app - self.assertRaisesRegexp(SushiError, '[fF][fF][mM][pP][eE][gG]', - lambda: SCXviD.make_keyframes('video.mkv', 'keyframes.txt')) + self.assertRaisesRegex(SushiError, '[fF][fF][mM][pP][eE][gG]', + lambda: SCXviD.make_keyframes('video.mkv', 'keyframes.txt')) @mock.patch('subprocess.Popen') def test_no_scxvid(self, popen_mock): @@ -123,8 +123,8 @@ def raise_no_app(cmd_args, **kwargs): return mock.Mock() popen_mock.side_effect = raise_no_app - self.assertRaisesRegexp(SushiError, '[sS][cC][xX][vV][iI][dD]', - lambda: SCXviD.make_keyframes('video.mkv', 'keyframes.txt')) + self.assertRaisesRegex(SushiError, '[sS][cC][xX][vV][iI][dD]', + lambda: SCXviD.make_keyframes('video.mkv', 'keyframes.txt')) class ExternalChaptersTestCase(unittest.TestCase): diff --git a/tests/main.py b/tests/main.py index 17b0a8c..ebf380a 100644 --- a/tests/main.py +++ b/tests/main.py @@ -1,10 +1,11 @@ -from collections import namedtuple import os import re import unittest -from mock import patch, ANY -from common import SushiError, format_time +from unittest.mock import patch, ANY + +from sushi.common import SushiError, format_time import sushi +from sushi import __main__ as main here = os.path.dirname(os.path.abspath(__file__)) @@ -33,7 +34,7 @@ def __eq__(self, other): class InterpolateNonesTestCase(unittest.TestCase): def test_returns_empty_array_when_passed_empty_array(self): - self.assertEquals(sushi.interpolate_nones([], []), []) + self.assertEqual(sushi.interpolate_nones([], []), []) def test_returns_false_when_no_valid_points(self): self.assertFalse(sushi.interpolate_nones([None, None, None], [1, 2, 3])) @@ -107,24 +108,24 @@ def test_events_in_two_groups_one_chapter(self): events = [FakeEvent(end=1), FakeEvent(end=2), FakeEvent(end=3)] groups = sushi.groups_from_chapters(events, [0.0, 1.5]) self.assertEqual(2, len(groups)) - self.assertItemsEqual([events[0]], groups[0]) - self.assertItemsEqual([events[1], events[2]], groups[1]) + self.assertEqual([events[0]], groups[0]) + self.assertEqual([events[1], events[2]], groups[1]) def test_multiple_groups_multiple_chapters(self): - events = [FakeEvent(end=x) for x in xrange(1, 10)] + events = [FakeEvent(end=x) for x in range(1, 10)] groups = sushi.groups_from_chapters(events, [0.0, 3.2, 4.4, 7.7]) self.assertEqual(4, len(groups)) - self.assertItemsEqual(events[0:3], groups[0]) - self.assertItemsEqual(events[3:4], groups[1]) - self.assertItemsEqual(events[4:7], groups[2]) - self.assertItemsEqual(events[7:9], groups[3]) + self.assertEqual(events[0:3], groups[0]) + self.assertEqual(events[3:4], groups[1]) + self.assertEqual(events[4:7], groups[2]) + self.assertEqual(events[7:9], groups[3]) class SplitBrokenGroupsTestCase(unittest.TestCase): def test_doing_nothing_on_correct_groups(self): groups = [[FakeEvent(0.5), FakeEvent(0.5)], [FakeEvent(10.0)]] fixed = sushi.split_broken_groups(groups) - self.assertItemsEqual(groups, fixed) + self.assertEqual(groups, fixed) def test_split_groups_without_merging(self): groups = [ @@ -132,7 +133,7 @@ def test_split_groups_without_merging(self): [FakeEvent(0.5)] * 10, ] fixed = sushi.split_broken_groups(groups) - self.assertItemsEqual([ + self.assertEqual([ [FakeEvent(0.5)] * 10, [FakeEvent(10.0)] * 5, [FakeEvent(0.5)] * 10 @@ -144,7 +145,7 @@ def test_split_groups_with_merging(self): [FakeEvent(10.0), FakeEvent(10.0), FakeEvent(15.0)] ] fixed = sushi.split_broken_groups(groups) - self.assertItemsEqual([ + self.assertEqual([ [FakeEvent(0.5)], [FakeEvent(10.0), FakeEvent(10.0), FakeEvent(10.0)], [FakeEvent(15.0)] @@ -192,7 +193,7 @@ def test_checks_that_files_exist(self, mock_object): '--dst-keyframes', 'dst-keyframes', '--src-keyframes', 'src-keyframes', '--src-timecodes', 'src-tcs', '--dst-timecodes', 'dst-tcs'] try: - sushi.parse_args_and_run(keys) + main.parse_args_and_run(keys) except SushiError: pass mock_object.assert_any_call('src', ANY) @@ -206,16 +207,16 @@ def test_checks_that_files_exist(self, mock_object): def test_raises_on_unknown_script_type(self, ignore): keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.mp4'] - self.assertRaisesRegexp(SushiError, self.any_case_regex(r'script.*type'), lambda: sushi.parse_args_and_run(keys)) + self.assertRaisesRegex(SushiError, self.any_case_regex(r'script.*type'), lambda: main.parse_args_and_run(keys)) def test_raises_on_script_type_not_matching(self, ignore): keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.ass', '-o', 'd.srt'] - self.assertRaisesRegexp(SushiError, self.any_case_regex(r'script.*type.*match'), - lambda: sushi.parse_args_and_run(keys)) + self.assertRaisesRegex(SushiError, self.any_case_regex(r'script.*type.*match'), + lambda: main.parse_args_and_run(keys)) def test_raises_on_timecodes_and_fps_being_defined_together(self, ignore): keys = ['--src', 's.wav', '--dst', 'd.wav', '--script', 's.ass', '--src-timecodes', 'tc.txt', '--src-fps', '25'] - self.assertRaisesRegexp(SushiError, self.any_case_regex(r'timecodes'), lambda: sushi.parse_args_and_run(keys)) + self.assertRaisesRegex(SushiError, self.any_case_regex(r'timecodes'), lambda: main.parse_args_and_run(keys)) class FormatTimeTestCase(unittest.TestCase): diff --git a/tests/subtitles.py b/tests/subtitles.py index e442756..fe21fec 100644 --- a/tests/subtitles.py +++ b/tests/subtitles.py @@ -2,7 +2,8 @@ import tempfile import os import codecs -from subs import AssEvent, AssScript, SrtEvent, SrtScript + +from sushi.subs import AssEvent, AssScript, SrtEvent, SrtScript SINGLE_LINE_SRT_EVENT = """1 00:14:21,960 --> 00:14:22,960 @@ -22,45 +23,45 @@ class SrtEventTestCase(unittest.TestCase): def test_simple_parsing(self): event = SrtEvent.from_string(SINGLE_LINE_SRT_EVENT) - self.assertEquals(14*60+21.960, event.start) - self.assertEquals(14*60+22.960, event.end) - self.assertEquals("HOW DID IT END UP LIKE THIS?", event.text) + self.assertEqual(14 * 60 + 21.960, event.start) + self.assertEqual(14 * 60 + 22.960, event.end) + self.assertEqual("HOW DID IT END UP LIKE THIS?", event.text) def test_multi_line_event_parsing(self): event = SrtEvent.from_string(MULTILINE_SRT_EVENT) - self.assertEquals(13*60+12.140, event.start) - self.assertEquals(13*60+14.100, event.end) - self.assertEquals("APPEARANCE!\nAppearrance (teisai)!\nNo wait, you're the worst (saitei)!", event.text) + self.assertEqual(13 * 60 + 12.140, event.start) + self.assertEqual(13 * 60 + 14.100, event.end) + self.assertEqual("APPEARANCE!\nAppearrance (teisai)!\nNo wait, you're the worst (saitei)!", event.text) def test_parsing_and_printing(self): - self.assertEquals(SINGLE_LINE_SRT_EVENT, unicode(SrtEvent.from_string(SINGLE_LINE_SRT_EVENT))) - self.assertEquals(MULTILINE_SRT_EVENT, unicode(SrtEvent.from_string(MULTILINE_SRT_EVENT))) + self.assertEqual(SINGLE_LINE_SRT_EVENT, str(SrtEvent.from_string(SINGLE_LINE_SRT_EVENT))) + self.assertEqual(MULTILINE_SRT_EVENT, str(SrtEvent.from_string(MULTILINE_SRT_EVENT))) class AssEventTestCase(unittest.TestCase): def test_simple_parsing(self): event = AssEvent(ASS_EVENT) self.assertFalse(event.is_comment) - self.assertEquals("Dialogue", event.kind) - self.assertEquals(18*60+50.98, event.start) - self.assertEquals(18*60+55.28, event.end) - self.assertEquals("0", event.layer) - self.assertEquals("Default", event.style) - self.assertEquals("", event.name) - self.assertEquals("0", event.margin_left) - self.assertEquals("0", event.margin_right) - self.assertEquals("0", event.margin_vertical) - self.assertEquals("", event.effect) - self.assertEquals("Are you trying to (ouch) crush it (ouch)\N like a (ouch) vise (ouch, ouch)?", event.text) + self.assertEqual("Dialogue", event.kind) + self.assertEqual(18 * 60 + 50.98, event.start) + self.assertEqual(18 * 60 + 55.28, event.end) + self.assertEqual("0", event.layer) + self.assertEqual("Default", event.style) + self.assertEqual("", event.name) + self.assertEqual("0", event.margin_left) + self.assertEqual("0", event.margin_right) + self.assertEqual("0", event.margin_vertical) + self.assertEqual("", event.effect) + self.assertEqual("Are you trying to (ouch) crush it (ouch)\\N like a (ouch) vise (ouch, ouch)?", event.text) def test_comment_parsing(self): event = AssEvent(ASS_COMMENT) self.assertTrue(event.is_comment) - self.assertEquals("Comment", event.kind) + self.assertEqual("Comment", event.kind) def test_parsing_and_printing(self): - self.assertEquals(ASS_EVENT, unicode(AssEvent(ASS_EVENT))) - self.assertEquals(ASS_COMMENT, unicode(AssEvent(ASS_COMMENT))) + self.assertEqual(ASS_EVENT, str(AssEvent(ASS_EVENT))) + self.assertEqual(ASS_COMMENT, str(AssEvent(ASS_COMMENT))) class ScriptTestBase(unittest.TestCase): @@ -77,10 +78,10 @@ def test_write_to_file(self): SrtScript(events).save_to_file(self.script_path) with open(self.script_path) as script: text = script.read() - self.assertEquals(SINGLE_LINE_SRT_EVENT + "\n\n" + MULTILINE_SRT_EVENT, text) + self.assertEqual(SINGLE_LINE_SRT_EVENT + "\n\n" + MULTILINE_SRT_EVENT, text) def test_read_from_file(self): - os.write(self.script_description, """1 + os.write(self.script_description, b"""1 00:00:17,500 --> 00:00:18,870 Yeah, really! @@ -97,18 +98,18 @@ def test_read_from_file(self): 00:00:21,250 --> 00:00:22,750 Serves you right.""") parsed = SrtScript.from_file(self.script_path).events - self.assertEquals(17.5, parsed[0].start) - self.assertEquals(18.87, parsed[0].end) - self.assertEquals("Yeah, really!", parsed[0].text) - self.assertEquals(17.5, parsed[1].start) - self.assertEquals(18.87, parsed[1].end) - self.assertEquals("", parsed[1].text) - self.assertEquals(17.5, parsed[2].start) - self.assertEquals(18.87, parsed[2].end) - self.assertEquals("House number\n35", parsed[2].text) - self.assertEquals(21.25, parsed[3].start) - self.assertEquals(22.75, parsed[3].end) - self.assertEquals("Serves you right.", parsed[3].text) + self.assertEqual(17.5, parsed[0].start) + self.assertEqual(18.87, parsed[0].end) + self.assertEqual("Yeah, really!", parsed[0].text) + self.assertEqual(17.5, parsed[1].start) + self.assertEqual(18.87, parsed[1].end) + self.assertEqual("", parsed[1].text) + self.assertEqual(17.5, parsed[2].start) + self.assertEqual(18.87, parsed[2].end) + self.assertEqual("House number\n35", parsed[2].text) + self.assertEqual(21.25, parsed[3].start) + self.assertEqual(22.75, parsed[3].end) + self.assertEqual("Serves you right.", parsed[3].text) class AssScriptTestCase(ScriptTestBase): @@ -120,7 +121,7 @@ def test_write_to_file(self): with codecs.open(self.script_path, encoding='utf-8-sig') as script: text = script.read() - self.assertEquals("""[V4+ Styles] + self.assertEqual("""[V4+ Styles] Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1 @@ -130,7 +131,7 @@ def test_write_to_file(self): {0}""".format(ASS_EVENT), text) def test_read_from_file(self): - text = """[Script Info] + text = b"""[Script Info] ; Script generated by Aegisub 3.1.1 Title: script title @@ -146,10 +147,10 @@ def test_read_from_file(self): os.write(self.script_description, text) script = AssScript.from_file(self.script_path) - self.assertEquals(["; Script generated by Aegisub 3.1.1", "Title: script title"], script.script_info) - self.assertEquals(["Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1", - "Style: Signs,Gentium Basic,40,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,0,0,2,10,10,10,1"], - script.styles) - self.assertEquals([1, 2], [x.source_index for x in script.events]) - self.assertEquals(u"Dialogue: 0,0:00:01.42,0:00:03.36,Default,,0000,0000,0000,,As you already know,", unicode(script.events[0])) - self.assertEquals(u"Dialogue: 0,0:00:03.36,0:00:05.93,Default,,0000,0000,0000,,I'm concerned about the hair on my nipples.", unicode(script.events[1])) + self.assertEqual(["; Script generated by Aegisub 3.1.1", "Title: script title"], script.script_info) + self.assertEqual(["Style: Default,Open Sans Semibold,36,&H00FFFFFF,&H000000FF,&H00020713,&H00000000,-1,0,0,0,100,100,0,0,1,1.7,0,2,0,0,28,1", + "Style: Signs,Gentium Basic,40,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,0,0,0,0,100,100,0,0,1,0,0,2,10,10,10,1"], + script.styles) + self.assertEqual([1, 2], [x.source_index for x in script.events]) + self.assertEqual("Dialogue: 0,0:00:01.42,0:00:03.36,Default,,0000,0000,0000,,As you already know,", str(script.events[0])) + self.assertEqual("Dialogue: 0,0:00:03.36,0:00:05.93,Default,,0000,0000,0000,,I'm concerned about the hair on my nipples.", str(script.events[1])) diff --git a/tests/timecodes.py b/tests/timecodes.py index 335006c..dcc3f6f 100644 --- a/tests/timecodes.py +++ b/tests/timecodes.py @@ -1,5 +1,6 @@ import unittest -from demux import Timecodes + +from sushi.demux import Timecodes class CfrTimecodesTestCase(unittest.TestCase): @@ -11,25 +12,25 @@ def test_get_frame_time_zero(self): def test_get_frame_time_sane(self): tcs = Timecodes.cfr(23.976) t = tcs.get_frame_time(10) - self.assertAlmostEqual(10.0/23.976, t) + self.assertAlmostEqual(10.0 / 23.976, t) def test_get_frame_time_insane(self): tcs = Timecodes.cfr(23.976) t = tcs.get_frame_time(100000) - self.assertAlmostEqual(100000.0/23.976, t) + self.assertAlmostEqual(100000.0 / 23.976, t) def test_get_frame_size(self): tcs = Timecodes.cfr(23.976) t1 = tcs.get_frame_size(0) t2 = tcs.get_frame_size(1000) - self.assertAlmostEqual(1.0/23.976, t1) + self.assertAlmostEqual(1.0 / 23.976, t1) self.assertAlmostEqual(t1, t2) def test_get_frame_number(self): - tcs = Timecodes.cfr(24000.0/1001.0) + tcs = Timecodes.cfr(24000.0 / 1001.0) self.assertEqual(tcs.get_frame_number(0), 0) self.assertEqual(tcs.get_frame_number(1145.353), 27461) - self.assertEqual(tcs.get_frame_number(1001.0/24000.0 * 1234567), 1234567) + self.assertEqual(tcs.get_frame_number(1001.0 / 24000.0 * 1234567), 1234567) class TimecodesTestCase(unittest.TestCase): @@ -37,9 +38,9 @@ def test_cfr_timecodes_v2(self): text = '# timecode format v2\n' + '\n'.join(str(1000 * x / 23.976) for x in range(0, 30000)) parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(0)) - self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(25)) - self.assertAlmostEqual(1.0/23.976*100, parsed.get_frame_time(100)) + self.assertAlmostEqual(1.0 / 23.976, parsed.get_frame_size(0)) + self.assertAlmostEqual(1.0 / 23.976, parsed.get_frame_size(25)) + self.assertAlmostEqual(1.0 / 23.976 * 100, parsed.get_frame_time(100)) self.assertEqual(0, parsed.get_frame_time(0)) self.assertEqual(0, parsed.get_frame_number(0)) self.assertEqual(27461, parsed.get_frame_number(1145.353)) @@ -47,9 +48,9 @@ def test_cfr_timecodes_v2(self): def test_cfr_timecodes_v1(self): text = '# timecode format v1\nAssume 23.976024' parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/23.976024, parsed.get_frame_size(0)) - self.assertAlmostEqual(1.0/23.976024, parsed.get_frame_size(25)) - self.assertAlmostEqual(1.0/23.976024*100, parsed.get_frame_time(100)) + self.assertAlmostEqual(1.0 / 23.976024, parsed.get_frame_size(0)) + self.assertAlmostEqual(1.0 / 23.976024, parsed.get_frame_size(25)) + self.assertAlmostEqual(1.0 / 23.976024 * 100, parsed.get_frame_time(100)) self.assertEqual(0, parsed.get_frame_time(0)) self.assertEqual(0, parsed.get_frame_number(0)) self.assertEqual(27461, parsed.get_frame_number(1145.353)) @@ -57,30 +58,30 @@ def test_cfr_timecodes_v1(self): def test_cfr_timecodes_v1_with_overrides(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,23.976000\n3000,5000,23.976000' parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(0)) - self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(25)) - self.assertAlmostEqual(1.0/23.976*100, parsed.get_frame_time(100)) + self.assertAlmostEqual(1.0 / 23.976, parsed.get_frame_size(0)) + self.assertAlmostEqual(1.0 / 23.976, parsed.get_frame_size(25)) + self.assertAlmostEqual(1.0 / 23.976 * 100, parsed.get_frame_time(100)) self.assertEqual(0, parsed.get_frame_time(0)) def test_vfr_timecodes_v1_frame_size_at_first_frame(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/29.97, parsed.get_frame_size(timestamp=0)) + self.assertAlmostEqual(1.0 / 29.97, parsed.get_frame_size(timestamp=0)) def test_vfr_timecodes_v1_frame_size_outside_of_defined_range(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(timestamp=5000.0)) + self.assertAlmostEqual(1.0 / 23.976, parsed.get_frame_size(timestamp=5000.0)) def test_vft_timecodes_v1_frame_size_inside_override_block(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/29.97, parsed.get_frame_size(timestamp=49.983)) + self.assertAlmostEqual(1.0 / 29.97, parsed.get_frame_size(timestamp=49.983)) def test_vft_timecodes_v1_frame_size_between_override_blocks(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' parsed = Timecodes.parse(text) - self.assertAlmostEqual(1.0/23.976, parsed.get_frame_size(timestamp=87.496)) + self.assertAlmostEqual(1.0 / 23.976, parsed.get_frame_size(timestamp=87.496)) def test_vfr_timecodes_v1_frame_time_at_first_frame(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' @@ -101,5 +102,3 @@ def test_vft_timecodes_v1_frame_time_between_override_blocks(self): text = '# timecode format v1\nAssume 23.976000\n0,2000,29.970000\n3000,4000,59.940000' parsed = Timecodes.parse(text) self.assertAlmostEqual(87.579, parsed.get_frame_time(number=2500), places=3) - -