update example

This commit is contained in:
Yingda Chen
2023-08-13 11:06:41 +08:00
committed by wenmeng.zwm
parent 54fdc76fdb
commit bbde919a64

View File

@@ -55,23 +55,24 @@
"from dashscope import TextEmbedding\n",
"from dashvector import Client, Doc\n",
"\n",
"# get env variable from .env, please make sure add DASHSCOPE_KEY in .env\n",
"# get env variable from .env\n",
"# please make sure DASHSCOPE_KEY is defined in .env\n",
"load_dotenv()\n",
"api_key = os.getenv('DASHSCOPE_KEY')\n",
"dashscope.api_key = api_key\n",
"dashscope.api_key = os.getenv('DASHSCOPE_KEY')\n",
"\n",
"# initialize dashvector for embedding's indexing and searching\n",
"ds_client = Client(api_key=api_key)\n",
"\n",
"# initialize DashVector for embedding's indexing and searching\n",
"dashvector_client = Client(api_key='{your-dashvector-api-key}')\n",
"\n",
"# define collection name\n",
"collection_name = 'news_embeddings'\n",
"\n",
"# delete if already exist\n",
"ds_client.delete(collection_name)\n",
"dashvector_client.delete(collection_name)\n",
"\n",
"# create a collection with embedding size 1536\n",
"rsp = ds_client.create(collection_name, 1536)\n",
"collection = ds_client.get(collection_name)\n"
"# create a collection with embedding size of 1536\n",
"rsp = dashvector_client.create(collection_name, 1536)\n",
"collection = dashvector_client.get(collection_name)\n"
]
},
{
@@ -94,7 +95,7 @@
"outputs": [],
"source": [
"def prepare_data_from_dir(path, size):\n",
" # prepare the data from a file folder in order to upsert to dashvector with a reasonable doc's size.\n",
" # prepare the data from a file folder in order to upsert to DashVector with a reasonable doc's size.\n",
" batch_docs = []\n",
" for file in os.listdir(path):\n",
" with open(path + '/' + file, 'r', encoding='utf-8') as f:\n",
@@ -127,7 +128,7 @@
"outputs": [],
"source": [
"def prepare_data_from_file(path, size):\n",
" # prepare the data from file in order to upsert to dashvector with a reasonable doc's size.\n",
" # prepare the data from file in order to upsert to DashVector with a reasonable doc's size.\n",
" batch_docs = []\n",
" chunk_size = 12\n",
" with open(path, 'r', encoding='utf-8') as f:\n",
@@ -191,8 +192,8 @@
"id = 0\n",
"dir_name = 'CEC-Corpus/raw corpus/allSourceText'\n",
"\n",
"# indexing the raw docs with index to dashvector\n",
"collection = ds_client.get(collection_name)\n",
"# indexing the raw docs with index to DashVector\n",
"collection = dashvector_client.get(collection_name)\n",
"\n",
"# embedding api max batch size\n",
"batch_size = 4 \n",
@@ -225,7 +226,7 @@
"outputs": [],
"source": [
"# check the collection status\n",
"collection = ds_client.get(collection_name)\n",
"collection = dashvector_client.get(collection_name)\n",
"rsp = collection.stats()\n",
"print(rsp)"
]
@@ -249,11 +250,11 @@
},
"outputs": [],
"source": [
"def search_relevant_context(question, topk=1, client=ds_client):\n",
"def search_relevant_context(question, topk=1, client=dashvector_client):\n",
" # query and recall the relevant information\n",
" collection = client.get(collection_name)\n",
"\n",
" # recall the top k similiar results from dashvector\n",
" # recall the top k similarity results from DashVector\n",
" rsp = collection.query(generate_embeddings(question), output_fields=['raw'],\n",
" topk=topk)\n",
" return \"\".join([item.fields['raw'] for item in rsp.output])"
@@ -333,7 +334,7 @@
},
"outputs": [],
"source": [
"# define a prompt template for the knowledge enhanced LLM generation\n",
"# define a prompt template for the vectorDB-enhanced LLM generation\n",
"def answer_question(question, context):\n",
" prompt = f'''请基于```内的内容回答问题。\"\n",
"\t```\n",
@@ -381,7 +382,7 @@
}
],
"source": [
"# test the case without knowledge\n",
"# test the case on plain LLM without vectorDB enhancement\n",
"question = '清华博士发生了什么?'\n",
"answer = answer_question(question, '')\n",
"print(f'question: {question}\\n' f'answer: {answer}')"