xtts-gguf / cpp /xtts_inference.cpp
bnewton-genmedlabs's picture
Initial GGUF implementation with C++ inference engine
4688879 verified
// xtts_inference.cpp - XTTS GGUF Inference Engine Implementation
#include "xtts_inference.h"
#include <ggml.h>
#include <ggml-alloc.h>
#include <ggml-backend.h>
#include <cmath>
#include <cstring>
#include <fstream>
#include <algorithm>
#include <random>
#include <sys/mman.h>
#include <fcntl.h>
#include <unistd.h>
namespace xtts {
// Constructor
XTTSInference::XTTSInference() {
// Initialize GGML backend
ggml_backend_load_all();
}
// Destructor
XTTSInference::~XTTSInference() {
// Clean up model resources
if (model.ctx) {
ggml_free(model.ctx);
}
if (model.backend) {
ggml_backend_free(model.backend);
}
if (model.buffer) {
ggml_backend_buffer_free(model.buffer);
}
if (allocr) {
ggml_gallocr_free(allocr);
}
// Unmap memory if using mmap
if (mapped_memory) {
munmap(mapped_memory, mapped_size);
}
}
XTTSModel::~XTTSModel() {
// Cleanup handled by parent XTTSInference
}
// Load model from GGUF file
bool XTTSInference::load_model(const std::string& model_path, bool use_mmap) {
std::cout << "Loading XTTS model from: " << model_path << std::endl;
if (!load_gguf_file(model_path, use_mmap)) {
return false;
}
// Create computation graph structure
create_computation_graph();
std::cout << "Model loaded successfully" << std::endl;
std::cout << " Vocab size: " << hparams.n_vocab << std::endl;
std::cout << " Embedding dim: " << hparams.n_embd << std::endl;
std::cout << " Layers: " << hparams.n_layer << std::endl;
std::cout << " Languages: " << hparams.n_languages << std::endl;
return true;
}
// Load GGUF file
bool XTTSInference::load_gguf_file(const std::string& path, bool use_mmap) {
// Read GGUF header
std::ifstream file(path, std::ios::binary);
if (!file) {
std::cerr << "Failed to open file: " << path << std::endl;
return false;
}
// Read magic and version
uint32_t magic, version;
file.read(reinterpret_cast<char*>(&magic), sizeof(magic));
file.read(reinterpret_cast<char*>(&version), sizeof(version));
if (magic != 0x46554747) { // "GGUF"
std::cerr << "Invalid GGUF magic number" << std::endl;
return false;
}
// Read metadata
uint64_t metadata_size;
file.read(reinterpret_cast<char*>(&metadata_size), sizeof(metadata_size));
std::vector<char> metadata_json(metadata_size);
file.read(metadata_json.data(), metadata_size);
// Parse metadata (simplified - would use proper JSON parser)
// For now, use default hyperparameters
// Read tensor count
uint64_t n_tensors;
file.read(reinterpret_cast<char*>(&n_tensors), sizeof(n_tensors));
// Initialize GGML context
size_t ctx_size = ggml_tensor_overhead() * n_tensors + (1 << 20); // 1MB extra
struct ggml_init_params params = {
.mem_size = ctx_size,
.mem_buffer = nullptr,
.no_alloc = true,
};
model.ctx = ggml_init(params);
if (!model.ctx) {
std::cerr << "Failed to initialize GGML context" << std::endl;
return false;
}
// Initialize backend (CPU by default, can use CUDA if available)
model.backend = ggml_backend_cpu_init();
if (!model.backend) {
std::cerr << "Failed to initialize backend" << std::endl;
return false;
}
// Memory map the file if requested
if (use_mmap) {
int fd = open(path.c_str(), O_RDONLY);
if (fd < 0) {
std::cerr << "Failed to open file for mmap" << std::endl;
return false;
}
// Get file size
off_t file_size = lseek(fd, 0, SEEK_END);
lseek(fd, 0, SEEK_SET);
// Memory map the file
mapped_memory = mmap(nullptr, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
mapped_size = file_size;
close(fd);
if (mapped_memory == MAP_FAILED) {
std::cerr << "Failed to mmap file" << std::endl;
mapped_memory = nullptr;
return false;
}
std::cout << "Memory-mapped " << (file_size / (1024*1024)) << " MB" << std::endl;
}
// Read and create tensors
for (size_t i = 0; i < n_tensors; ++i) {
// Read tensor name
uint32_t name_len;
file.read(reinterpret_cast<char*>(&name_len), sizeof(name_len));
std::string name(name_len, '\0');
file.read(&name[0], name_len);
// Read shape
uint32_t n_dims;
file.read(reinterpret_cast<char*>(&n_dims), sizeof(n_dims));
std::vector<int64_t> shape(n_dims);
for (uint32_t j = 0; j < n_dims; ++j) {
uint32_t dim;
file.read(reinterpret_cast<char*>(&dim), sizeof(dim));
shape[j] = dim;
}
// Read quantization type
uint32_t quant_type;
file.read(reinterpret_cast<char*>(&quant_type), sizeof(quant_type));
// Read data size
uint64_t data_size;
file.read(reinterpret_cast<char*>(&data_size), sizeof(data_size));
// Map GGML type
enum ggml_type type = GGML_TYPE_F32;
switch (quant_type) {
case 0: type = GGML_TYPE_F32; break;
case 1: type = GGML_TYPE_F16; break;
case 8: type = GGML_TYPE_Q8_0; break;
case 12: type = GGML_TYPE_Q4_K; break;
default: type = GGML_TYPE_F32; break;
}
// Create tensor
struct ggml_tensor* tensor = nullptr;
if (n_dims == 1) {
tensor = ggml_new_tensor_1d(model.ctx, type, shape[0]);
} else if (n_dims == 2) {
tensor = ggml_new_tensor_2d(model.ctx, type, shape[0], shape[1]);
} else if (n_dims == 3) {
tensor = ggml_new_tensor_3d(model.ctx, type, shape[0], shape[1], shape[2]);
} else if (n_dims == 4) {
tensor = ggml_new_tensor_4d(model.ctx, type, shape[0], shape[1], shape[2], shape[3]);
}
if (!tensor) {
std::cerr << "Failed to create tensor: " << name << std::endl;
file.seekg(data_size, std::ios::cur); // Skip data
continue;
}
// Set tensor name
ggml_set_name(tensor, name.c_str());
// Store tensor in model based on name
if (name.find("text_embedding") != std::string::npos) {
model.text_embedding = tensor;
} else if (name.find("language_embedding") != std::string::npos) {
model.language_embedding = tensor;
} else if (name.find("pos_encoding") != std::string::npos) {
model.pos_encoding = tensor;
} else if (name.find("audio_token_predictor") != std::string::npos) {
model.audio_token_predictor = tensor;
} else if (name.find("speaker_projection") != std::string::npos) {
model.speaker_projection = tensor;
} else if (name.find("vocoder_preconv") != std::string::npos) {
model.vocoder_preconv = tensor;
} else if (name.find("vocoder_postconv") != std::string::npos) {
model.vocoder_postconv = tensor;
}
// Add more tensor assignments as needed...
// Skip data for now (would load into tensor in real implementation)
file.seekg(data_size, std::ios::cur);
}
file.close();
// Allocate backend buffer for tensors
size_t buffer_size = ggml_backend_get_default_buffer_size(model.backend);
model.buffer = ggml_backend_alloc_buffer(model.backend, buffer_size);
return true;
}
// Create computation graph
void XTTSInference::create_computation_graph() {
// Initialize graph allocator
allocr = ggml_gallocr_new_from_backend(model.backend);
// Initialize KV cache
kv_cache.k_cache = ggml_new_tensor_3d(
model.ctx,
GGML_TYPE_F32,
hparams.n_embd,
hparams.n_ctx_text + hparams.n_ctx_audio,
hparams.n_layer
);
kv_cache.v_cache = ggml_new_tensor_3d(
model.ctx,
GGML_TYPE_F32,
hparams.n_embd,
hparams.n_ctx_text + hparams.n_ctx_audio,
hparams.n_layer
);
}
// Tokenize text (simplified byte-level tokenization)
std::vector<int32_t> XTTSInference::tokenize(const std::string& text) {
std::vector<int32_t> tokens;
tokens.reserve(text.length());
for (char c : text) {
// Simple byte-level tokenization
tokens.push_back(static_cast<unsigned char>(c));
}
// Pad or truncate to max length
if (tokens.size() > hparams.n_ctx_text) {
tokens.resize(hparams.n_ctx_text);
} else {
while (tokens.size() < hparams.n_ctx_text) {
tokens.push_back(0); // Padding token
}
}
return tokens;
}
// Create speaker embedding
std::vector<float> XTTSInference::create_speaker_embedding(int speaker_id) {
std::vector<float> embedding(hparams.speaker_emb_dim, 0.0f);
// Simple one-hot style encoding for demo
if (speaker_id >= 0 && speaker_id < hparams.speaker_emb_dim) {
embedding[speaker_id] = 1.0f;
}
// Add some random variation
std::mt19937 gen(speaker_id);
std::normal_distribution<float> dist(0.0f, 0.1f);
for (float& val : embedding) {
val += dist(gen);
}
return embedding;
}
// Encode text to features
struct ggml_tensor* XTTSInference::encode_text(
const std::vector<int32_t>& tokens,
Language language,
const std::vector<float>& speaker_embedding
) {
struct ggml_cgraph* gf = ggml_new_graph(model.ctx);
// Create input tensors
struct ggml_tensor* token_tensor = ggml_new_tensor_1d(
model.ctx, GGML_TYPE_I32, tokens.size()
);
memcpy(token_tensor->data, tokens.data(), tokens.size() * sizeof(int32_t));
// Get text embeddings
struct ggml_tensor* text_emb = ggml_get_rows(
model.ctx, model.text_embedding, token_tensor
);
// Add language embedding
struct ggml_tensor* lang_tensor = ggml_new_tensor_1d(
model.ctx, GGML_TYPE_I32, tokens.size()
);
for (size_t i = 0; i < tokens.size(); ++i) {
((int32_t*)lang_tensor->data)[i] = static_cast<int32_t>(language);
}
struct ggml_tensor* lang_emb = ggml_get_rows(
model.ctx, model.language_embedding, lang_tensor
);
// Combine embeddings
struct ggml_tensor* combined = ggml_add(model.ctx, text_emb, lang_emb);
// Add positional encoding
if (model.pos_encoding) {
struct ggml_tensor* pos = ggml_view_2d(
model.ctx, model.pos_encoding,
hparams.n_embd, tokens.size(),
hparams.n_embd * sizeof(float), 0
);
combined = ggml_add(model.ctx, combined, pos);
}
// Add speaker embedding if provided
if (!speaker_embedding.empty() && model.speaker_projection) {
struct ggml_tensor* spk_tensor = ggml_new_tensor_1d(
model.ctx, GGML_TYPE_F32, speaker_embedding.size()
);
memcpy(spk_tensor->data, speaker_embedding.data(),
speaker_embedding.size() * sizeof(float));
struct ggml_tensor* spk_proj = ggml_mul_mat(
model.ctx, model.speaker_projection, spk_tensor
);
// Broadcast and add to all positions
struct ggml_tensor* spk_expanded = ggml_repeat(
model.ctx, spk_proj,
ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, hparams.n_embd, tokens.size())
);
combined = ggml_add(model.ctx, combined, ggml_scale(model.ctx, spk_expanded, 0.1f));
}
// Process through transformer layers
struct ggml_tensor* hidden = combined;
for (int layer = 0; layer < hparams.n_layer; ++layer) {
// Self-attention
hidden = attention(hidden, layer, true);
// Feed-forward network
hidden = ffn(hidden, layer);
}
// Build and execute graph
ggml_build_forward_expand(gf, hidden);
ggml_gallocr_alloc_graph(allocr, gf);
// Run computation
ggml_backend_graph_compute(model.backend, gf);
return hidden;
}
// Attention mechanism
struct ggml_tensor* XTTSInference::attention(
struct ggml_tensor* x,
int layer_idx,
bool use_cache
) {
// Layer normalization
struct ggml_tensor* normalized = layer_norm(
x,
layer_idx < model.ln1_weight.size() ? model.ln1_weight[layer_idx] : nullptr,
layer_idx < model.ln1_bias.size() ? model.ln1_bias[layer_idx] : nullptr
);
// QKV projection
struct ggml_tensor* qkv = nullptr;
if (layer_idx < model.attn_qkv.size() && model.attn_qkv[layer_idx]) {
qkv = ggml_mul_mat(model.ctx, model.attn_qkv[layer_idx], normalized);
} else {
// Fallback if weights not loaded
qkv = normalized;
}
// Split into Q, K, V
int head_dim = hparams.n_embd / hparams.n_head;
struct ggml_tensor* q = ggml_view_3d(
model.ctx, qkv,
head_dim, hparams.n_head, x->ne[1],
head_dim * sizeof(float),
hparams.n_embd * sizeof(float),
0
);
struct ggml_tensor* k = ggml_view_3d(
model.ctx, qkv,
head_dim, hparams.n_head, x->ne[1],
head_dim * sizeof(float),
hparams.n_embd * sizeof(float),
hparams.n_embd * x->ne[1] * sizeof(float)
);
struct ggml_tensor* v = ggml_view_3d(
model.ctx, qkv,
head_dim, hparams.n_head, x->ne[1],
head_dim * sizeof(float),
hparams.n_embd * sizeof(float),
2 * hparams.n_embd * x->ne[1] * sizeof(float)
);
// Scaled dot-product attention
float scale = 1.0f / sqrtf(static_cast<float>(head_dim));
struct ggml_tensor* scores = ggml_mul_mat(model.ctx, k, q);
scores = ggml_scale(model.ctx, scores, scale);
// Apply causal mask
scores = ggml_diag_mask_inf(model.ctx, scores, 0);
// Softmax
struct ggml_tensor* attn_weights = ggml_soft_max(model.ctx, scores);
// Apply attention to values
struct ggml_tensor* attn_output = ggml_mul_mat(model.ctx, v, attn_weights);
// Reshape and project output
attn_output = ggml_cont(model.ctx, ggml_permute(
model.ctx, attn_output, 0, 2, 1, 3
));
attn_output = ggml_reshape_2d(
model.ctx, attn_output,
hparams.n_embd, x->ne[1]
);
if (layer_idx < model.attn_out.size() && model.attn_out[layer_idx]) {
attn_output = ggml_mul_mat(model.ctx, model.attn_out[layer_idx], attn_output);
}
// Residual connection
return ggml_add(model.ctx, x, attn_output);
}
// Feed-forward network
struct ggml_tensor* XTTSInference::ffn(
struct ggml_tensor* x,
int layer_idx
) {
// Layer normalization
struct ggml_tensor* normalized = layer_norm(
x,
layer_idx < model.ln2_weight.size() ? model.ln2_weight[layer_idx] : nullptr,
layer_idx < model.ln2_bias.size() ? model.ln2_bias[layer_idx] : nullptr
);
// FFN up projection
struct ggml_tensor* up = normalized;
if (layer_idx < model.ffn_up.size() && model.ffn_up[layer_idx]) {
up = ggml_mul_mat(model.ctx, model.ffn_up[layer_idx], normalized);
}
// Activation (GELU)
up = ggml_gelu(model.ctx, up);
// FFN down projection
if (layer_idx < model.ffn_down.size() && model.ffn_down[layer_idx]) {
up = ggml_mul_mat(model.ctx, model.ffn_down[layer_idx], up);
}
// Residual connection
return ggml_add(model.ctx, x, up);
}
// Layer normalization
struct ggml_tensor* XTTSInference::layer_norm(
struct ggml_tensor* x,
struct ggml_tensor* weight,
struct ggml_tensor* bias,
float eps
) {
struct ggml_tensor* normalized = ggml_norm(model.ctx, x, eps);
if (weight) {
normalized = ggml_mul(model.ctx, normalized, weight);
}
if (bias) {
normalized = ggml_add(model.ctx, normalized, bias);
}
return normalized;
}
// Generate audio tokens autoregressively
std::vector<int32_t> XTTSInference::generate_audio_tokens(
struct ggml_tensor* text_features,
float temperature
) {
std::vector<int32_t> audio_tokens;
audio_tokens.reserve(hparams.n_ctx_audio);
// Start with special start token
audio_tokens.push_back(0);
// Generate tokens autoregressively
for (int i = 0; i < hparams.n_ctx_audio; ++i) {
// Get logits for next token
struct ggml_tensor* logits = nullptr;
if (model.audio_token_predictor) {
// Use the last hidden state
struct ggml_tensor* last_hidden = ggml_view_1d(
model.ctx, text_features,
hparams.n_embd,
(text_features->ne[1] - 1) * hparams.n_embd * sizeof(float)
);
logits = ggml_mul_mat(model.ctx, model.audio_token_predictor, last_hidden);
} else {
// Fallback: random generation
logits = ggml_new_tensor_1d(model.ctx, GGML_TYPE_F32, hparams.n_audio_tokens);
for (int j = 0; j < hparams.n_audio_tokens; ++j) {
((float*)logits->data)[j] = static_cast<float>(rand()) / RAND_MAX;
}
}
// Sample next token
int32_t next_token = sample_token(logits, temperature);
audio_tokens.push_back(next_token);
// Check for end token
if (next_token == 1) { // Assuming 1 is end token
break;
}
}
return audio_tokens;
}
// Sample token from logits
int32_t XTTSInference::sample_token(
struct ggml_tensor* logits,
float temperature,
float top_p
) {
int n_vocab = logits->ne[0];
std::vector<float> probs(n_vocab);
// Apply temperature
for (int i = 0; i < n_vocab; ++i) {
probs[i] = ((float*)logits->data)[i] / temperature;
}
// Softmax
float max_logit = *std::max_element(probs.begin(), probs.end());
float sum = 0.0f;
for (float& p : probs) {
p = expf(p - max_logit);
sum += p;
}
for (float& p : probs) {
p /= sum;
}
// Top-p sampling
std::vector<std::pair<float, int>> prob_indices;
for (int i = 0; i < n_vocab; ++i) {
prob_indices.push_back({probs[i], i});
}
std::sort(prob_indices.begin(), prob_indices.end(), std::greater<>());
float cum_prob = 0.0f;
size_t cutoff = 0;
for (size_t i = 0; i < prob_indices.size(); ++i) {
cum_prob += prob_indices[i].first;
if (cum_prob >= top_p) {
cutoff = i + 1;
break;
}
}
// Renormalize
float norm_sum = 0.0f;
for (size_t i = 0; i < cutoff; ++i) {
norm_sum += prob_indices[i].first;
}
// Sample
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dist(0.0f, norm_sum);
float sample = dist(gen);
cum_prob = 0.0f;
for (size_t i = 0; i < cutoff; ++i) {
cum_prob += prob_indices[i].first;
if (cum_prob >= sample) {
return prob_indices[i].second;
}
}
return prob_indices[0].second;
}
// Vocoder forward pass
std::vector<float> XTTSInference::vocoder_forward(
const std::vector<int32_t>& audio_tokens
) {
// Convert tokens to mel spectrogram (simplified)
// In practice, would use learned codebook
size_t mel_frames = audio_tokens.size() / 2;
struct ggml_tensor* mel = ggml_new_tensor_3d(
model.ctx, GGML_TYPE_F32,
hparams.n_mel_channels, mel_frames, 1
);
// Fill with dummy mel values (would be from codebook in real implementation)
for (size_t i = 0; i < mel_frames; ++i) {
for (int j = 0; j < hparams.n_mel_channels; ++j) {
float value = (audio_tokens[i * 2] + audio_tokens[i * 2 + 1] * 256) / 65536.0f;
((float*)mel->data)[i * hparams.n_mel_channels + j] = value;
}
}
// Apply vocoder
struct ggml_tensor* audio = mel;
// Initial convolution
if (model.vocoder_preconv) {
audio = ggml_conv_1d(model.ctx, model.vocoder_preconv, audio, 1, 1, 1);
}
// Upsampling blocks
for (auto& layer : model.vocoder_ups) {
if (layer) {
audio = ggml_conv_transpose_1d(model.ctx, layer, audio, 2, 0, 1);
audio = ggml_leaky_relu(model.ctx, audio, 0.1f, true);
}
}
// Final convolution
if (model.vocoder_postconv) {
audio = ggml_conv_1d(model.ctx, model.vocoder_postconv, audio, 1, 1, 1);
audio = ggml_tanh(model.ctx, audio);
}
// Extract audio samples
size_t n_samples = audio->ne[0] * audio->ne[1];
std::vector<float> samples(n_samples);
memcpy(samples.data(), audio->data, n_samples * sizeof(float));
return samples;
}
// Main generation function
std::vector<float> XTTSInference::generate(
const std::string& text,
Language language,
int speaker_id,
float temperature,
float speed
) {
// Tokenize text
std::vector<int32_t> tokens = tokenize(text);
// Create speaker embedding
std::vector<float> speaker_embedding = create_speaker_embedding(speaker_id);
// Encode text to features
struct ggml_tensor* text_features = encode_text(
tokens, language, speaker_embedding
);
// Generate audio tokens
std::vector<int32_t> audio_tokens = generate_audio_tokens(
text_features, temperature
);
// Convert to audio waveform
std::vector<float> audio = vocoder_forward(audio_tokens);
// Apply speed adjustment
if (speed != 1.0f && speed > 0.0f) {
// Simple resampling for speed adjustment
size_t new_size = static_cast<size_t>(audio.size() / speed);
std::vector<float> resampled(new_size);
for (size_t i = 0; i < new_size; ++i) {
float src_idx = i * speed;
size_t idx0 = static_cast<size_t>(src_idx);
size_t idx1 = std::min(idx0 + 1, audio.size() - 1);
float frac = src_idx - idx0;
resampled[i] = audio[idx0] * (1.0f - frac) + audio[idx1] * frac;
}
audio = std::move(resampled);
}
return audio;
}
// Stream generator implementation
XTTSInference::StreamGenerator::StreamGenerator(
XTTSInference* parent,
const std::string& text,
Language lang
) : parent_model(parent), language(lang), done(false) {
// Tokenize text
text_tokens = parent_model->tokenize(text);
}
XTTSInference::StreamGenerator::~StreamGenerator() {
// Cleanup
}
void XTTSInference::StreamGenerator::generate_next_tokens(size_t n_tokens) {
// Generate next batch of audio tokens
// This would be implemented with proper streaming logic
for (size_t i = 0; i < n_tokens && audio_tokens.size() < parent_model->hparams.n_ctx_audio; ++i) {
audio_tokens.push_back(rand() % parent_model->hparams.n_audio_tokens);
}
}
std::vector<float> XTTSInference::StreamGenerator::get_next_chunk(size_t chunk_samples) {
if (done) {
return {};
}
// Generate more tokens if needed
if (current_token >= audio_tokens.size()) {
generate_next_tokens(50); // Generate 50 tokens at a time
}
// Convert tokens to audio
size_t tokens_for_chunk = std::min(
static_cast<size_t>(50),
audio_tokens.size() - current_token
);
if (tokens_for_chunk == 0) {
done = true;
return {};
}
std::vector<int32_t> chunk_tokens(
audio_tokens.begin() + current_token,
audio_tokens.begin() + current_token + tokens_for_chunk
);
current_token += tokens_for_chunk;
// Use vocoder to convert to audio
std::vector<float> audio_chunk = parent_model->vocoder_forward(chunk_tokens);
// Check if we're done
if (current_token >= parent_model->hparams.n_ctx_audio ||
current_token >= audio_tokens.size()) {
done = true;
}
return audio_chunk;
}
std::unique_ptr<XTTSInference::StreamGenerator> XTTSInference::create_stream(
const std::string& text,
Language language
) {
return std::make_unique<StreamGenerator>(this, text, language);
}
size_t XTTSInference::get_memory_usage() const {
size_t total = 0;
// Add context memory
if (model.ctx) {
total += ggml_used_mem(model.ctx);
}
// Add KV cache memory
if (kv_cache.k_cache) {
total += ggml_nbytes(kv_cache.k_cache);
}
if (kv_cache.v_cache) {
total += ggml_nbytes(kv_cache.v_cache);
}
// Add mapped memory (though it's not in RAM if properly mmap'd)
if (mapped_memory) {
// Only count as overhead, actual memory is demand-paged
total += sizeof(*this) + (1 << 20); // 1MB overhead estimate
}
return total;
}
// C API implementation
extern "C" {
void* xtts_init(const char* model_path, bool use_mmap) {
auto* model = new XTTSInference();
if (!model->load_model(model_path, use_mmap)) {
delete model;
return nullptr;
}
return model;
}
float* xtts_generate(
void* model_ptr,
const char* text,
int language,
int speaker_id,
float temperature,
float speed,
size_t* out_length
) {
if (!model_ptr || !text || !out_length) {
return nullptr;
}
auto* model = static_cast<XTTSInference*>(model_ptr);
auto audio = model->generate(
text,
static_cast<Language>(language),
speaker_id,
temperature,
speed
);
*out_length = audio.size();
float* result = new float[audio.size()];
memcpy(result, audio.data(), audio.size() * sizeof(float));
return result;
}
void* xtts_stream_init(
void* model_ptr,
const char* text,
int language
) {
if (!model_ptr || !text) {
return nullptr;
}
auto* model = static_cast<XTTSInference*>(model_ptr);
auto stream = model->create_stream(text, static_cast<Language>(language));
return stream.release();
}
float* xtts_stream_next(
void* stream_ptr,
size_t chunk_size,
size_t* out_length
) {
if (!stream_ptr || !out_length) {
return nullptr;
}
auto* stream = static_cast<XTTSInference::StreamGenerator*>(stream_ptr);
auto chunk = stream->get_next_chunk(chunk_size);
if (chunk.empty()) {
*out_length = 0;
return nullptr;
}
*out_length = chunk.size();
float* result = new float[chunk.size()];
memcpy(result, chunk.data(), chunk.size() * sizeof(float));
return result;
}
void xtts_stream_free(void* stream_ptr) {
if (stream_ptr) {
delete static_cast<XTTSInference::StreamGenerator*>(stream_ptr);
}
}
void xtts_free(void* model_ptr) {
if (model_ptr) {
delete static_cast<XTTSInference*>(model_ptr);
}
}
void xtts_free_audio(float* audio_ptr) {
delete[] audio_ptr;
}
} // extern "C"
} // namespace xtts