diff --git a/examples/run_music_generation.py b/examples/run_music_generation.py index 84924e7..b8a7adc 100644 --- a/examples/run_music_generation.py +++ b/examples/run_music_generation.py @@ -1,6 +1,7 @@ from heartlib import HeartMuLaGenPipeline import argparse import torch +import time def str2bool(value): @@ -32,12 +33,13 @@ def str2device(value): def parse_args(): + ts = time.time_ns() parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--version", type=str, default="3B") parser.add_argument("--lyrics", type=str, default="./assets/lyrics.txt") parser.add_argument("--tags", type=str, default="./assets/tags.txt") - parser.add_argument("--save_path", type=str, default="./assets/output.mp3") + parser.add_argument("--save_path", type=str, default=f"./assets/output__SEED__{ts}.mp3") parser.add_argument("--max_audio_length_ms", type=int, default=240_000) parser.add_argument("--topk", type=int, default=50) @@ -47,7 +49,10 @@ def parse_args(): parser.add_argument("--codec_device", type=str2device, default="cuda") parser.add_argument("--mula_dtype", type=str2dtype, default="bfloat16") parser.add_argument("--codec_dtype", type=str2dtype, default="float32") - parser.add_argument("--lazy_load", type=str2bool, default=False) + parser.add_argument("--lazy_load", type=str2bool, default=True) + parser.add_argument("--quantize", type=str2bool, default=True) + parser.add_argument("--seed", type=int, default=None, + help="Random seed for reproducibility (default: random)") return parser.parse_args() @@ -67,7 +72,7 @@ def parse_args(): lazy_load=args.lazy_load, ) with torch.no_grad(): - pipe( + result = pipe( { "lyrics": args.lyrics, "tags": args.tags, @@ -77,5 +82,7 @@ def parse_args(): topk=args.topk, temperature=args.temperature, cfg_scale=args.cfg_scale, + seed=args.seed, ) print(f"Generated music saved to {args.save_path}") + print(f"Seed used: {result['seed']}") diff --git a/pyproject.toml b/pyproject.toml index 12a2815..4b3cded 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,35 +4,32 @@ build-backend = "setuptools.build_meta" [project] name = "heartlib" -version = "0.1.0" +version = "0.1.1" description = "A Python Library." readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.12" license = {text = "Apache-2.0"} -authors = [ - {name = "HeartMuLa Team", email = "heartmula.ai@gmail.com"} -] dependencies = [ - "numpy==2.0.2", - "torch==2.4.1", - "torchaudio==2.4.1", - "torchtune==0.4.0", - "torchao==0.9.0", - "torchvision==0.19.1", + "numpy==2.2.6", + "torch==2.8.0+cu128", + "torchaudio==2.8.0+cu128", + "torchtune==0.6.1", + "torchao==0.15.0", + "torchvision==0.23.0+cu128", "tqdm==4.67.1", - "traitlets==5.7.1", - "traittypes==0.2.3", - "transformers==4.57.0", - "tokenizers==0.22.1", - "ipykernel==6.17.1", + "traitlets==5.14.3", + #"traittypes==0.2.3", + "transformers==4.56.0", + "tokenizers==0.22.0", + #"ipykernel==6.17.1", "einops==0.8.1", - "accelerate==1.12.0", - "bitsandbytes==0.49.0", - "vector-quantize-pytorch==1.27.15", - "modelscope==1.33.0", - "soundfile" + "accelerate==1.10.1", + "bitsandbytes==0.49.1", + "vector-quantize-pytorch==1.27.19", + #"modelscope==1.33.0", + "soundfile==0.13.1" ] -urls = { "homepage" = "https://heartmula.github.io/" } +urls = { "homepage" = "https://github.com/nalexand/HeartMula-OPTIMIZED-8GB" } classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent" diff --git a/src/heartlib/heartcodec/modeling_heartcodec.py b/src/heartlib/heartcodec/modeling_heartcodec.py index df79245..b0ca93e 100644 --- a/src/heartlib/heartcodec/modeling_heartcodec.py +++ b/src/heartlib/heartcodec/modeling_heartcodec.py @@ -141,6 +141,11 @@ def detokenize( ) latent_list.append(latents) + torch.cuda.synchronize() + del self.flow_matching + del codes + torch.cuda.empty_cache() + # latent_list = [l.float() for l in latent_list] latent_list[0] = latent_list[0][:, first_latent_length:, :] min_samples = int(duration * self.sample_rate) diff --git a/src/heartlib/heartmula/modeling_heartmula.py b/src/heartlib/heartmula/modeling_heartmula.py index 7079efa..b664e7c 100644 --- a/src/heartlib/heartmula/modeling_heartmula.py +++ b/src/heartlib/heartmula/modeling_heartmula.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import torch.nn.functional as F from .configuration_heartmula import HeartMuLaConfig from transformers.modeling_utils import PreTrainedModel import torch @@ -8,6 +9,89 @@ from torchtune.models import llama3_2 +class FP8Linear(nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + + if not hasattr(torch, 'float8_e4m3fn'): + raise ImportError("PyTorch 2.1+ required for Float8 support.") + + factory_kwargs = {'device': device, 'dtype': dtype or torch.bfloat16} + init_weight = torch.empty((out_features, in_features), **factory_kwargs) + nn.init.kaiming_uniform_(init_weight, a=5 ** 0.5) + + self.weight = nn.Parameter(init_weight.to(torch.float8_e4m3fn), requires_grad=True) + + if bias: + self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(init_weight) + bound = 1 / (fan_in ** 0.5) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + else: + self.register_parameter('bias', None) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + comp_dtype = input.dtype + weight_bf16 = self.weight.to(comp_dtype) + return F.linear(input, weight_bf16, self.bias) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) + + def reset_parameters(self): + # Helper to re-init if needed + # We must init in high precision then cast down + temp_weight = torch.empty((self.out_features, self.in_features), dtype=torch.bfloat16) + nn.init.kaiming_uniform_(temp_weight, a=5 ** 0.5) + with torch.no_grad(): + self.weight.copy_(temp_weight.to(torch.float8_e4m3fn)) + + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(temp_weight) + bound = 1 / (fan_in ** 0.5) if fan_in > 0 else 0 + nn.init.uniform_(self.bias, -bound, bound) + + + @property + def shape(self): + """Allows accessing .shape like a tensor""" + return self.weight.shape + + +def replace_linear_with_fp8(model): + """ + Recursively replaces nn.Linear with FP8Linear. + Transfers existing weights (if any) to the new layers. + """ + for name, module in model.named_children(): + if isinstance(module, nn.Linear): + new_layer = FP8Linear( + in_features=module.in_features, + out_features=module.out_features, + bias=(module.bias is not None), + device=module.weight.device + ) + with torch.no_grad(): + new_layer.weight.data.copy_(module.weight.data.to(torch.float8_e4m3fn)) + if module.bias is not None: + new_layer.bias.data.copy_(module.bias.data) + + setattr(model, name, new_layer) + else: + replace_linear_with_fp8(module) + + def llama3_2_3B() -> torchtune.modules.transformer.TransformerDecoder: return llama3_2.llama3_2( vocab_size=128_256, @@ -84,6 +168,9 @@ def _prepare_transformer(model): embed_dim = model.tok_embeddings.embedding_dim model.tok_embeddings = nn.Identity() model.output = nn.Identity() + + replace_linear_with_fp8(model) + return model, embed_dim @@ -178,6 +265,7 @@ def setup_caches(self, max_batch_size: int): _create_causal_mask(self.config.audio_num_codebooks, device), ) + @torch.inference_mode() def generate_frame( self, tokens: torch.Tensor, diff --git a/src/heartlib/pipelines/music_generation.py b/src/heartlib/pipelines/music_generation.py index c9111ff..cd47981 100644 --- a/src/heartlib/pipelines/music_generation.py +++ b/src/heartlib/pipelines/music_generation.py @@ -187,6 +187,7 @@ def _sanitize_parameters(self, **kwargs): "temperature": kwargs.get("temperature", 1.0), "topk": kwargs.get("topk", 50), "cfg_scale": kwargs.get("cfg_scale", 1.5), + "seed": kwargs.get("seed", None), } postprocess_kwargs = { "save_path": kwargs.get("save_path", "output.mp3"), @@ -264,6 +265,7 @@ def _cfg_cat(tensor: torch.Tensor, cfg_scale: float): "pos": _cfg_cat(torch.arange(prompt_len, dtype=torch.long), cfg_scale), } + @torch.inference_mode() def _forward( self, model_inputs: Dict[str, Any], @@ -271,7 +273,14 @@ def _forward( temperature: float, topk: int, cfg_scale: float, + seed: Optional[int] = None, ): + if seed is None: + seed = torch.randint(0, 2 ** 32, (1,)).item() + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + prompt_tokens = model_inputs["tokens"].to(self.mula_device) prompt_tokens_mask = model_inputs["tokens_mask"].to(self.mula_device) continuous_segment = model_inputs["muq_embed"].to(self.mula_device) @@ -333,13 +342,16 @@ def _pad_audio_token(token: torch.Tensor): frames.append(curr_token[0:1,]) frames = torch.stack(frames).permute(1, 2, 0).squeeze(0) self._unload() - return {"frames": frames} + return {"frames": frames, "seed": seed} def postprocess(self, model_outputs: Dict[str, Any], save_path: str): + seed = model_outputs.get("seed") frames = model_outputs["frames"].to(self.codec_device) wav = self.codec.detokenize(frames) self._unload() + save_path = save_path.replace("_SEED_", str(seed)) torchaudio.save(save_path, wav.to(torch.float32).cpu(), 48000) + return {"save_path": save_path, "seed": seed} def __call__(self, inputs: Dict[str, Any], **kwargs): preprocess_kwargs, forward_kwargs, postprocess_kwargs = ( @@ -347,7 +359,7 @@ def __call__(self, inputs: Dict[str, Any], **kwargs): ) model_inputs = self.preprocess(inputs, **preprocess_kwargs) model_outputs = self._forward(model_inputs, **forward_kwargs) - self.postprocess(model_outputs, **postprocess_kwargs) + return self.postprocess(model_outputs, **postprocess_kwargs) @classmethod def from_pretrained(