Add better voice clones and prepare for finetuning

This commit is contained in:
Francis LaBounty
2023-05-25 16:24:41 -06:00
parent 0b16a49fe2
commit 40afeec9c0
12 changed files with 1050 additions and 37 deletions

View File

@@ -8,6 +8,8 @@ from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange, repeat, reduce
SEMANTIC_PAD_TOKEN = 10_000
class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
@@ -165,7 +167,7 @@ class GPT(nn.Module):
n_params -= self.transformer.wpe.weight.numel()
return n_params
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, labels=None):
device = idx.device
b, t = idx.size()
if past_kv is not None:
@@ -212,6 +214,21 @@ class GPT(nn.Module):
x = self.transformer.ln_f(x)
if labels is not None:
logits = self.lm_head(x)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.output_vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return logits, loss
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim