diff --git a/examples/pytorch/application/qwen_doc_search_QA_based_on_dashscope.ipynb b/examples/pytorch/application/qwen_doc_search_QA_based_on_dashscope.ipynb new file mode 100644 index 00000000..98276726 --- /dev/null +++ b/examples/pytorch/application/qwen_doc_search_QA_based_on_dashscope.ipynb @@ -0,0 +1,477 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "04d4165c-fab2-4f54-9b50-11d53917d785", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# install required packages\n", + "!pip install dashvector dashscope\n", + "!pip install transformers_stream_generator python-dotenv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ca135ac-b1b0-47b9-ad25-a0d11ac884f3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# prepare news corpus as knowledge source\n", + "!git clone https://github.com/shijiebei2009/CEC-Corpus.git" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "0b53db17-7d6d-4192-a145-e470d6d2a6ec", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-08-10T10:31:46.520109Z", + "iopub.status.busy": "2023-08-10T10:31:46.519793Z", + "iopub.status.idle": "2023-08-10T10:31:46.894428Z", + "shell.execute_reply": "2023-08-10T10:31:46.893761Z", + "shell.execute_reply.started": "2023-08-10T10:31:46.520085Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import dashscope\n", + "import os\n", + "from dotenv import load_dotenv\n", + "from dashscope import TextEmbedding\n", + "from dashvector import Client, Doc" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "728a2bf5-905c-48ef-b70a-be53d4f8fcc0", + "metadata": { + "ExecutionIndicator": { + "show": false + }, + "execution": { + "iopub.execute_input": "2023-08-10T10:32:15.429699Z", + "iopub.status.busy": "2023-08-10T10:32:15.429291Z", + "iopub.status.idle": "2023-08-10T10:32:16.076518Z", + "shell.execute_reply": "2023-08-10T10:32:16.075949Z", + "shell.execute_reply.started": "2023-08-10T10:32:15.429679Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# get env varible from .env, please make sure add DASHSCOPE_KEY in .env\n", + "load_dotenv()\n", + "api_key = os.getenv('DASHSCOPE_KEY')\n", + "dashscope.api_key = api_key\n", + "\n", + "# initialize dashvector for embedding's indexing and searching\n", + "ds_client = Client(api_key=api_key)\n", + "\n", + "# define collection name\n", + "collection_name = 'news_embeddings'\n", + "\n", + "# delete if already exist\n", + "ds_client.delete(collection_name)\n", + "\n", + "# create a collection with embedding size 1536\n", + "rsp = ds_client.create(collection_name, 1536)\n", + "collection = ds_client.get(collection_name)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "558b64ab-1fdf-4339-8368-97e67bef8159", + "metadata": { + "ExecutionIndicator": { + "show": false + }, + "execution": { + "iopub.execute_input": "2023-08-10T10:57:43.451192Z", + "iopub.status.busy": "2023-08-10T10:57:43.450893Z", + "iopub.status.idle": "2023-08-10T10:57:43.454858Z", + "shell.execute_reply": "2023-08-10T10:57:43.454244Z", + "shell.execute_reply.started": "2023-08-10T10:57:43.451173Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def prepare_data_from_dir(path, size):\n", + " # prepare the data from a file folder in order to upsert to dashvector with a reasonable doc's size.\n", + " batch_docs = []\n", + " for file in os.listdir(path):\n", + " with open(path + '/' + file, 'r', encoding='utf-8') as f:\n", + " batch_docs.append(f.read())\n", + " if len(batch_docs) == size:\n", + " yield batch_docs[:]\n", + " batch_docs.clear()\n", + "\n", + " if batch_docs:\n", + " yield batch_docs" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "d65c0f3f-a080-4803-b5ed-f4e641a96db2", + "metadata": { + "ExecutionIndicator": { + "show": false + }, + "execution": { + "iopub.execute_input": "2023-08-10T10:57:44.615001Z", + "iopub.status.busy": "2023-08-10T10:57:44.614690Z", + "iopub.status.idle": "2023-08-10T10:57:44.618899Z", + "shell.execute_reply": "2023-08-10T10:57:44.618418Z", + "shell.execute_reply.started": "2023-08-10T10:57:44.614979Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def prepare_data_from_file(path, size):\n", + " # prepare the data from file in order to upsert to dashvector with a reasonable doc's size.\n", + " batch_docs = []\n", + " batch_size = 12\n", + " with open(path, 'r', encoding='utf-8') as f:\n", + " doc = ''\n", + " count = 0\n", + " for line in f:\n", + " if count < batch_size and line.strip() != '':\n", + " doc += line\n", + " count += 1\n", + " if count == batch_size:\n", + " batch_docs.append(doc)\n", + " if len(batch_docs) == size:\n", + " yield batch_docs[:]\n", + " batch_docs.clear()\n", + " doc = ''\n", + " count = 0\n", + "\n", + " if batch_docs:\n", + " yield batch_docs" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "aded6eec-1f05-479e-9f0e-3ce63872a07b", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-08-10T10:57:46.210192Z", + "iopub.status.busy": "2023-08-10T10:57:46.209870Z", + "iopub.status.idle": "2023-08-10T10:57:46.214412Z", + "shell.execute_reply": "2023-08-10T10:57:46.213625Z", + "shell.execute_reply.started": "2023-08-10T10:57:46.210172Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def generate_embeddings(news):\n", + " # create embeddings via DashScope's TextEmbedding model API\n", + " rsp = TextEmbedding.call(model=TextEmbedding.Models.text_embedding_v1,\n", + " input=news)\n", + " #print(rsp)\n", + " embeddings = [record['embedding'] for record in rsp.output['embeddings']]\n", + " return embeddings if isinstance(news, list) else embeddings[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c0ba7e1-001f-4bb9-9bdb-7eb318bc3550", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "id = 0\n", + "# file_name = '天龙八部.txt'\n", + "dir_name = 'CEC-Corpus/raw corpus/allSourceText'\n", + "\n", + "# indexing the raw docs with index to dashvector\n", + "collection = ds_client.get(collection_name)\n", + "batch_size = 4 # embedding api max batch size\n", + "for news in list(prepare_data_from_dir(dir_name, batch_size)):\n", + " ids = [id + i for i, _ in enumerate(news)]\n", + " id += len(news)\n", + " # generate embedding from raw docs\n", + " vectors = generate_embeddings(news)\n", + " print(news)\n", + " # upsert and indexing\n", + " ret = collection.upsert(\n", + " [\n", + " Doc(id=str(id), vector=vector, fields={\"raw\": doc})\n", + " for id, doc, vector in zip(ids, news, vectors)\n", + " ]\n", + " )\n", + " print(ret)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53bed7e4-35be-4df6-8775-7d62fcdb6457", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# check the collection status\n", + "collection = ds_client.get(collection_name)\n", + "rsp = collection.stats()\n", + "print(rsp)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "41e54ddd-145d-49c3-ade4-4a46dc34e07b", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-08-10T10:57:54.368540Z", + "iopub.status.busy": "2023-08-10T10:57:54.368215Z", + "iopub.status.idle": "2023-08-10T10:57:54.371879Z", + "shell.execute_reply": "2023-08-10T10:57:54.371364Z", + "shell.execute_reply.started": "2023-08-10T10:57:54.368521Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def search_relevant_context(question, topk=1, client=ds_client):\n", + " # query and recall the relevant information\n", + " collection = client.get(collection_name)\n", + "\n", + " # recall the top k similiar results from dashvector\n", + " rsp = collection.query(generate_embeddings(question), output_fields=['raw'],\n", + " topk=topk)\n", + " return \"\".join([item.fields['raw'] for item in rsp.output])" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "409236b9-87d4-4df0-8ee6-486d3c0e5fb6", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-08-10T10:57:56.141848Z", + "iopub.status.busy": "2023-08-10T10:57:56.141502Z", + "iopub.status.idle": "2023-08-10T10:57:56.387965Z", + "shell.execute_reply": "2023-08-10T10:57:56.387379Z", + "shell.execute_reply.started": "2023-08-10T10:57:56.141830Z" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2006-08-26 10:41:45\n", + "8月23日上午9时40分,京沪高速公路沧州服务区附近,一辆由北向南行驶的金杯面包车撞到高速公路护栏上,车上5名清华大学博士后研究人员及1名司机受伤,被紧急送往沧州二医院抢救。截至发稿时,仍有一名张姓博士后研究人员尚未脱离危险。\n", + "\n", + "\n" + ] + } + ], + "source": [ + "# query the top 1 results\n", + "question = '清华博士发生了什么?'\n", + "context = search_relevant_context(question, topk=1)\n", + "print(context)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "730abebb-1f5a-4fb9-b035-fb2ae09a31c9", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# initialize qwen 7B model\n", + "from modelscope import AutoModelForCausalLM, AutoTokenizer\n", + "from modelscope import GenerationConfig\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"qwen/Qwen-7B-Chat\", revision = 'v1.0.5',trust_remote_code=True)\n", + "model = AutoModelForCausalLM.from_pretrained(\"qwen/Qwen-7B-Chat\", revision = 'v1.0.5',device_map=\"auto\", trust_remote_code=True, fp16=True).eval()\n", + "model.generation_config = GenerationConfig.from_pretrained(\"Qwen/Qwen-7B-Chat\",revision = 'v1.0.5', trust_remote_code=True) # 可指定不同的生成长度、top_p等相关超参" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "2f5a1bcb-e83a-44d3-bbe4-f97437782a3b", + "metadata": { + "ExecutionIndicator": { + "show": false + }, + "execution": { + "iopub.execute_input": "2023-08-10T10:41:01.761863Z", + "iopub.status.busy": "2023-08-10T10:41:01.761502Z", + "iopub.status.idle": "2023-08-10T10:41:01.765849Z", + "shell.execute_reply": "2023-08-10T10:41:01.765318Z", + "shell.execute_reply.started": "2023-08-10T10:41:01.761842Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "# define a prompt template for the knowledge enhanced LLM generation\n", + "def answer_question(question, context):\n", + " prompt = f'''请基于```内的内容回答问题。\"\n", + "\t```\n", + "\t{context}\n", + "\t```\n", + "\t我的问题是:{question}。\n", + " '''\n", + " history = None\n", + " print(prompt)\n", + " response, history = model.chat(tokenizer, prompt, history=None)\n", + " return response" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "75ac8f4a-a861-4376-9e55-ebefef9a9cd6", + "metadata": { + "ExecutionIndicator": { + "show": true + }, + "execution": { + "iopub.execute_input": "2023-08-10T10:41:29.070090Z", + "iopub.status.busy": "2023-08-10T10:41:29.069778Z", + "iopub.status.idle": "2023-08-10T10:41:31.613198Z", + "shell.execute_reply": "2023-08-10T10:41:31.612421Z", + "shell.execute_reply.started": "2023-08-10T10:41:29.070073Z" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "请基于```内的内容回答问题。\"\n", + "\t```\n", + "\t\n", + "\t```\n", + "\t我的问题是:清华博士发生了什么?。\n", + " \n", + "question: 清华博士发生了什么?\n", + "answer: 清华博士是指清华大学的博士研究生。作为一名AI语言模型,我无法获取个人的身份信息或具体事件,因此无法回答清华博士发生了什么。如果您需要了解更多相关信息,建议您查询相关媒体或官方网站。\n" + ] + } + ], + "source": [ + "# test the case without knowledge\n", + "question = '清华博士发生了什么?'\n", + "answer = answer_question(question, '')\n", + "print(f'question: {question}\\n' f'answer: {answer}')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "eca328fc-cd69-4e12-8448-f426f3314414", + "metadata": { + "ExecutionIndicator": { + "show": false + }, + "execution": { + "iopub.execute_input": "2023-08-10T10:41:34.268896Z", + "iopub.status.busy": "2023-08-10T10:41:34.268585Z", + "iopub.status.idle": "2023-08-10T10:41:37.750128Z", + "shell.execute_reply": "2023-08-10T10:41:37.749414Z", + "shell.execute_reply.started": "2023-08-10T10:41:34.268878Z" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "请基于```内的内容回答问题。\"\n", + "\t```\n", + "\t2006-08-26 10:41:45\n", + "8月23日上午9时40分,京沪高速公路沧州服务区附近,一辆由北向南行驶的金杯面包车撞到高速公路护栏上,车上5名清华大学博士后研究人员及1名司机受伤,被紧急送往沧州二医院抢救。截至发稿时,仍有一名张姓博士后研究人员尚未脱离危险。\n", + "\n", + "\n", + "\t```\n", + "\t我的问题是:清华博士发生了什么?。\n", + " \n", + "question: 清华博士发生了什么?\n", + "answer: 8月23日上午9时40分,一辆由北向南行驶的金杯面包车撞到高速公路护栏上,车上5名清华大学博士后研究人员及1名司机受伤,被紧急送往沧州二医院抢救。\n" + ] + } + ], + "source": [ + "# test the case with knowledge\n", + "context = search_relevant_context(question, topk=1)\n", + "answer = answer_question(question, context)\n", + "print(f'question: {question}\\n' f'answer: {answer}')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}