diff --git a/TTS/tts/layers/bark/model.py b/TTS/tts/layers/bark/model.py index c84022bd..da2aeba0 100644 --- a/TTS/tts/layers/bark/model.py +++ b/TTS/tts/layers/bark/model.py @@ -175,7 +175,39 @@ class GPT(nn.Module): n_params -= self.transformer.wpe.weight.numel() return n_params - def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): + def forward(self, idx): + device = idx.device + _, t = idx.size() + assert ( + t <= self.config.block_size + ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + breakpoint() + + # forward the GPT model itself + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + past_length = 0 + past_kv = tuple([None] * len(self.transformer.h)) + + position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) # shape (1, t) + assert position_ids.shape == (1, t) + pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd) + + x = self.transformer.drop(tok_emb + pos_emb) + + for _, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)): + x, kv = block(x, past_kv=past_layer_kv, use_cache=False) + + x = self.transformer.ln_f(x) + + # inference-time mini-optimization: only forward the lm_head on the very last position + logits = self.lm_head(x) # note: using list [-1] to preserve the time dim + + return logits + + def inference(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): device = idx.device _, t = idx.size() if past_kv is not None: