mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2026-04-03 09:46:24 +02:00
Add better voice clones and prepare for finetuning
This commit is contained in:
@@ -8,6 +8,8 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange, repeat, reduce
|
||||
SEMANTIC_PAD_TOKEN = 10_000
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
||||
@@ -165,7 +167,7 @@ class GPT(nn.Module):
|
||||
n_params -= self.transformer.wpe.weight.numel()
|
||||
return n_params
|
||||
|
||||
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
|
||||
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, labels=None):
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
if past_kv is not None:
|
||||
@@ -212,6 +214,21 @@ class GPT(nn.Module):
|
||||
|
||||
x = self.transformer.ln_f(x)
|
||||
|
||||
|
||||
if labels is not None:
|
||||
logits = self.lm_head(x)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.output_vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
return logits, loss
|
||||
|
||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
|
||||
|
||||
Reference in New Issue
Block a user