diff --git a/examples/pytorch/application/qwen_doc_search_QA_based_on_dashscope.ipynb b/examples/pytorch/application/qwen_doc_search_QA_based_on_dashscope.ipynb index 98276726..475fc48e 100644 --- a/examples/pytorch/application/qwen_doc_search_QA_based_on_dashscope.ipynb +++ b/examples/pytorch/application/qwen_doc_search_QA_based_on_dashscope.ipynb @@ -30,32 +30,6 @@ "!git clone https://github.com/shijiebei2009/CEC-Corpus.git" ] }, - { - "cell_type": "code", - "execution_count": 2, - "id": "0b53db17-7d6d-4192-a145-e470d6d2a6ec", - "metadata": { - "ExecutionIndicator": { - "show": true - }, - "execution": { - "iopub.execute_input": "2023-08-10T10:31:46.520109Z", - "iopub.status.busy": "2023-08-10T10:31:46.519793Z", - "iopub.status.idle": "2023-08-10T10:31:46.894428Z", - "shell.execute_reply": "2023-08-10T10:31:46.893761Z", - "shell.execute_reply.started": "2023-08-10T10:31:46.520085Z" - }, - "tags": [] - }, - "outputs": [], - "source": [ - "import dashscope\n", - "import os\n", - "from dotenv import load_dotenv\n", - "from dashscope import TextEmbedding\n", - "from dashvector import Client, Doc" - ] - }, { "cell_type": "code", "execution_count": 3, @@ -75,7 +49,13 @@ }, "outputs": [], "source": [ - "# get env varible from .env, please make sure add DASHSCOPE_KEY in .env\n", + "import dashscope\n", + "import os\n", + "from dotenv import load_dotenv\n", + "from dashscope import TextEmbedding\n", + "from dashvector import Client, Doc\n", + "\n", + "# get env variable from .env, please make sure add DASHSCOPE_KEY in .env\n", "load_dotenv()\n", "api_key = os.getenv('DASHSCOPE_KEY')\n", "dashscope.api_key = api_key\n", @@ -149,15 +129,15 @@ "def prepare_data_from_file(path, size):\n", " # prepare the data from file in order to upsert to dashvector with a reasonable doc's size.\n", " batch_docs = []\n", - " batch_size = 12\n", + " chunk_size = 12\n", " with open(path, 'r', encoding='utf-8') as f:\n", " doc = ''\n", " count = 0\n", " for line in f:\n", - " if count < batch_size and line.strip() != '':\n", + " if count < chunk_size and line.strip() != '':\n", " doc += line\n", " count += 1\n", - " if count == batch_size:\n", + " if count == chunk_size:\n", " batch_docs.append(doc)\n", " if len(batch_docs) == size:\n", " yield batch_docs[:]\n", @@ -188,11 +168,10 @@ }, "outputs": [], "source": [ - "def generate_embeddings(news):\n", + "def generate_embeddings(docs):\n", " # create embeddings via DashScope's TextEmbedding model API\n", " rsp = TextEmbedding.call(model=TextEmbedding.Models.text_embedding_v1,\n", - " input=news)\n", - " #print(rsp)\n", + " input=docs)\n", " embeddings = [record['embedding'] for record in rsp.output['embeddings']]\n", " return embeddings if isinstance(news, list) else embeddings[0]" ] @@ -210,19 +189,20 @@ "outputs": [], "source": [ "id = 0\n", - "# file_name = '天龙八部.txt'\n", "dir_name = 'CEC-Corpus/raw corpus/allSourceText'\n", "\n", "# indexing the raw docs with index to dashvector\n", "collection = ds_client.get(collection_name)\n", - "batch_size = 4 # embedding api max batch size\n", + "\n", + "# embedding api max batch size\n", + "batch_size = 4 \n", + "\n", "for news in list(prepare_data_from_dir(dir_name, batch_size)):\n", " ids = [id + i for i, _ in enumerate(news)]\n", " id += len(news)\n", " # generate embedding from raw docs\n", " vectors = generate_embeddings(news)\n", - " print(news)\n", - " # upsert and indexing\n", + " # upsert and index\n", " ret = collection.upsert(\n", " [\n", " Doc(id=str(id), vector=vector, fields={\"raw\": doc})\n", @@ -302,9 +282,7 @@ "output_type": "stream", "text": [ "2006-08-26 10:41:45\n", - "8月23日上午9时40分,京沪高速公路沧州服务区附近,一辆由北向南行驶的金杯面包车撞到高速公路护栏上,车上5名清华大学博士后研究人员及1名司机受伤,被紧急送往沧州二医院抢救。截至发稿时,仍有一名张姓博士后研究人员尚未脱离危险。\n", - "\n", - "\n" + "8月23日上午9时40分,京沪高速公路沧州服务区附近,一辆由北向南行驶的金杯面包车撞到高速公路护栏上,车上5名清华大学博士后研究人员及1名司机受伤,被紧急送往沧州二医院抢救。截至发稿时,仍有一名张姓博士后研究人员尚未脱离危险。\n" ] } ],