Custom LLM Provider Guide¶
This guide shows how to integrate custom language model providers with bruno-core.
Overview¶
Bruno-core uses the LLMInterface to abstract language model interactions. You can implement this interface to integrate any LLM provider:
- Cloud APIs: OpenAI, Anthropic Claude, Google PaLM, Cohere
- Self-hosted: Ollama, LM Studio, text-generation-webui
- Custom models: Your own fine-tuned models
- Hybrid: Multiple providers with fallback logic
LLMInterface Contract¶
from bruno_core.interfaces import LLMInterface
from bruno_core.models import Message
from typing import AsyncIterator
class CustomLLM(LLMInterface):
async def generate(self, messages: list[Message], **kwargs) -> str:
"""Generate a complete response."""
raise NotImplementedError
async def stream(self, messages: list[Message], **kwargs) -> AsyncIterator[str]:
"""Stream response tokens."""
raise NotImplementedError
def get_token_count(self, text: str) -> int:
"""Estimate token count."""
raise NotImplementedError
def list_models(self) -> list[str]:
"""List available models."""
raise NotImplementedError
Basic Implementation¶
Simple LLM Provider¶
import asyncio
from bruno_core.interfaces import LLMInterface
from bruno_core.models import Message, MessageRole
class SimpleLLM(LLMInterface):
"""Minimal LLM implementation."""
def __init__(self, api_key: str, model: str = "gpt-4"):
self.api_key = api_key
self.model = model
async def generate(self, messages: list[Message], **kwargs) -> str:
# Format messages for your API
formatted = self._format_messages(messages)
# Make API call
response = await self._call_api(formatted, **kwargs)
return response
async def stream(self, messages: list[Message], **kwargs):
formatted = self._format_messages(messages)
async for token in self._stream_api(formatted, **kwargs):
yield token
def get_token_count(self, text: str) -> int:
# Simple approximation
return len(text.split())
def list_models(self) -> list[str]:
return ["gpt-4", "gpt-3.5-turbo"]
def _format_messages(self, messages: list[Message]) -> list[dict]:
return [
{"role": msg.role.value, "content": msg.content}
for msg in messages
]
async def _call_api(self, messages: list[dict], **kwargs) -> str:
# Your API call logic here
pass
async def _stream_api(self, messages: list[dict], **kwargs):
# Your streaming API call logic here
yield "token"
Real-World Implementations¶
OpenAI Provider¶
from openai import AsyncOpenAI
from bruno_core.interfaces import LLMInterface
from bruno_core.models import Message
class OpenAILLM(LLMInterface):
"""OpenAI GPT integration."""
def __init__(self, api_key: str, model: str = "gpt-4"):
self.client = AsyncOpenAI(api_key=api_key)
self.model = model
async def generate(self, messages: list[Message], **kwargs) -> str:
response = await self.client.chat.completions.create(
model=self.model,
messages=[
{"role": msg.role.value, "content": msg.content}
for msg in messages
],
**kwargs
)
return response.choices[0].message.content
async def stream(self, messages: list[Message], **kwargs):
stream = await self.client.chat.completions.create(
model=self.model,
messages=[
{"role": msg.role.value, "content": msg.content}
for msg in messages
],
stream=True,
**kwargs
)
async for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
def get_token_count(self, text: str) -> int:
import tiktoken
encoding = tiktoken.encoding_for_model(self.model)
return len(encoding.encode(text))
def list_models(self) -> list[str]:
return ["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo"]
Anthropic Claude Provider¶
from anthropic import AsyncAnthropic
from bruno_core.interfaces import LLMInterface
from bruno_core.models import Message, MessageRole
class ClaudeLLM(LLMInterface):
"""Anthropic Claude integration."""
def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"):
self.client = AsyncAnthropic(api_key=api_key)
self.model = model
async def generate(self, messages: list[Message], **kwargs) -> str:
# Separate system message if present
system = None
claude_messages = []
for msg in messages:
if msg.role == MessageRole.SYSTEM:
system = msg.content
else:
claude_messages.append({
"role": msg.role.value,
"content": msg.content
})
response = await self.client.messages.create(
model=self.model,
messages=claude_messages,
system=system,
max_tokens=kwargs.get("max_tokens", 1024),
**kwargs
)
return response.content[0].text
async def stream(self, messages: list[Message], **kwargs):
system = None
claude_messages = []
for msg in messages:
if msg.role == MessageRole.SYSTEM:
system = msg.content
else:
claude_messages.append({
"role": msg.role.value,
"content": msg.content
})
async with self.client.messages.stream(
model=self.model,
messages=claude_messages,
system=system,
max_tokens=kwargs.get("max_tokens", 1024),
**kwargs
) as stream:
async for text in stream.text_stream:
yield text
def get_token_count(self, text: str) -> int:
# Claude doesn't have a public tokenizer yet
# Use approximation: ~1 token per 4 characters
return len(text) // 4
def list_models(self) -> list[str]:
return [
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307"
]
Ollama Local Provider¶
import aiohttp
from bruno_core.interfaces import LLMInterface
from bruno_core.models import Message
class OllamaLLM(LLMInterface):
"""Ollama local LLM integration."""
def __init__(self, base_url: str = "http://localhost:11434", model: str = "llama2"):
self.base_url = base_url
self.model = model
async def generate(self, messages: list[Message], **kwargs) -> str:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/api/chat",
json={
"model": self.model,
"messages": [
{"role": msg.role.value, "content": msg.content}
for msg in messages
],
"stream": False,
**kwargs
}
) as response:
data = await response.json()
return data["message"]["content"]
async def stream(self, messages: list[Message], **kwargs):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/api/chat",
json={
"model": self.model,
"messages": [
{"role": msg.role.value, "content": msg.content}
for msg in messages
],
"stream": True,
**kwargs
}
) as response:
async for line in response.content:
if line:
import json
data = json.loads(line)
if "message" in data:
yield data["message"]["content"]
def get_token_count(self, text: str) -> int:
return len(text.split())
def list_models(self) -> list[str]:
import requests
response = requests.get(f"{self.base_url}/api/tags")
return [model["name"] for model in response.json()["models"]]
Advanced Features¶
Rate Limiting¶
import time
import asyncio
from bruno_core.interfaces import LLMInterface
class RateLimitedLLM(LLMInterface):
"""LLM with rate limiting."""
def __init__(self, base_llm: LLMInterface, requests_per_minute: int = 60):
self.base_llm = base_llm
self.rpm = requests_per_minute
self.last_request = 0.0
async def _wait_for_rate_limit(self):
min_interval = 60.0 / self.rpm
elapsed = time.time() - self.last_request
if elapsed < min_interval:
await asyncio.sleep(min_interval - elapsed)
self.last_request = time.time()
async def generate(self, messages, **kwargs):
await self._wait_for_rate_limit()
return await self.base_llm.generate(messages, **kwargs)
async def stream(self, messages, **kwargs):
await self._wait_for_rate_limit()
async for token in self.base_llm.stream(messages, **kwargs):
yield token
def get_token_count(self, text: str) -> int:
return self.base_llm.get_token_count(text)
def list_models(self) -> list[str]:
return self.base_llm.list_models()
Retry Logic¶
import asyncio
from bruno_core.interfaces import LLMInterface
class RetryLLM(LLMInterface):
"""LLM with automatic retries."""
def __init__(self, base_llm: LLMInterface, max_retries: int = 3):
self.base_llm = base_llm
self.max_retries = max_retries
async def generate(self, messages, **kwargs):
last_error = None
for attempt in range(self.max_retries):
try:
return await self.base_llm.generate(messages, **kwargs)
except Exception as e:
last_error = e
if attempt < self.max_retries - 1:
await asyncio.sleep(2 ** attempt) # Exponential backoff
continue
raise Exception(f"Failed after {self.max_retries} attempts: {last_error}")
async def stream(self, messages, **kwargs):
async for token in self.base_llm.stream(messages, **kwargs):
yield token
def get_token_count(self, text: str) -> int:
return self.base_llm.get_token_count(text)
def list_models(self) -> list[str]:
return self.base_llm.list_models()
Multi-Provider Fallback¶
from bruno_core.interfaces import LLMInterface
from typing import List
class FallbackLLM(LLMInterface):
"""Try multiple providers in order."""
def __init__(self, providers: List[LLMInterface]):
self.providers = providers
async def generate(self, messages, **kwargs):
last_error = None
for provider in self.providers:
try:
return await provider.generate(messages, **kwargs)
except Exception as e:
last_error = e
continue
raise Exception(f"All providers failed. Last error: {last_error}")
async def stream(self, messages, **kwargs):
for provider in self.providers:
try:
async for token in provider.stream(messages, **kwargs):
yield token
return
except Exception:
continue
def get_token_count(self, text: str) -> int:
return self.providers[0].get_token_count(text)
def list_models(self) -> list[str]:
models = []
for provider in self.providers:
models.extend(provider.list_models())
return list(set(models))
Usage¶
Basic Usage¶
from bruno_core.base import BaseAssistant
# Create LLM
llm = OpenAILLM(api_key="your-key")
# Create assistant
assistant = BaseAssistant(llm=llm, memory=memory)
await assistant.initialize()
With Rate Limiting¶
base_llm = OpenAILLM(api_key="your-key")
llm = RateLimitedLLM(base_llm, requests_per_minute=30)
assistant = BaseAssistant(llm=llm, memory=memory)
With Fallback¶
primary = OpenAILLM(api_key="openai-key")
fallback = ClaudeLLM(api_key="claude-key")
local = OllamaLLM()
llm = FallbackLLM([primary, fallback, local])
assistant = BaseAssistant(llm=llm, memory=memory)
Testing¶
import pytest
from bruno_core.models import Message, MessageRole
@pytest.mark.asyncio
async def test_custom_llm():
llm = CustomLLM(api_key="test-key")
messages = [
Message(role=MessageRole.USER, content="Hello")
]
response = await llm.generate(messages)
assert isinstance(response, str)
assert len(response) > 0
# Test streaming
tokens = []
async for token in llm.stream(messages):
tokens.append(token)
assert len(tokens) > 0
Best Practices¶
- Error Handling: Always handle API errors gracefully
- Rate Limiting: Implement rate limiting for cloud APIs
- Retries: Add retry logic with exponential backoff
- Timeouts: Set reasonable timeouts for API calls
- Token Counting: Use provider-specific tokenizers when available
- Streaming: Implement streaming for better UX
- Context Management: Handle context window limits
- Cost Tracking: Log token usage for cost monitoring