SurrealDB
SurrealDB Docs Logo

Enter a search query

Navigation

Pydantic AI

Pydantic AI is a Python agent framework designed to help you quickly, confidently, and painlessly build production grade applications and workflows with Generative AI.

Example

This is a simple RAG application that uses Pydantic AI and embedded SurrealDB. The integration is done by providing the agent with a custom retrieval tool, which takes a search query, executes a SurrealDB vector-search query, and returns the results.

To run the example:

Set up your OpenAI API key:

export OPENAI_API_KEY=your-api-key

Or, store it in a .env file and add --env-file .env to your uv run commands.

Build the vector store:

uv run --env-file .env -m pydantic_ai_examples.rag_surrealdb build

Ask the agent a question:

uv run --env-file .env -m pydantic_ai_examples.rag_surrealdb search "How do I register a function as a custom tool for my agent?"

Or use the web UI:

uv run --env-file .env -m pydantic_ai_examples.rag_surrealdb web

Code

from __future__ import annotations as _annotations import asyncio import re import sys import unicodedata from collections.abc import Sequence from contextlib import asynccontextmanager from dataclasses import dataclass from pathlib import Path from typing import TypeVar import httpx import logfire import uvicorn from pydantic import BaseModel, TypeAdapter from surrealdb import ( AsyncEmbeddedSurrealConnection, AsyncHttpSurrealConnection, AsyncSurreal, AsyncWsSurrealConnection, RecordID, Value, ) from typing_extensions import AsyncGenerator from pydantic_ai import Agent, Embedder SurrealConn = ( AsyncWsSurrealConnection | AsyncHttpSurrealConnection | AsyncEmbeddedSurrealConnection ) # 'if-token-present' means nothing will be sent (and the example will work) if you don't have logfire configured logfire.configure(send_to_logfire='if-token-present') logfire.instrument_pydantic_ai() logfire.instrument_surrealdb() THIS_DIR = Path(__file__).parent SURREALDB_NS = 'pydantic_ai_examples' SURREALDB_DB = 'rag_surrealdb' SURREALDB_USER = 'root' SURREALDB_PASS = 'root' embedder = Embedder('openai:text-embedding-3-small') agent = Agent('openai:gpt-5.2') RecordType = TypeVar('RecordType') class RetrievalQueryResult(BaseModel): url: str title: str content: str dist: float async def query( conn: SurrealConn, query_: str, vars_: dict[str, Value], record_type: type[RecordType], ) -> list[RecordType]: result = await conn.query(query_, vars_) result_ta = TypeAdapter(list[record_type]) rows = result_ta.validate_python(result) return rows @agent.tool_plain async def retrieve(search_query: str) -> str: """Retrieve documentation sections based on a search query. Args: search_query: The search query. """ with logfire.span( 'create embedding for {search_query=}', search_query=search_query ): result = await embedder.embed_query(search_query) embedding = result.embeddings # Embedder method guarantees there's one item here embedding_vector = embedding[0] # SurrealDB vector search using HNSW index async with database_connect(False) as db: rows = await query( db, """ SELECT url, title, content, vector::distance::knn() AS dist FROM doc_sections WHERE embedding <|8, 40|> $vector ORDER BY dist ASC """, {'vector': list(embedding_vector)}, RetrievalQueryResult, ) return '\n\n'.join( f'# {row.title}\nDocumentation URL:{row.url}\n\n{row.content}' for row in rows ) async def run_agent(question: str): """Entry point to run the agent and perform RAG based question answering.""" logfire.info('Asking "{question}"', question=question) answer = await agent.run(question) print(answer.output) # Web chat UI app = agent.to_web() ####################################################### # The rest of this file is dedicated to preparing the # # search database, and some utilities. # ####################################################### # JSON document from # https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992 DOCS_JSON = ( 'https://gist.githubusercontent.com/' 'samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/' '80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json' ) def build_doc_rec_id(url: str) -> RecordID: return RecordID('doc_sections', slugify(url, '_')) async def build_search_db(): """Build the search database.""" async with httpx.AsyncClient() as client: response = await client.get(DOCS_JSON) response.raise_for_status() sections = sections_ta.validate_json(response.content) async with database_connect(True) as db: missing_sections: list[DocsSection] = [] for section in sections: url = section.url() record_id = build_doc_rec_id(url) existing = await db.select(record_id) if existing: logfire.info('Skipping {url=}', url=url) continue missing_sections.append(section) if missing_sections: with logfire.span('create embeddings'): result = await embedder.embed_documents( [section.embedding_content() for section in missing_sections] ) embeddings = result.embeddings for section, embedding_vector in zip( missing_sections, embeddings, strict=True ): await insert_doc_section(db, section, embedding_vector) else: logfire.info('All documents already exist; skipping embedding generation') async def insert_doc_section( db: SurrealConn, section: DocsSection, embedding_vector: Sequence[float], ) -> None: url = section.url() record_id = build_doc_rec_id(url) # Create record with embedding, using record ID directly res = await db.create( record_id, { 'url': url, 'title': section.title, 'content': section.content, 'embedding': list(embedding_vector), }, ) if not isinstance(res, dict): raise ValueError(f'Unexpected response from database: {res}') @dataclass class DocsSection: id: int parent: int | None path: str level: int title: str content: str def url(self) -> str: url_path = re.sub(r'\.md$', '', self.path) return ( f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, "-")}' ) def embedding_content(self) -> str: return '\n\n'.join((f'path: {self.path}', f'title: {self.title}', self.content)) sections_ta = TypeAdapter(list[DocsSection]) @asynccontextmanager async def database_connect( create_db: bool = False, ) -> AsyncGenerator[SurrealConn, None]: # Running SurrealDB embedded db_path = THIS_DIR / f'.{SURREALDB_DB}' db_url = f'file://{db_path}' requires_auth = False # Running SurrealDB in a separate process, connect with URL # db_url = 'ws://localhost:8000/rpc' # requires_auth = True async with AsyncSurreal(db_url) as db: # Sign in to the database if requires_auth: await db.signin({'username': SURREALDB_USER, 'password': SURREALDB_PASS}) # Set namespace and database await db.use(SURREALDB_NS, SURREALDB_DB) # Initialize schema if creating database if create_db: with logfire.span('create schema'): await db.query(DB_SCHEMA) yield db DB_SCHEMA = """ DEFINE TABLE doc_sections SCHEMALESS; DEFINE FIELD embedding ON doc_sections TYPE array<float>; DEFINE INDEX hnsw_idx_doc_sections ON doc_sections FIELDS embedding HNSW DIMENSION 1536 DIST COSINE TYPE F32; """ def slugify(value: str, separator: str, unicode: bool = False) -> str: """Slugify a string, to make it URL friendly.""" # Taken unchanged from https://github.com/Python-Markdown/markdown/blob/3.7/markdown/extensions/toc.py#L38 if not unicode: # Replace Extended Latin characters with ASCII, i.e. `žlutý` => `zluty` value = unicodedata.normalize('NFKD', value) value = value.encode('ascii', 'ignore').decode('ascii') value = re.sub(r'[^\w\s-]', '', value).strip().lower() return re.sub(rf'[{separator}\s]+', separator, value) if __name__ == '__main__': action = sys.argv[1] if len(sys.argv) > 1 else None if action == 'build': asyncio.run(build_search_db()) elif action == 'search': if len(sys.argv) == 3: q = sys.argv[2] else: q = 'How do I configure logfire to work with FastAPI?' asyncio.run(run_agent(q)) elif action == 'web': uvicorn.run(app, host='127.0.0.1', port=7932) else: print( 'uv run --extra examples -m pydantic_ai_examples.rag_surrealdb build|search|web', file=sys.stderr, ) sys.exit(1)