Pydantic AI
Pydantic AI is a Python agent framework designed to help you quickly, confidently, and painlessly build production grade applications and workflows with Generative AI.
Example
This is a simple RAG application that uses Pydantic AI and embedded SurrealDB. The integration is done by providing the agent with a custom retrieval tool, which takes a search query, executes a SurrealDB vector-search query, and returns the results.
To run the example:
Set up your OpenAI API key:
export OPENAI_API_KEY=your-api-key
Or, store it in a .env file and add --env-file .env to your uv run commands.
Build the vector store:
uv run --env-file .env -m pydantic_ai_examples.rag_surrealdb build
Ask the agent a question:
uv run --env-file .env -m pydantic_ai_examples.rag_surrealdb search "How do I register a function as a custom tool for my agent?"
Or use the web UI:
uv run --env-file .env -m pydantic_ai_examples.rag_surrealdb web
Code
from __future__ import annotations as _annotations
import asyncio
import re
import sys
import unicodedata
from collections.abc import Sequence
from contextlib import asynccontextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import TypeVar
import httpx
import logfire
import uvicorn
from pydantic import BaseModel, TypeAdapter
from surrealdb import (
AsyncEmbeddedSurrealConnection,
AsyncHttpSurrealConnection,
AsyncSurreal,
AsyncWsSurrealConnection,
RecordID,
Value,
)
from typing_extensions import AsyncGenerator
from pydantic_ai import Agent, Embedder
SurrealConn = (
AsyncWsSurrealConnection
| AsyncHttpSurrealConnection
| AsyncEmbeddedSurrealConnection
)
logfire.configure(send_to_logfire='if-token-present')
logfire.instrument_pydantic_ai()
logfire.instrument_surrealdb()
THIS_DIR = Path(__file__).parent
SURREALDB_NS = 'pydantic_ai_examples'
SURREALDB_DB = 'rag_surrealdb'
SURREALDB_USER = 'root'
SURREALDB_PASS = 'root'
embedder = Embedder('openai:text-embedding-3-small')
agent = Agent('openai:gpt-5.2')
RecordType = TypeVar('RecordType')
class RetrievalQueryResult(BaseModel):
url: str
title: str
content: str
dist: float
async def query(
conn: SurrealConn,
query_: str,
vars_: dict[str, Value],
record_type: type[RecordType],
) -> list[RecordType]:
result = await conn.query(query_, vars_)
result_ta = TypeAdapter(list[record_type])
rows = result_ta.validate_python(result)
return rows
@agent.tool_plain
async def retrieve(search_query: str) -> str:
"""Retrieve documentation sections based on a search query.
Args:
search_query: The search query.
"""
with logfire.span(
'create embedding for {search_query=}', search_query=search_query
):
result = await embedder.embed_query(search_query)
embedding = result.embeddings
embedding_vector = embedding[0]
async with database_connect(False) as db:
rows = await query(
db,
"""
SELECT url, title, content, vector::distance::knn() AS dist
FROM doc_sections
WHERE embedding <|8, 40|> $vector
ORDER BY dist ASC
""",
{'vector': list(embedding_vector)},
RetrievalQueryResult,
)
return '\n\n'.join(
f'# {row.title}\nDocumentation URL:{row.url}\n\n{row.content}' for row in rows
)
async def run_agent(question: str):
"""Entry point to run the agent and perform RAG based question answering."""
logfire.info('Asking "{question}"', question=question)
answer = await agent.run(question)
print(answer.output)
app = agent.to_web()
DOCS_JSON = (
'https://gist.githubusercontent.com/'
'samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992/raw/'
'80c5925c42f1442c24963aaf5eb1a324d47afe95/logfire_docs.json'
)
def build_doc_rec_id(url: str) -> RecordID:
return RecordID('doc_sections', slugify(url, '_'))
async def build_search_db():
"""Build the search database."""
async with httpx.AsyncClient() as client:
response = await client.get(DOCS_JSON)
response.raise_for_status()
sections = sections_ta.validate_json(response.content)
async with database_connect(True) as db:
missing_sections: list[DocsSection] = []
for section in sections:
url = section.url()
record_id = build_doc_rec_id(url)
existing = await db.select(record_id)
if existing:
logfire.info('Skipping {url=}', url=url)
continue
missing_sections.append(section)
if missing_sections:
with logfire.span('create embeddings'):
result = await embedder.embed_documents(
[section.embedding_content() for section in missing_sections]
)
embeddings = result.embeddings
for section, embedding_vector in zip(
missing_sections, embeddings, strict=True
):
await insert_doc_section(db, section, embedding_vector)
else:
logfire.info('All documents already exist; skipping embedding generation')
async def insert_doc_section(
db: SurrealConn,
section: DocsSection,
embedding_vector: Sequence[float],
) -> None:
url = section.url()
record_id = build_doc_rec_id(url)
res = await db.create(
record_id,
{
'url': url,
'title': section.title,
'content': section.content,
'embedding': list(embedding_vector),
},
)
if not isinstance(res, dict):
raise ValueError(f'Unexpected response from database: {res}')
@dataclass
class DocsSection:
id: int
parent: int | None
path: str
level: int
title: str
content: str
def url(self) -> str:
url_path = re.sub(r'\.md$', '', self.path)
return (
f'https://logfire.pydantic.dev/docs/{url_path}/#{slugify(self.title, "-")}'
)
def embedding_content(self) -> str:
return '\n\n'.join((f'path: {self.path}', f'title: {self.title}', self.content))
sections_ta = TypeAdapter(list[DocsSection])
@asynccontextmanager
async def database_connect(
create_db: bool = False,
) -> AsyncGenerator[SurrealConn, None]:
db_path = THIS_DIR / f'.{SURREALDB_DB}'
db_url = f'file://{db_path}'
requires_auth = False
async with AsyncSurreal(db_url) as db:
if requires_auth:
await db.signin({'username': SURREALDB_USER, 'password': SURREALDB_PASS})
await db.use(SURREALDB_NS, SURREALDB_DB)
if create_db:
with logfire.span('create schema'):
await db.query(DB_SCHEMA)
yield db
DB_SCHEMA = """
DEFINE TABLE doc_sections SCHEMALESS;
DEFINE FIELD embedding ON doc_sections TYPE array<float>;
DEFINE INDEX hnsw_idx_doc_sections ON doc_sections
FIELDS embedding
HNSW DIMENSION 1536
DIST COSINE
TYPE F32;
"""
def slugify(value: str, separator: str, unicode: bool = False) -> str:
"""Slugify a string, to make it URL friendly."""
if not unicode:
value = unicodedata.normalize('NFKD', value)
value = value.encode('ascii', 'ignore').decode('ascii')
value = re.sub(r'[^\w\s-]', '', value).strip().lower()
return re.sub(rf'[{separator}\s]+', separator, value)
if __name__ == '__main__':
action = sys.argv[1] if len(sys.argv) > 1 else None
if action == 'build':
asyncio.run(build_search_db())
elif action == 'search':
if len(sys.argv) == 3:
q = sys.argv[2]
else:
q = 'How do I configure logfire to work with FastAPI?'
asyncio.run(run_agent(q))
elif action == 'web':
uvicorn.run(app, host='127.0.0.1', port=7932)
else:
print(
'uv run --extra examples -m pydantic_ai_examples.rag_surrealdb build|search|web',
file=sys.stderr,
)
sys.exit(1)