mirror of
https://github.com/AIGC-Audio/AudioGPT.git
synced 2025-12-16 11:57:58 +01:00
111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
import os
|
|
import sys
|
|
import copy
|
|
import pickle
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import fire
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
|
|
def coco_score(refs, pred, scorer):
|
|
if scorer.method() == "Bleu":
|
|
scores = np.array([ 0.0 for n in range(4) ])
|
|
else:
|
|
scores = 0
|
|
num_cap_per_audio = len(refs[list(refs.keys())[0]])
|
|
|
|
for i in range(num_cap_per_audio):
|
|
if i > 0:
|
|
for key in refs:
|
|
refs[key].insert(0, res[key][0])
|
|
res = {key: [refs[key].pop(),] for key in refs}
|
|
score, _ = scorer.compute_score(refs, pred)
|
|
|
|
if scorer.method() == "Bleu":
|
|
scores += np.array(score)
|
|
else:
|
|
scores += score
|
|
|
|
score = scores / num_cap_per_audio
|
|
|
|
for key in refs:
|
|
refs[key].insert(0, res[key][0])
|
|
score_allref, _ = scorer.compute_score(refs, pred)
|
|
diff = score_allref - score
|
|
return diff
|
|
|
|
def embedding_score(refs, pred, scorer):
|
|
|
|
num_cap_per_audio = len(refs[list(refs.keys())[0]])
|
|
scores = 0
|
|
|
|
for i in range(num_cap_per_audio):
|
|
res = {key: [refs[key][i],] for key in refs.keys() if len(refs[key]) == num_cap_per_audio}
|
|
refs_i = {key: np.concatenate([refs[key][:i], refs[key][i+1:]]) for key in refs.keys() if len(refs[key]) == num_cap_per_audio}
|
|
score, _ = scorer.compute_score(refs_i, pred)
|
|
|
|
scores += score
|
|
|
|
score = scores / num_cap_per_audio
|
|
|
|
score_allref, _ = scorer.compute_score(refs, pred)
|
|
diff = score_allref - score
|
|
return diff
|
|
|
|
def main(output_file, eval_caption_file, eval_embedding_file, output, zh=False):
|
|
output_df = pd.read_json(output_file)
|
|
output_df["key"] = output_df["filename"].apply(lambda x: os.path.splitext(os.path.basename(x))[0])
|
|
pred = output_df.groupby("key")["tokens"].apply(list).to_dict()
|
|
|
|
label_df = pd.read_json(eval_caption_file)
|
|
if zh:
|
|
refs = label_df.groupby("key")["tokens"].apply(list).to_dict()
|
|
else:
|
|
refs = label_df.groupby("key")["caption"].apply(list).to_dict()
|
|
|
|
from pycocoevalcap.bleu.bleu import Bleu
|
|
from pycocoevalcap.cider.cider import Cider
|
|
from pycocoevalcap.rouge.rouge import Rouge
|
|
|
|
scorer = Bleu(zh=zh)
|
|
bleu_scores = coco_score(copy.deepcopy(refs), pred, scorer)
|
|
scorer = Cider(zh=zh)
|
|
cider_score = coco_score(copy.deepcopy(refs), pred, scorer)
|
|
scorer = Rouge(zh=zh)
|
|
rouge_score = coco_score(copy.deepcopy(refs), pred, scorer)
|
|
|
|
if not zh:
|
|
from pycocoevalcap.meteor.meteor import Meteor
|
|
scorer = Meteor()
|
|
meteor_score = coco_score(copy.deepcopy(refs), pred, scorer)
|
|
|
|
from pycocoevalcap.spice.spice import Spice
|
|
scorer = Spice()
|
|
spice_score = coco_score(copy.deepcopy(refs), pred, scorer)
|
|
|
|
# from audiocaptioneval.sentbert.sentencebert import SentenceBert
|
|
# scorer = SentenceBert(zh=zh)
|
|
# with open(eval_embedding_file, "rb") as f:
|
|
# ref_embeddings = pickle.load(f)
|
|
|
|
# sent_bert = embedding_score(ref_embeddings, pred, scorer)
|
|
|
|
with open(output, "w") as f:
|
|
f.write("Diff:\n")
|
|
for n in range(4):
|
|
f.write("BLEU-{}: {:6.3f}\n".format(n+1, bleu_scores[n]))
|
|
f.write("CIDEr: {:6.3f}\n".format(cider_score))
|
|
f.write("ROUGE: {:6.3f}\n".format(rouge_score))
|
|
if not zh:
|
|
f.write("Meteor: {:6.3f}\n".format(meteor_score))
|
|
f.write("SPICE: {:6.3f}\n".format(spice_score))
|
|
# f.write("SentenceBert: {:6.3f}\n".format(sent_bert))
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(main)
|