import os
import psutil
import torch
import gc
from transformers import AutoProcessor, MusicgenMelodyForConditionalGeneration, MusicgenMelodyConfig
import scipy
# https://huggingface.co/docs/transformers/main/model_doc/musicgen_melody
# Function to log memory usage
def log_memory(stage=""):
process = psutil.Process(os.getpid())
print(f"Memory Usage after {stage}: {process.memory_info().rss / 1024 ** 2} MB")
log_memory("initial load")
# Hugging Face token for authentication
token = "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"
# Load model configuration and manually add missing config attributes
#model_name = "facebook/musicgen-small" # Use smaller variants if available
model_name = "facebook/musicgen-melody" # For better output
config = MusicgenMelodyConfig.from_pretrained(model_name, token=token)
# Manually add the missing 'use_cache' attribute
config.use_cache = False # This should resolve the AttributeError you encountered
# Manually add the missing initializer_factor if it's required
config.initializer_factor = 1.0 # Default value for initialization
# Modify configuration parameters for debugging
config.dropout = 0.1
config.layerdrop = 0.1
config.max_position_embeddings = 512 # Reduced
config.hidden_size = 128 # Smaller hidden size
config.num_codebooks = 128 # Adjusted to a smaller number for compatibility
config.scale_embedding = True
config.vocab_size = 50257
config.num_hidden_layers = 2 # Fewer layers
config.num_attention_heads = 4 # Fewer attention heads
config.attention_dropout = 0.1
config.activation_function = "gelu"
config.activation_dropout = 0.1
config.ffn_dim = 1024
log_memory("after config")
# Load the model
model = MusicgenMelodyForConditionalGeneration.from_pretrained(model_name, config=config, token=token).eval()
log_memory("after model loaded")
# Processor for the model
processor = AutoProcessor.from_pretrained(model_name)
# Ensure proper input shape by padding to the required size
prompt = "A relaxing jazz track with piano and bass."
input_ids = processor(
text=[prompt],
padding=True,
return_tensors="pt",
).to(model.device)
# Check the shape after reshaping
print(f"Input tensor shape after reshaping: {input_ids['input_ids'].shape}")
# Generate audio based on input prompt with no_grad to save memory
with torch.no_grad():
generated_audio = model.generate(**input_ids, max_new_tokens=1024)
print(generated_audio)
log_memory("after generation")
# Check type of the audio data
print(f"Type of generated audio: {type(generated_audio)}")
# Save the generated audio to a file
if isinstance(generated_audio, torch.Tensor):
sampling_rate = model.config.audio_encoder.sampling_rate
scipy.io.wavfile.write("generated_music.wav", rate=sampling_rate, data=generated_audio.to("cpu")[0, 0].numpy())
else:
print("Unexpected audio format, unable to save.")
# Cleanup
del generated_audio # Explicitly delete the variable
gc.collect() # Garbage collection
log_memory("after cleanup")