diff --git a/application/controllers/openai_controller.py b/application/controllers/openai_controller.py index 0f893406ac9c6b06599685328b489562770f72f5..55c3e0aa28751abf40e4a0884176b8371dc0814e 100644 --- a/application/controllers/openai_controller.py +++ b/application/controllers/openai_controller.py @@ -1,4 +1,7 @@ from flask import Blueprint, request, Response, current_app +from application.utils.general_tool import gcfnpdap, filter_docs, test_chat_history_length +from application.utils.vector import initialize_faiss_database +from application.utils.openai_api import gpt_three_stream openai_bp = Blueprint('openai', __name__, url_prefix='/openai') @@ -10,10 +13,12 @@ def openai_chat(): { "query": str, "history": List, - "model_param": { + "model_config": { "temperature": int, # openai 0-2. 其于模型未知 "max_token": int, #最大值1024 - "system_role": str, #最大值500 + "system_prompt_code": int | None + "system_prompt_content": str | None 最大值1500 + "knowledge": str | None # 允许选定知识库,如果选择知识库为None } } :return: @@ -29,8 +34,92 @@ def openai_chat(): pass else: + current_app.vkb = None + current_app.train_prompt = None + current_app.no_train_prompt = None + current_app.max_chat_length = 8000 + + # 相似性阀值,选用余弦相似度,1表示完全相关,0表示完全不相关 + current_app.similarity = 0.8 + new_question, chat_history = request.json - return Response("hello") + # 相关的语料 + related_docs_content = None + + # 相关的doc + related_docs = None + + if data["model_config"]["knowledge"]: + vf_apath = gcfnpdap(__name__, 2) + "/static/store_vector_knowledge_dictory/" + "1716811846_2" + current_app.vkb = initialize_faiss_database(vf_apath) + + if len(chat_history) > 2: + # 历史相似片段 + history_related_docs = current_app.vkb.similarity_search_with_relevance_scores(chat_history[-3], 1) + + # 最新问题相似片段 + new_question_related_docs = current_app.vkb.similarity_search_with_relevance_scores(new_question, 2) + + related_docs = new_question_related_docs + history_related_docs + + else: + # 最新问题相似片段 + new_question_related_docs = current_app.vkb.similarity_search_with_relevance_scores(new_question, 2) + + related_docs = new_question_related_docs + + # 相关的doc + related_docs = filter_docs(related_docs, current_app.similarity) + + # related_docs 经过过滤后不为空,则查找最相似的doc的scores + most_related_doc_scores = max([i[1] for i in related_docs]) + + related_docs_content = [doc[0].page_content for doc in related_docs] + related_docs_content = list(set(related_docs_content)) + related_docs_content = "\n\n".join(related_docs_content) + + # 最相关预料 + if related_docs: + + # 数据库操作 + # ------ + # ------ + # 这里data中传了代码过来,则用他的,使用数据库 + + trained_prompt_parent_dir = gcfnpdap(__name__, 2) + "/static/trained/" + "file_name.txt" + + with open(trained_prompt_parent_dir, "r", encoding="UTF-8") as f: + current_app.train_prompt = f.read() + + trained_prompt = current_app.train_prompt.format(doc=related_docs_content) + + chat_history.insert(0, trained_prompt) + + chat_history = test_chat_history_length(chat_history, current_app.max_chat_length) + + data["chat_history"] = chat_history + + return Response(gpt_three_stream(data), mimetype="text/event-stream") + + else: + + # 数据库操作 + # ------ + # ------ + + no_trained_prompt_parent_dir = gcfnpdap(__name__, 2) + "/static/no_trained/" + "file_name.txt" + + with open(no_trained_prompt_parent_dir, "r", encoding="UTF-8") as f: + current_app.no_train_prompt = f.read() + + chat_history.insert(0, current_app.no_train_prompt) + + chat_history = test_chat_history_length(chat_history, current_app.max_chat_length) + + data["chat_history"] = chat_history + + return Response(gpt_three_stream(data), mimetype="text/event-stream") + diff --git a/application/utils/__init__.py b/application/utils/__init__.py index 880f412afda6ed39d41e32e1be84af5e024d17e3..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/application/utils/__init__.py +++ b/application/utils/__init__.py @@ -1,7 +0,0 @@ -from application.utils.moonshot import all_func as moonshot -from application.utils.openai_api import all_func as openai_api -from application.utils.qianfan import all_func as qianfan -from application.utils.general_tool import all_func as tool -from application.utils.vector import all_func as vector - -__all__ = moonshot + openai_api + qianfan + tool + vector diff --git a/application/utils/general_tool/__init__.py b/application/utils/general_tool/__init__.py index c384539078c24eccf3ab8cac266b931705c68b2a..18a6391d78a446aecccb44f88e85faa36b98fea5 100644 --- a/application/utils/general_tool/__init__.py +++ b/application/utils/general_tool/__init__.py @@ -19,7 +19,38 @@ def gcfnpdap(file_path, n): return dir_absolute_path -all_func = [gcfnpdap] +def filter_docs(list_, value_): + """ + desc: 过滤相似性较低的语料 + :param list_: 向量数据库查询的相似的docs + :param value_: 自定义相似性阀值 + :return: + """ + return [i for i in list_ if i[1] >= value_] + + +def measure_chat_history_length(chat_history): + """判断字符列表中所有元素总长度""" + total_length = 0 + for i in chat_history: + total_length = total_length + len(i) + + return total_length + + +def test_chat_history_length(chat_history, max_chat_length): + """检验字符列表总长度,如果长度过长,删去历史记录""" + if measure_chat_history_length(chat_history) <= max_chat_length: + return chat_history + + else: + chat_history.pop(1) + chat_history.pop(1) + + return test_chat_history_length() + + +__all__ = ["gcfnpdap", "filter_docs", "test_chat_history_length"] diff --git a/application/utils/moonshot/__init__.py b/application/utils/moonshot/__init__.py index 1fa0b450603038ecf14752204c745dfb71294916..9fcffeef8acc89f34721082153ab5d76ec5e753d 100644 --- a/application/utils/moonshot/__init__.py +++ b/application/utils/moonshot/__init__.py @@ -1,3 +1,3 @@ from application.utils.moonshot.chat import moonshot_one, moonshot_one_stream -all_func = [moonshot_one, moonshot_one_stream] \ No newline at end of file +__all__ = ["moonshot_one", "moonshot_one_stream"] diff --git a/application/utils/openai_api/__init__.py b/application/utils/openai_api/__init__.py index c2affe0188eeae6296e8f2ab5a777897b9f3d170..18c6a98ec9a0ef3432ab7449aae44c4151b1ce60 100644 --- a/application/utils/openai_api/__init__.py +++ b/application/utils/openai_api/__init__.py @@ -1,3 +1,3 @@ from application.utils.openai_api.chat import gpt_three, gpt_three_stream -all_func = [gpt_three, gpt_three_stream] +__all__ = ["gpt_three", "gpt_three_stream"] diff --git a/application/utils/openai_api/chat.py b/application/utils/openai_api/chat.py index 04f4b2f9b159625de4bea5f5306cfc59b3c0a757..8e966aff05e68ffbd8c1572e563f20f5209dea6b 100644 --- a/application/utils/openai_api/chat.py +++ b/application/utils/openai_api/chat.py @@ -13,6 +13,17 @@ def gpt_three(data): } """ + # 处理聊天数据 + for i in range(len(data["message"])): + if i == 0: + data["message"][i] = {"role": "system", "content": data["message"][i]} + + elif i % 2 != 0: + data["message"][i] = {"role": "user", "content": data["message"][i]} + + else: + data["message"][i] = {"role": "assistant", "content": data["message"][i]} + client = OpenAI() completion = client.chat.completions.create( model="gpt-4o", @@ -33,6 +44,18 @@ def gpt_three(data): def gpt_three_stream(data): + + # 处理聊天数据 + for i in range(len(data["message"])): + if i == 0: + data["message"][i] = {"role": "system", "content": data["message"][i]} + + elif i % 2 != 0: + data["message"][i] = {"role": "user", "content": data["message"][i]} + + else: + data["message"][i] = {"role": "assistant", "content": data["message"][i]} + client = OpenAI() completion = client.chat.completions.create( model='gpt-3.5-turbo', diff --git a/application/utils/qianfan/__init__.py b/application/utils/qianfan/__init__.py index d8a67ff5f6422cf6cb16c186d6bb5d6527c9cff8..ab24b7ad773ef1146462303660be8ebe61145357 100644 --- a/application/utils/qianfan/__init__.py +++ b/application/utils/qianfan/__init__.py @@ -3,5 +3,5 @@ import json from application.utils.qianfan.chat import ernie_three, ernie_three_stream -all_func = [ernie_three, ernie_three_stream] +__all__ = ["ernie_three", "ernie_three_stream"] diff --git a/application/utils/vector/__init__.py b/application/utils/vector/__init__.py index 815c7e363c2325e25a63e3bac6bf8492b8f7523c..6764e7fa9fec60188327b35a25c112f83b3a20f3 100644 --- a/application/utils/vector/__init__.py +++ b/application/utils/vector/__init__.py @@ -1,4 +1,5 @@ from application.utils.vector.excel_to_vetor import xlsx_excel_to_vector +from application.utils.vector.hang_vector_database import initialize_faiss_database import os # store openai-api-key to internal storage @@ -8,4 +9,4 @@ os.environ["OPENAI_API_KEY"] = "sk-proj-wOXa1EvcrzPxGvVwjbp3T3BlbkFJSZNHhSdxIjtB os.environ["http_proxy"] = "http://localhost:7890" os.environ["https_proxy"] = "http://localhost:7890" -all_func = [xlsx_excel_to_vector] +__all__ = ["xlsx_excel_to_vector", "initialize_faiss_database"] diff --git a/application/utils/vector/hang_vector_database.py b/application/utils/vector/hang_vector_database.py new file mode 100644 index 0000000000000000000000000000000000000000..16da623211b9323cc91e6c4ef56d85d1d56bd99e --- /dev/null +++ b/application/utils/vector/hang_vector_database.py @@ -0,0 +1,11 @@ +from langchain_openai import OpenAIEmbeddings +from langchain_community.vectorstores import FAISS + + +def initialize_faiss_database(vector_data_dir_absolute_path): + # openai embedding + embeddings = OpenAIEmbeddings() + + vector_kb_object = FAISS.load_local(vector_data_dir_absolute_path, embeddings) + + return vector_kb_object