simplify device placement

This commit is contained in:
Georg Kucsko
2023-04-25 17:49:35 -04:00
parent 97c6019ecd
commit 6c26fb7b34
2 changed files with 82 additions and 42 deletions

View File

@@ -200,7 +200,6 @@ class GPT(nn.Module):
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
new_kv = () if use_cache else None