|
@@ -219,6 +219,8 @@ def convert(
|
|
|
output_path: str,
|
|
output_path: str,
|
|
|
config_path: str,
|
|
config_path: str,
|
|
|
dtype: str = "fp32",
|
|
dtype: str = "fp32",
|
|
|
|
|
+ name: str = None,
|
|
|
|
|
+ description: str = None,
|
|
|
):
|
|
):
|
|
|
"""
|
|
"""
|
|
|
Convert PyTorch checkpoint to GGUF format.
|
|
Convert PyTorch checkpoint to GGUF format.
|
|
@@ -253,8 +255,10 @@ def convert(
|
|
|
print("Writing metadata...")
|
|
print("Writing metadata...")
|
|
|
|
|
|
|
|
# General metadata
|
|
# General metadata
|
|
|
- gguf_writer.add_name("Mel-Band-Roformer Vocal Separator")
|
|
|
|
|
- gguf_writer.add_description("Audio source separation model for vocal extraction")
|
|
|
|
|
|
|
+ model_name = name if name else "Mel-Band-Roformer Separator"
|
|
|
|
|
+ model_description = description if description else "Music source separation model"
|
|
|
|
|
+ gguf_writer.add_name(model_name)
|
|
|
|
|
+ gguf_writer.add_description(model_description)
|
|
|
|
|
|
|
|
# Determine types
|
|
# Determine types
|
|
|
target_qtype = get_target_quantization_type(dtype)
|
|
target_qtype = get_target_quantization_type(dtype)
|
|
@@ -464,6 +468,18 @@ Examples:
|
|
|
],
|
|
],
|
|
|
help="Target quantization type. Norms/Biases will be kept as F32. (K-Quants not supported due to dim=384)",
|
|
help="Target quantization type. Norms/Biases will be kept as F32. (K-Quants not supported due to dim=384)",
|
|
|
)
|
|
)
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--name",
|
|
|
|
|
+ type=str,
|
|
|
|
|
+ default=None,
|
|
|
|
|
+ help="Model name (default: 'Mel-Band-Roformer Vocal Separator')",
|
|
|
|
|
+ )
|
|
|
|
|
+ parser.add_argument(
|
|
|
|
|
+ "--description",
|
|
|
|
|
+ type=str,
|
|
|
|
|
+ default=None,
|
|
|
|
|
+ help="Model description (default: 'Audio source separation model for vocal extraction')",
|
|
|
|
|
+ )
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
- convert(args.ckpt, args.out, args.config, args.dtype)
|
|
|
|
|
|
|
+ convert(args.ckpt, args.out, args.config, args.dtype, args.name, args.description)
|