LangChainのRAGチュートリアルをやってみた(後編) #langchain #langgraph #rag #azure #openai

はじめに
本稿では LangChain のチュートリアル Build a Retrieval Augmented Generation (RAG) App: Part 2 を Azure OpenAI を使って実施してみた記録です。
LLMのみに基づいた生成AIでは、LLMモデルが知らない内容を無理に回答しようとする場合、おかしな結果を生成してしまいがちです。例えば、社内のみで共有されている文章について知らないと回答するならまだしも、まったく間違った内容を生成されては困ります。
そこで、RAG (Retrieval Augmented Generation; 検索拡張生成) という検索技術と生成AIを組み合わせたアプローチの出番です。先の例で言えば、検索で得た社内文章をもとに、生成AIで回答を作成するので、正しい答えである可能性が高まります。
RAGについては他記事も参照してください。
- 【RAGがわかる】社内勉強会の内容を特別公開!
- 【CLくんブログ】Tokyo RAG user group Meetup で弊社エンジニアが講演しました!
- 【エンタープライズLLM】社内データを元に回答してくれるChatGPTを作るには? RAG・LLM技術を利用して価値ある企業独自のAIを作るためのテクニック
- Azure AI SearchでRAGしてみよう! チャットプレイグラウンドとWebアプリ編
LLMを使用したアプリケーションを開発するためのフレームワークであるLangChainのチュートリアルには簡単なRAGを実装する方法がありますので、このチュートリアルに本記事に独自の内容も含めてやってみましょう。
前提
本稿では次の環境を前提として進めます。
多少のバージョン違いは読み替えてください。また、次の有料リソースを利用します。
- Azure OpenAI
- gpt-4o (2024-08-06)
- text-embedding-ada-002 (2)
こちらも多少のバージョン違いは読み替えてください。
実施
もう一つのチュートリアルからも実際のコードを作成していきましょう。
このチュートリアルでは、特定のウェブページの内容に基いて、質疑応答を2回継続するLangGraphアプリケーションを作成しています。本記事の独自要素として、弊社の「スキルアップ支援制度」に対する質問に答えられるようにしてみましょう。回答は弊社 Tech blog の記事「みんなで育むスキルアップ支援制度・正式導入を決定しました」に記載された内容に基づくようにします。
Part 1 と重複する部分
冒頭部分は Part 1 と同じコードなので、解説を省きます。詳細は前編をご覧ください。
import os from dotenv import load_dotenv load_dotenv() from langchain.chat_models import init_chat_model llm = init_chat_model( model=os.getenv("AZURE_OPENAI_MODEL"), model_provider="azure_openai", azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), ) from langchain_openai import AzureOpenAIEmbeddings embeddings = AzureOpenAIEmbeddings( model=os.getenv("AZURE_OPENAI_EMBEDDING_MODEL"), azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), api_key=os.getenv("AZURE_OPENAI_API_KEY"), openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION"), ) from langchain_core.vectorstores import InMemoryVectorStore vector_store = InMemoryVectorStore(embeddings) import bs4 from langchain_community.document_loaders import WebBaseLoader loader = WebBaseLoader( web_paths=("https://www.creationline.com/tech-blog/hr/76913",), bs_kwargs=dict( parse_only=bs4.SoupStrainer( class_=("content_post", "tech_header") ) ), ) docs = loader.load() from langchain_text_splitters import RecursiveCharacterTextSplitter text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( separators=["\n\n", "\n", "。"], chunk_size=200, chunk_overlap=100 ) all_splits = text_splitter.split_documents(docs) _ = vector_store.add_documents(documents=all_splits)
一点だけ Part 1 と変更しています。
chunk_size=200, chunk_overlap=100
チャンクサイズとオーバーラップをそれぞれ2倍の値にしています。理由は後述します。
グラフの構成(1)
from langgraph.graph import MessagesState, StateGraph graph_builder = StateGraph(MessagesState)
LangGraphについては過去ブログ「LangGraphをLLMなしでちょっと触ってみよう」「LangGraphとAzure OpenAIを組み合わせてみよう」を参照してください。
ツールの作成
from langchain_core.tools import tool @tool(response_format="content_and_artifact") def retrieve(query: str): """Retrieve information related to a query.""" retrieved_docs = vector_store.similarity_search(query, k=2) serialized = "\n\n".join( (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}") for doc in retrieved_docs ) return serialized, retrieved_docs
ここではLangGraphのTool callingを使って検索「ツール」を作成しています。Tool callingのしくみについては過去記事「LangGraphのTool callingでOpenAI APIのFunction callingを試してみよう」をご覧ください。
ベクトルストアから簡単な類似検索 (similarity_search)を用いてチャンクを取り出します。チャンクの個数は k=2
でtop_kの指定がある通り、2個です。そしてチャンクの内容だけでなく、メタデータも連結して返しています。
なお、コメント部分の「 """Retrieve information related to a query."""
」もツールに必要な部分なので削除しないでください。削除すると ValueError: Function must have a docstring if description not provided.
というエラーになります。
分岐ノード
from langchain_core.messages import SystemMessage from langgraph.prebuilt import ToolNode def query_or_respond(state: MessagesState): llm_with_tools = llm.bind_tools([retrieve]) response = llm_with_tools.invoke(state["messages"]) return {"messages": [response]}
ここではユーザの入力 state["messages"]
をTool callingの仕組みで、先のツール retrieve
でベクトルストアの問い合わせを行うか、直接応答するかで分岐します。
ツールノード
tools = ToolNode([retrieve])
先のツール retrieve
をツールノードとしています。
生成ノード
generate
関数は大きいので、少しずつ見ていきます。
def generate(state: MessagesState): recent_tool_messages = [] for message in reversed(state["messages"]): if message.type == "tool": recent_tool_messages.append(message) else: break tool_messages = recent_tool_messages[::-1]
ツールが生成したメッセージを recent_tool_messages
リストに追加していきます。最終的に tool_messages
リストに逆順にしています。ここで言うツールは retrieve
なので、ベクトル検索して得られたチャンクがツール生成メッセージとなるでしょう。
docs_content = "\n\n".join(doc.content for doc in tool_messages) print(docs_content) system_message_content = ( "You are an assistant for question-answering tasks. " "Use the following pieces of retrieved context to answer " "the question. If you don't know the answer, say that you " "don't know. Use three sentences maximum and keep the " "answer concise." "\n\n" f"{docs_content}" )
この文を訳すと次のようになります。
あなたは質問応答タスクのアシスタントです。 質問に答えるために、次の収集された文脈を活用してください。 答えが分からない場合は、分からないと言いましょう。 答えは最大三文で簡潔に述べてください。
この後に取得したチャンクが続き、システムメッセージとなります。
conversation_messages = [ message for message in state["messages"] if message.type in ("human", "system") or (message.type == "ai" and not message.tool_calls) ]
ここでは state["messages"]
に含まれているメッセージから、特定の条件に合致したメッセージを取り出して conversation_messages
リストを新しく作成しています。具体的には、
- 人間(
human
)かシステム(system
)のメッセージ - AI(
ai
)かつTool callingを行っていないメッセージ
です。
prompt = [SystemMessage(system_message_content)] + conversation_messages
生成AIに渡すプロンプトとして、先に生成したシステムメッセージと会話履歴を連結したものを使います。
response = llm.invoke(prompt) return {"messages": [response]}
生成AIにプロンプトを渡して、結果を得ます。
グラフの構成(2)
from langgraph.graph import END from langgraph.prebuilt import ToolNode, tools_condition graph_builder.add_node(query_or_respond) graph_builder.add_node(tools) graph_builder.add_node(generate) graph_builder.set_entry_point("query_or_respond") graph_builder.add_conditional_edges( "query_or_respond", tools_condition, {END: END, "tools": "tools"}, ) graph_builder.add_edge("tools", "generate") graph_builder.add_edge("generate", END) from langgraph.checkpoint.memory import MemorySaver memory = MemorySaver() graph = graph_builder.compile(checkpointer=memory) config = {"configurable": {"thread_id": "abc123"}}
グラフを構成します。この際、会話履歴をメモリ保持するようにしています。詳しくは過去記事「LangGraphの会話履歴をメモリ保持しよう」をご覧ください。
質疑応答
input_messages = [ "クリエーションラインの「スキルアップ支援制度」とは?", "費用の上限はいくらですか?", ] for input_message in input_messages: for step in graph.stream( {"messages": [{"role": "user", "content": input_message}]}, stream_mode="values", config=config, ): step["messages"][-1].pretty_print()
まず「スキルアップ支援制度」とは何かを聞いて、さらに補助される費用の上限を聞いています。
結果
次のようになりました。
(langgraph) % ./61_rag_tutorial_2.py ================================ Human Message ================================= クリエーションラインの「スキルアップ支援制度」とは? ================================== Ai Message ================================== Tool Calls: retrieve (call_uY598LFs8TtliXjHKze5hmq3) Call ID: call_uY598LFs8TtliXjHKze5hmq3 Args: query: クリエーションライン スキルアップ支援制度 ================================= Tool Message ================================= Name: retrieve Source: {'source': 'https://www.creationline.com/tech-blog/hr/76913'} Content: 【本題に入る前に】スキルアップ支援制度とは何か? Source: {'source': 'https://www.creationline.com/tech-blog/hr/76913'} Content: 2024年7月からトライアル運用を始めた「スキルアップ支援制度」は、メンバー一人ひとりのスキル向上やチャレンジを後押しする制度として、有効に機能していることが確認できました。また、トライアル期間中に発生したさまざまなケースにも、リーダー陣が会社全体の視点から議論を重ね、柔軟に対応することができました。 Source: {'source': 'https://www.creationline.com/tech-blog/hr/76913'} Content: 【本題に入る前に】スキルアップ支援制度とは何か? Source: {'source': 'https://www.creationline.com/tech-blog/hr/76913'} Content: 2024年7月からトライアル運用を始めた「スキルアップ支援制度」は、メンバー一人ひとりのスキル向上やチャレンジを後押しする制度として、有効に機能していることが確認できました。また、トライアル期間中に発生したさまざまなケースにも、リーダー陣が会社全体の視点から議論を重ね、柔軟に対応することができました。 ================================== Ai Message ================================== クリエーションラインの「スキルアップ支援制度」は、メンバー一人ひとりのスキル向上やチャレンジを後押しするための制度です。2024年7月からトライアル運用が始まり、トライアル期間中にはリーダー陣が柔軟に対応しながら運用の効果を確認しています。 ================================ Human Message ================================= 費用の上限はいくらですか? ================================== Ai Message ================================== Tool Calls: retrieve (call_WjV7MEfSJmwDZXYk2WduZlGh) Call ID: call_WjV7MEfSJmwDZXYk2WduZlGh Args: query: クリエーションライン スキルアップ支援制度 費用 上限 ================================= Tool Message ================================= Name: retrieve Source: {'source': 'https://www.creationline.com/tech-blog/hr/76913'} Content: スキルアップ支援制度とは、スキル開発費用の一部を会社が負担することで、メンバーの皆さんのスキル開発を奨励することを主目的とした制度です。1人年間10万円を上限とし、上長であるチームリーダーが承認すればいろいろな人材育成施策で使用可能にしています Source: {'source': 'https://www.creationline.com/tech-blog/hr/76913'} Content: スキルアップ支援制度は「受講費用」にのみ適用 交通費・宿泊費は、別途経費として申請可能 Source: {'source': 'https://www.creationline.com/tech-blog/hr/76913'} Content: スキルアップ支援制度とは、スキル開発費用の一部を会社が負担することで、メンバーの皆さんのスキル開発を奨励することを主目的とした制度です。1人年間10万円を上限とし、上長であるチームリーダーが承認すればいろいろな人材育成施策で使用可能にしています Source: {'source': 'https://www.creationline.com/tech-blog/hr/76913'} Content: スキルアップ支援制度は「受講費用」にのみ適用 交通費・宿泊費は、別途経費として申請可能 ================================== Ai Message ================================== 費用の上限は、1人年間10万円です。
必要な部分だけ抜き出すと、
- Q. クリエーションラインの「スキルアップ支援制度」とは?
- A. クリエーションラインの「スキルアップ支援制度」は、メンバー一人ひとりのスキル向上やチャレンジを後押しするための制度です。2024年7月からトライアル運用が始まり、トライアル期間中にはリーダー陣が柔軟に対応しながら運用の効果を確認しています。
- Q. 費用の上限はいくらですか?
- A. 費用の上限は、1人年間10万円です。
となります。「みんなで育むスキルアップ支援制度・正式導入を決定しました」に記載された内容に基づいて回答していることがわかります。
先にチャンクサイズとオーバーラップをそれぞれ前回の2倍の値にしたと書きました。というのも、元の値だとベクトル検索で得られたチャンクが少なく、費用の情報が得られなかったためです。
なお、Tool Messageがダブって表示されていますが、原因はよくわかりませんでした。
費用の上限はいくらですか? ================================== Ai Message ================================== Tool Calls: retrieve (call_N0fUwPtif0DCWo64fcmn748Y) Call ID: call_N0fUwPtif0DCWo64fcmn748Y Args: query: クリエーションライン スキルアップ支援制度 費用の上限 ================================= Tool Message ================================= Name: retrieve Source: {'source': 'https://www.creationline.com/tech-blog/hr/76913'} Content: スキルアップ支援制度について Source: {'source': 'https://www.creationline.com/tech-blog/hr/76913'} Content: スキルアップ支援制度とは、スキル開発費用の一部を会社が負担することで、メンバーの皆さんのスキル開発を奨励することを主目的とした制度です Source: {'source': 'https://www.creationline.com/tech-blog/hr/76913'} Content: スキルアップ支援制度について Source: {'source': 'https://www.creationline.com/tech-blog/hr/76913'} Content: スキルアップ支援制度とは、スキル開発費用の一部を会社が負担することで、メンバーの皆さんのスキル開発を奨励することを主目的とした制度です ================================== Ai Message ================================== その情報はわかりません。
このように、同じ質問・同じデータソースであっても、チャンク化(チャンキング)の値によって結果が大きく変わってしまうため、この設定のチューニングが重要になります。広く情報をカバーするために大きくするとかえって不要な情報まで増えてしまったり、生成AIに投入するコストもかかってしまうため難しいところです。
全ソースコード
# https://python.langchain.com/docs/tutorials/qa_chat_history/ import os from dotenv import load_dotenv load_dotenv() from langchain.chat_models import init_chat_model llm = init_chat_model( model=os.getenv("AZURE_OPENAI_MODEL"), model_provider="azure_openai", azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), ) from langchain_openai import AzureOpenAIEmbeddings embeddings = AzureOpenAIEmbeddings( model=os.getenv("AZURE_OPENAI_EMBEDDING_MODEL"), azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), api_key=os.getenv("AZURE_OPENAI_API_KEY"), openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION"), ) from langchain_core.vectorstores import InMemoryVectorStore vector_store = InMemoryVectorStore(embeddings) import bs4 from langchain_community.document_loaders import WebBaseLoader loader = WebBaseLoader( web_paths=("https://www.creationline.com/tech-blog/hr/76913",), bs_kwargs=dict( parse_only=bs4.SoupStrainer( class_=("content_post", "tech_header") ) ), ) docs = loader.load() from langchain_text_splitters import RecursiveCharacterTextSplitter text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( separators=["\n\n", "\n", "。"], chunk_size=200, chunk_overlap=100 ) all_splits = text_splitter.split_documents(docs) _ = vector_store.add_documents(documents=all_splits) #----- from langgraph.graph import MessagesState, StateGraph graph_builder = StateGraph(MessagesState) #----- from langchain_core.tools import tool @tool(response_format="content_and_artifact") def retrieve(query: str): """Retrieve information related to a query.""" retrieved_docs = vector_store.similarity_search(query, k=2) serialized = "\n\n".join( (f"Source: {doc.metadata}\n" f"Content: {doc.page_content}") for doc in retrieved_docs ) return serialized, retrieved_docs #----- from langchain_core.messages import SystemMessage from langgraph.prebuilt import ToolNode def query_or_respond(state: MessagesState): llm_with_tools = llm.bind_tools([retrieve]) response = llm_with_tools.invoke(state["messages"]) return {"messages": [response]} #----- tools = ToolNode([retrieve]) #----- def generate(state: MessagesState): recent_tool_messages = [] for message in reversed(state["messages"]): if message.type == "tool": recent_tool_messages.append(message) else: break tool_messages = recent_tool_messages[::-1] #----- docs_content = "\n\n".join(doc.content for doc in tool_messages) print(docs_content) system_message_content = ( "You are an assistant for question-answering tasks. " "Use the following pieces of retrieved context to answer " "the question. If you don't know the answer, say that you " "don't know. Use three sentences maximum and keep the " "answer concise." "\n\n" f"{docs_content}" ) #----- conversation_messages = [ message for message in state["messages"] if message.type in ("human", "system") or (message.type == "ai" and not message.tool_calls) ] prompt = [SystemMessage(system_message_content)] + conversation_messages #----- response = llm.invoke(prompt) return {"messages": [response]} #----- from langgraph.graph import END from langgraph.prebuilt import ToolNode, tools_condition graph_builder.add_node(query_or_respond) graph_builder.add_node(tools) graph_builder.add_node(generate) graph_builder.set_entry_point("query_or_respond") graph_builder.add_conditional_edges( "query_or_respond", tools_condition, {END: END, "tools": "tools"}, ) graph_builder.add_edge("tools", "generate") graph_builder.add_edge("generate", END) #----- from langgraph.checkpoint.memory import MemorySaver memory = MemorySaver() graph = graph_builder.compile(checkpointer=memory) config = {"configurable": {"thread_id": "abc123"}} #----- input_messages = [ "クリエーションラインの「スキルアップ支援制度」とは?", "費用の上限はいくらですか?", ] for input_message in input_messages: for step in graph.stream( {"messages": [{"role": "user", "content": input_message}]}, stream_mode="values", config=config, ): step["messages"][-1].pretty_print()
まとめ
本稿では LangChain のチュートリアル Build a Retrieval Augmented Generation (RAG) App: Part 2 を Azure OpenAI を使って実施してみた記録です。チュートリアルの内容をそのまま使うのではなく、弊社 Tech blog の記事「みんなで育むスキルアップ支援制度・正式導入を決定しました」に記載された内容に基づいて、「スキルアップ支援制度」について継続的な質疑応答ができるようにしました。
チュートリアルでは固定の質問でしたが、過去記事「LangGraphの会話履歴をメモリ保持しよう」「LangGraphの会話履歴をSQLiteに保持しよう」を参照すれば、質問を自由に入力できるように改良するのは容易だと思います。
また本稿では、チャンク化パラメータによってベクトル検索結果が左右されることも見てみました。回答の確からしさを向上させるためには、設定チューニングとテストが重要になると考えられます。
本ブログでは引き続き生成AIによる会話精度向上や機能について取り上げていきたいと思います。