Skip to content

Commit b3c56b2

Browse files
committed
Merge branch 'dev' into sd3
2 parents aaa26bb + 583ab27 commit b3c56b2

File tree

2 files changed

+191
-1
lines changed

2 files changed

+191
-1
lines changed

library/jpeg_xl_util.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Modified from https://github.com/Fraetor/jxl_decode Original license: MIT
2+
# Added partial read support for up to 200x speedup
3+
4+
import os
5+
from typing import List, Tuple
6+
7+
class JXLBitstream:
8+
"""
9+
A stream of bits with methods for easy handling.
10+
"""
11+
12+
def __init__(self, file, offset: int = 0, offsets: List[List[int]] = None):
13+
self.shift = 0
14+
self.bitstream = bytearray()
15+
self.file = file
16+
self.offset = offset
17+
self.offsets = offsets
18+
if self.offsets:
19+
self.offset = self.offsets[0][1]
20+
self.previous_data_len = 0
21+
self.index = 0
22+
self.file.seek(self.offset)
23+
24+
def get_bits(self, length: int = 1) -> int:
25+
if self.offsets and self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
26+
self.partial_to_read_length = length
27+
if self.shift < self.previous_data_len + self.offsets[self.index][2]:
28+
self.partial_read(0, length)
29+
self.bitstream.extend(self.file.read(self.partial_to_read_length))
30+
else:
31+
self.bitstream.extend(self.file.read(length))
32+
bitmask = 2**length - 1
33+
bits = (int.from_bytes(self.bitstream, "little") >> self.shift) & bitmask
34+
self.shift += length
35+
return bits
36+
37+
def partial_read(self, current_length: int, length: int) -> None:
38+
self.previous_data_len += self.offsets[self.index][2]
39+
to_read_length = self.previous_data_len - (self.shift + current_length)
40+
self.bitstream.extend(self.file.read(to_read_length))
41+
current_length += to_read_length
42+
self.partial_to_read_length -= to_read_length
43+
self.index += 1
44+
self.file.seek(self.offsets[self.index][1])
45+
if self.shift + length > self.previous_data_len + self.offsets[self.index][2]:
46+
self.partial_read(current_length, length)
47+
48+
49+
def decode_codestream(file, offset: int = 0, offsets: List[List[int]] = None) -> Tuple[int,int]:
50+
"""
51+
Decodes the actual codestream.
52+
JXL codestream specification: http://www-internal/2022/18181-1
53+
"""
54+
55+
# Convert codestream to int within an object to get some handy methods.
56+
codestream = JXLBitstream(file, offset=offset, offsets=offsets)
57+
58+
# Skip signature
59+
codestream.get_bits(16)
60+
61+
# SizeHeader
62+
div8 = codestream.get_bits(1)
63+
if div8:
64+
height = 8 * (1 + codestream.get_bits(5))
65+
else:
66+
distribution = codestream.get_bits(2)
67+
match distribution:
68+
case 0:
69+
height = 1 + codestream.get_bits(9)
70+
case 1:
71+
height = 1 + codestream.get_bits(13)
72+
case 2:
73+
height = 1 + codestream.get_bits(18)
74+
case 3:
75+
height = 1 + codestream.get_bits(30)
76+
ratio = codestream.get_bits(3)
77+
if div8 and not ratio:
78+
width = 8 * (1 + codestream.get_bits(5))
79+
elif not ratio:
80+
distribution = codestream.get_bits(2)
81+
match distribution:
82+
case 0:
83+
width = 1 + codestream.get_bits(9)
84+
case 1:
85+
width = 1 + codestream.get_bits(13)
86+
case 2:
87+
width = 1 + codestream.get_bits(18)
88+
case 3:
89+
width = 1 + codestream.get_bits(30)
90+
else:
91+
match ratio:
92+
case 1:
93+
width = height
94+
case 2:
95+
width = (height * 12) // 10
96+
case 3:
97+
width = (height * 4) // 3
98+
case 4:
99+
width = (height * 3) // 2
100+
case 5:
101+
width = (height * 16) // 9
102+
case 6:
103+
width = (height * 5) // 4
104+
case 7:
105+
width = (height * 2) // 1
106+
return width, height
107+
108+
109+
def decode_container(file) -> Tuple[int,int]:
110+
"""
111+
Parses the ISOBMFF container, extracts the codestream, and decodes it.
112+
JXL container specification: http://www-internal/2022/18181-2
113+
"""
114+
115+
def parse_box(file, file_start: int) -> dict:
116+
file.seek(file_start)
117+
LBox = int.from_bytes(file.read(4), "big")
118+
XLBox = None
119+
if 1 < LBox <= 8:
120+
raise ValueError(f"Invalid LBox at byte {file_start}.")
121+
if LBox == 1:
122+
file.seek(file_start + 8)
123+
XLBox = int.from_bytes(file.read(8), "big")
124+
if XLBox <= 16:
125+
raise ValueError(f"Invalid XLBox at byte {file_start}.")
126+
if XLBox:
127+
header_length = 16
128+
box_length = XLBox
129+
else:
130+
header_length = 8
131+
if LBox == 0:
132+
box_length = os.fstat(file.fileno()).st_size - file_start
133+
else:
134+
box_length = LBox
135+
file.seek(file_start + 4)
136+
box_type = file.read(4)
137+
file.seek(file_start)
138+
return {
139+
"length": box_length,
140+
"type": box_type,
141+
"offset": header_length,
142+
}
143+
144+
file.seek(0)
145+
# Reject files missing required boxes. These two boxes are required to be at
146+
# the start and contain no values, so we can manually check there presence.
147+
# Signature box. (Redundant as has already been checked.)
148+
if file.read(12) != bytes.fromhex("0000000C 4A584C20 0D0A870A"):
149+
raise ValueError("Invalid signature box.")
150+
# File Type box.
151+
if file.read(20) != bytes.fromhex(
152+
"00000014 66747970 6A786C20 00000000 6A786C20"
153+
):
154+
raise ValueError("Invalid file type box.")
155+
156+
offset = 0
157+
offsets = []
158+
data_offset_not_found = True
159+
container_pointer = 32
160+
file_size = os.fstat(file.fileno()).st_size
161+
while data_offset_not_found:
162+
box = parse_box(file, container_pointer)
163+
match box["type"]:
164+
case b"jxlc":
165+
offset = container_pointer + box["offset"]
166+
data_offset_not_found = False
167+
case b"jxlp":
168+
file.seek(container_pointer + box["offset"])
169+
index = int.from_bytes(file.read(4), "big")
170+
offsets.append([index, container_pointer + box["offset"] + 4, box["length"] - box["offset"] - 4])
171+
container_pointer += box["length"]
172+
if container_pointer >= file_size:
173+
data_offset_not_found = False
174+
175+
if offsets:
176+
offsets.sort(key=lambda i: i[0])
177+
file.seek(0)
178+
179+
return decode_codestream(file, offset=offset, offsets=offsets)
180+
181+
182+
def get_jxl_size(path: str) -> Tuple[int,int]:
183+
with open(path, "rb") as file:
184+
if file.read(2) == bytes.fromhex("FF0A"):
185+
return decode_codestream(file)
186+
return decode_container(file)

library/train_util.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,16 @@
113113
# JPEG-XL on Linux
114114
try:
115115
from jxlpy import JXLImagePlugin
116+
from library.jpeg_xl_util import get_jxl_size
116117

117118
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
118119
except:
119120
pass
120121

121-
# JPEG-XL on Windows
122+
# JPEG-XL on Linux and Windows
122123
try:
123124
import pillow_jxl
125+
from library.jpeg_xl_util import get_jxl_size
124126

125127
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
126128
except:
@@ -1463,6 +1465,8 @@ def cache_text_encoder_outputs_common(
14631465
)
14641466

14651467
def get_image_size(self, image_path):
1468+
if image_path.endswith(".jxl") or image_path.endswith(".JXL"):
1469+
return get_jxl_size(image_path)
14661470
# return imagesize.get(image_path)
14671471
image_size = imagesize.get(image_path)
14681472
if image_size[0] <= 0:

0 commit comments

Comments
 (0)