mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-23 07:09:35 +01:00
22 lines
760 B
Python
22 lines
760 B
Python
|
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
from collections import defaultdict
|
||
|
|
|
||
|
|
|
||
|
|
def make_positions(tensor, padding_idx):
|
||
|
|
"""Replace non-padding symbols with their position numbers.
|
||
|
|
Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
||
|
|
"""
|
||
|
|
# The series of casts and type-conversions here are carefully
|
||
|
|
# balanced to both work with ONNX export and XLA. In particular XLA
|
||
|
|
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
|
||
|
|
# how to handle the dtype kwarg in cumsum.
|
||
|
|
mask = tensor.ne(padding_idx).int()
|
||
|
|
return (
|
||
|
|
torch.cumsum(mask, dim=1).type_as(mask) * mask
|
||
|
|
).long() + padding_idx
|
||
|
|
|
||
|
|
|
||
|
|
def softmax(x, dim):
|
||
|
|
return F.softmax(x, dim=dim, dtype=torch.float32)
|