update QA example

This commit is contained in:
Zhicheng Zhang
2023-08-10 19:03:20 +08:00
committed by wenmeng.zwm
parent c6df118593
commit 348d6d04e4

View File

@@ -0,0 +1,477 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "04d4165c-fab2-4f54-9b50-11d53917d785",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"# install required packages\n",
"!pip install dashvector dashscope\n",
"!pip install transformers_stream_generator python-dotenv"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0ca135ac-b1b0-47b9-ad25-a0d11ac884f3",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# prepare news corpus as knowledge source\n",
"!git clone https://github.com/shijiebei2009/CEC-Corpus.git"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0b53db17-7d6d-4192-a145-e470d6d2a6ec",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"execution": {
"iopub.execute_input": "2023-08-10T10:31:46.520109Z",
"iopub.status.busy": "2023-08-10T10:31:46.519793Z",
"iopub.status.idle": "2023-08-10T10:31:46.894428Z",
"shell.execute_reply": "2023-08-10T10:31:46.893761Z",
"shell.execute_reply.started": "2023-08-10T10:31:46.520085Z"
},
"tags": []
},
"outputs": [],
"source": [
"import dashscope\n",
"import os\n",
"from dotenv import load_dotenv\n",
"from dashscope import TextEmbedding\n",
"from dashvector import Client, Doc"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "728a2bf5-905c-48ef-b70a-be53d4f8fcc0",
"metadata": {
"ExecutionIndicator": {
"show": false
},
"execution": {
"iopub.execute_input": "2023-08-10T10:32:15.429699Z",
"iopub.status.busy": "2023-08-10T10:32:15.429291Z",
"iopub.status.idle": "2023-08-10T10:32:16.076518Z",
"shell.execute_reply": "2023-08-10T10:32:16.075949Z",
"shell.execute_reply.started": "2023-08-10T10:32:15.429679Z"
},
"tags": []
},
"outputs": [],
"source": [
"# get env varible from .env, please make sure add DASHSCOPE_KEY in .env\n",
"load_dotenv()\n",
"api_key = os.getenv('DASHSCOPE_KEY')\n",
"dashscope.api_key = api_key\n",
"\n",
"# initialize dashvector for embedding's indexing and searching\n",
"ds_client = Client(api_key=api_key)\n",
"\n",
"# define collection name\n",
"collection_name = 'news_embeddings'\n",
"\n",
"# delete if already exist\n",
"ds_client.delete(collection_name)\n",
"\n",
"# create a collection with embedding size 1536\n",
"rsp = ds_client.create(collection_name, 1536)\n",
"collection = ds_client.get(collection_name)\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "558b64ab-1fdf-4339-8368-97e67bef8159",
"metadata": {
"ExecutionIndicator": {
"show": false
},
"execution": {
"iopub.execute_input": "2023-08-10T10:57:43.451192Z",
"iopub.status.busy": "2023-08-10T10:57:43.450893Z",
"iopub.status.idle": "2023-08-10T10:57:43.454858Z",
"shell.execute_reply": "2023-08-10T10:57:43.454244Z",
"shell.execute_reply.started": "2023-08-10T10:57:43.451173Z"
},
"tags": []
},
"outputs": [],
"source": [
"def prepare_data_from_dir(path, size):\n",
" # prepare the data from a file folder in order to upsert to dashvector with a reasonable doc's size.\n",
" batch_docs = []\n",
" for file in os.listdir(path):\n",
" with open(path + '/' + file, 'r', encoding='utf-8') as f:\n",
" batch_docs.append(f.read())\n",
" if len(batch_docs) == size:\n",
" yield batch_docs[:]\n",
" batch_docs.clear()\n",
"\n",
" if batch_docs:\n",
" yield batch_docs"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "d65c0f3f-a080-4803-b5ed-f4e641a96db2",
"metadata": {
"ExecutionIndicator": {
"show": false
},
"execution": {
"iopub.execute_input": "2023-08-10T10:57:44.615001Z",
"iopub.status.busy": "2023-08-10T10:57:44.614690Z",
"iopub.status.idle": "2023-08-10T10:57:44.618899Z",
"shell.execute_reply": "2023-08-10T10:57:44.618418Z",
"shell.execute_reply.started": "2023-08-10T10:57:44.614979Z"
},
"tags": []
},
"outputs": [],
"source": [
"def prepare_data_from_file(path, size):\n",
" # prepare the data from file in order to upsert to dashvector with a reasonable doc's size.\n",
" batch_docs = []\n",
" batch_size = 12\n",
" with open(path, 'r', encoding='utf-8') as f:\n",
" doc = ''\n",
" count = 0\n",
" for line in f:\n",
" if count < batch_size and line.strip() != '':\n",
" doc += line\n",
" count += 1\n",
" if count == batch_size:\n",
" batch_docs.append(doc)\n",
" if len(batch_docs) == size:\n",
" yield batch_docs[:]\n",
" batch_docs.clear()\n",
" doc = ''\n",
" count = 0\n",
"\n",
" if batch_docs:\n",
" yield batch_docs"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "aded6eec-1f05-479e-9f0e-3ce63872a07b",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"execution": {
"iopub.execute_input": "2023-08-10T10:57:46.210192Z",
"iopub.status.busy": "2023-08-10T10:57:46.209870Z",
"iopub.status.idle": "2023-08-10T10:57:46.214412Z",
"shell.execute_reply": "2023-08-10T10:57:46.213625Z",
"shell.execute_reply.started": "2023-08-10T10:57:46.210172Z"
},
"tags": []
},
"outputs": [],
"source": [
"def generate_embeddings(news):\n",
" # create embeddings via DashScope's TextEmbedding model API\n",
" rsp = TextEmbedding.call(model=TextEmbedding.Models.text_embedding_v1,\n",
" input=news)\n",
" #print(rsp)\n",
" embeddings = [record['embedding'] for record in rsp.output['embeddings']]\n",
" return embeddings if isinstance(news, list) else embeddings[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c0ba7e1-001f-4bb9-9bdb-7eb318bc3550",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"id = 0\n",
"# file_name = '天龙八部.txt'\n",
"dir_name = 'CEC-Corpus/raw corpus/allSourceText'\n",
"\n",
"# indexing the raw docs with index to dashvector\n",
"collection = ds_client.get(collection_name)\n",
"batch_size = 4 # embedding api max batch size\n",
"for news in list(prepare_data_from_dir(dir_name, batch_size)):\n",
" ids = [id + i for i, _ in enumerate(news)]\n",
" id += len(news)\n",
" # generate embedding from raw docs\n",
" vectors = generate_embeddings(news)\n",
" print(news)\n",
" # upsert and indexing\n",
" ret = collection.upsert(\n",
" [\n",
" Doc(id=str(id), vector=vector, fields={\"raw\": doc})\n",
" for id, doc, vector in zip(ids, news, vectors)\n",
" ]\n",
" )\n",
" print(ret)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53bed7e4-35be-4df6-8775-7d62fcdb6457",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"# check the collection status\n",
"collection = ds_client.get(collection_name)\n",
"rsp = collection.stats()\n",
"print(rsp)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "41e54ddd-145d-49c3-ade4-4a46dc34e07b",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"execution": {
"iopub.execute_input": "2023-08-10T10:57:54.368540Z",
"iopub.status.busy": "2023-08-10T10:57:54.368215Z",
"iopub.status.idle": "2023-08-10T10:57:54.371879Z",
"shell.execute_reply": "2023-08-10T10:57:54.371364Z",
"shell.execute_reply.started": "2023-08-10T10:57:54.368521Z"
},
"tags": []
},
"outputs": [],
"source": [
"def search_relevant_context(question, topk=1, client=ds_client):\n",
" # query and recall the relevant information\n",
" collection = client.get(collection_name)\n",
"\n",
" # recall the top k similiar results from dashvector\n",
" rsp = collection.query(generate_embeddings(question), output_fields=['raw'],\n",
" topk=topk)\n",
" return \"\".join([item.fields['raw'] for item in rsp.output])"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "409236b9-87d4-4df0-8ee6-486d3c0e5fb6",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"execution": {
"iopub.execute_input": "2023-08-10T10:57:56.141848Z",
"iopub.status.busy": "2023-08-10T10:57:56.141502Z",
"iopub.status.idle": "2023-08-10T10:57:56.387965Z",
"shell.execute_reply": "2023-08-10T10:57:56.387379Z",
"shell.execute_reply.started": "2023-08-10T10:57:56.141830Z"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2006-08-26 10:41:45\n",
"8月23日上午9时40分京沪高速公路沧州服务区附近一辆由北向南行驶的金杯面包车撞到高速公路护栏上车上5名清华大学博士后研究人员及1名司机受伤被紧急送往沧州二医院抢救。截至发稿时仍有一名张姓博士后研究人员尚未脱离危险。\n",
"\n",
"\n"
]
}
],
"source": [
"# query the top 1 results\n",
"question = '清华博士发生了什么?'\n",
"context = search_relevant_context(question, topk=1)\n",
"print(context)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "730abebb-1f5a-4fb9-b035-fb2ae09a31c9",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"tags": []
},
"outputs": [],
"source": [
"# initialize qwen 7B model\n",
"from modelscope import AutoModelForCausalLM, AutoTokenizer\n",
"from modelscope import GenerationConfig\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"qwen/Qwen-7B-Chat\", revision = 'v1.0.5',trust_remote_code=True)\n",
"model = AutoModelForCausalLM.from_pretrained(\"qwen/Qwen-7B-Chat\", revision = 'v1.0.5',device_map=\"auto\", trust_remote_code=True, fp16=True).eval()\n",
"model.generation_config = GenerationConfig.from_pretrained(\"Qwen/Qwen-7B-Chat\",revision = 'v1.0.5', trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "2f5a1bcb-e83a-44d3-bbe4-f97437782a3b",
"metadata": {
"ExecutionIndicator": {
"show": false
},
"execution": {
"iopub.execute_input": "2023-08-10T10:41:01.761863Z",
"iopub.status.busy": "2023-08-10T10:41:01.761502Z",
"iopub.status.idle": "2023-08-10T10:41:01.765849Z",
"shell.execute_reply": "2023-08-10T10:41:01.765318Z",
"shell.execute_reply.started": "2023-08-10T10:41:01.761842Z"
},
"tags": []
},
"outputs": [],
"source": [
"# define a prompt template for the knowledge enhanced LLM generation\n",
"def answer_question(question, context):\n",
" prompt = f'''请基于```内的内容回答问题。\"\n",
"\t```\n",
"\t{context}\n",
"\t```\n",
"\t我的问题是{question}。\n",
" '''\n",
" history = None\n",
" print(prompt)\n",
" response, history = model.chat(tokenizer, prompt, history=None)\n",
" return response"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "75ac8f4a-a861-4376-9e55-ebefef9a9cd6",
"metadata": {
"ExecutionIndicator": {
"show": true
},
"execution": {
"iopub.execute_input": "2023-08-10T10:41:29.070090Z",
"iopub.status.busy": "2023-08-10T10:41:29.069778Z",
"iopub.status.idle": "2023-08-10T10:41:31.613198Z",
"shell.execute_reply": "2023-08-10T10:41:31.612421Z",
"shell.execute_reply.started": "2023-08-10T10:41:29.070073Z"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"请基于```内的内容回答问题。\"\n",
"\t```\n",
"\t\n",
"\t```\n",
"\t我的问题是清华博士发生了什么。\n",
" \n",
"question: 清华博士发生了什么?\n",
"answer: 清华博士是指清华大学的博士研究生。作为一名AI语言模型我无法获取个人的身份信息或具体事件因此无法回答清华博士发生了什么。如果您需要了解更多相关信息建议您查询相关媒体或官方网站。\n"
]
}
],
"source": [
"# test the case without knowledge\n",
"question = '清华博士发生了什么?'\n",
"answer = answer_question(question, '')\n",
"print(f'question: {question}\\n' f'answer: {answer}')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "eca328fc-cd69-4e12-8448-f426f3314414",
"metadata": {
"ExecutionIndicator": {
"show": false
},
"execution": {
"iopub.execute_input": "2023-08-10T10:41:34.268896Z",
"iopub.status.busy": "2023-08-10T10:41:34.268585Z",
"iopub.status.idle": "2023-08-10T10:41:37.750128Z",
"shell.execute_reply": "2023-08-10T10:41:37.749414Z",
"shell.execute_reply.started": "2023-08-10T10:41:34.268878Z"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"请基于```内的内容回答问题。\"\n",
"\t```\n",
"\t2006-08-26 10:41:45\n",
"8月23日上午9时40分京沪高速公路沧州服务区附近一辆由北向南行驶的金杯面包车撞到高速公路护栏上车上5名清华大学博士后研究人员及1名司机受伤被紧急送往沧州二医院抢救。截至发稿时仍有一名张姓博士后研究人员尚未脱离危险。\n",
"\n",
"\n",
"\t```\n",
"\t我的问题是清华博士发生了什么。\n",
" \n",
"question: 清华博士发生了什么?\n",
"answer: 8月23日上午9时40分一辆由北向南行驶的金杯面包车撞到高速公路护栏上车上5名清华大学博士后研究人员及1名司机受伤被紧急送往沧州二医院抢救。\n"
]
}
],
"source": [
"# test the case with knowledge\n",
"context = search_relevant_context(question, topk=1)\n",
"answer = answer_question(question, context)\n",
"print(f'question: {question}\\n' f'answer: {answer}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}