mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 11:57:58 +01:00
50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
|
|
import json
|
||
|
|
import random
|
||
|
|
import argparse
|
||
|
|
import numpy as np
|
||
|
|
from tqdm import tqdm
|
||
|
|
from h5py import File
|
||
|
|
import sklearn.metrics
|
||
|
|
|
||
|
|
random.seed(1)
|
||
|
|
|
||
|
|
parser = argparse.ArgumentParser()
|
||
|
|
parser.add_argument("train_feature", type=str)
|
||
|
|
parser.add_argument("train_corpus", type=str)
|
||
|
|
parser.add_argument("pred_feature", type=str)
|
||
|
|
parser.add_argument("output_json", type=str)
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
train_embs = []
|
||
|
|
train_idx_to_audioid = []
|
||
|
|
with File(args.train_feature, "r") as store:
|
||
|
|
for audio_id, embedding in tqdm(store.items(), ascii=True):
|
||
|
|
train_embs.append(embedding[()])
|
||
|
|
train_idx_to_audioid.append(audio_id)
|
||
|
|
|
||
|
|
train_annotation = json.load(open(args.train_corpus, "r"))["audios"]
|
||
|
|
train_audioid_to_tokens = {}
|
||
|
|
for item in train_annotation:
|
||
|
|
audio_id = item["audio_id"]
|
||
|
|
train_audioid_to_tokens[audio_id] = [cap_item["tokens"] for cap_item in item["captions"]]
|
||
|
|
train_embs = np.stack(train_embs)
|
||
|
|
|
||
|
|
|
||
|
|
pred_data = []
|
||
|
|
pred_embs = []
|
||
|
|
pred_idx_to_audioids = []
|
||
|
|
with File(args.pred_feature, "r") as store:
|
||
|
|
for audio_id, embedding in tqdm(store.items(), ascii=True):
|
||
|
|
pred_embs.append(embedding[()])
|
||
|
|
pred_idx_to_audioids.append(audio_id)
|
||
|
|
pred_embs = np.stack(pred_embs)
|
||
|
|
|
||
|
|
similarity = sklearn.metrics.pairwise.cosine_similarity(pred_embs, train_embs)
|
||
|
|
for idx, audio_id in enumerate(pred_idx_to_audioids):
|
||
|
|
train_idx = similarity[idx].argmax()
|
||
|
|
pred_data.append({
|
||
|
|
"filename": audio_id,
|
||
|
|
"tokens": random.choice(train_audioid_to_tokens[train_idx_to_audioid[train_idx]])
|
||
|
|
})
|
||
|
|
json.dump({"predictions": pred_data}, open(args.output_json, "w"), ensure_ascii=False, indent=4)
|