Files
storyteller/main.py
Aodhan Collins eccd456c59 Initial commit
2025-10-11 21:21:36 +01:00

398 lines
15 KiB
Python

from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Dict, List, Optional
import uuid
import os
from dotenv import load_dotenv
from openai import AsyncOpenAI
import asyncio
from datetime import datetime
import httpx
# Load environment variables
load_dotenv()
# Initialize FastAPI
app = FastAPI(title="Storyteller RPG API")
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize OpenAI
client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
if not os.getenv("OPENAI_API_KEY") and not openrouter_api_key:
print("Warning: Neither OPENAI_API_KEY nor OPENROUTER_API_KEY set. AI features will not work.")
# Models
class Message(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
sender: str # "character" or "storyteller"
content: str
timestamp: str = Field(default_factory=lambda: datetime.now().isoformat())
class Character(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str
description: str
personality: str = "" # Additional personality traits
llm_model: str = "gpt-3.5-turbo" # LLM model for this character
conversation_history: List[Message] = [] # Private conversation with storyteller
pending_response: bool = False # Waiting for storyteller response
class StorytellerResponse(BaseModel):
character_id: str
content: str
class GameSession(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str
characters: Dict[str, Character] = {}
current_scene: str = ""
scene_history: List[str] = [] # All scenes narrated
# In-memory storage (replace with database in production)
sessions: Dict[str, GameSession] = {}
# WebSocket connection manager
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {} # key: "session_character" or "session_storyteller"
async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket
def disconnect(self, client_id: str):
if client_id in self.active_connections:
del self.active_connections[client_id]
async def send_to_client(self, client_id: str, message: dict):
if client_id in self.active_connections:
await self.active_connections[client_id].send_json(message)
manager = ConnectionManager()
# API Endpoints
@app.post("/sessions/")
async def create_session(name: str):
session = GameSession(name=name)
sessions[session.id] = session
return session
@app.get("/sessions/{session_id}")
async def get_session(session_id: str):
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
return sessions[session_id]
@app.post("/sessions/{session_id}/characters/")
async def add_character(
session_id: str,
name: str,
description: str,
personality: str = "",
llm_model: str = "gpt-3.5-turbo"
):
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
character = Character(
name=name,
description=description,
personality=personality,
llm_model=llm_model
)
session = sessions[session_id]
session.characters[character.id] = character
# Notify storyteller of new character
storyteller_key = f"{session_id}_storyteller"
if storyteller_key in manager.active_connections:
await manager.send_to_client(storyteller_key, {
"type": "character_joined",
"character": {
"id": character.id,
"name": character.name,
"description": character.description,
"llm_model": character.llm_model
}
})
return character
# WebSocket endpoint for character interactions (character view)
@app.websocket("/ws/character/{session_id}/{character_id}")
async def character_websocket(websocket: WebSocket, session_id: str, character_id: str):
if session_id not in sessions or character_id not in sessions[session_id].characters:
await websocket.close(code=1008, reason="Session or character not found")
return
client_key = f"{session_id}_{character_id}"
await manager.connect(websocket, client_key)
try:
# Send conversation history
session = sessions[session_id]
character = session.characters[character_id]
await websocket.send_json({
"type": "history",
"messages": [msg.dict() for msg in character.conversation_history]
})
while True:
data = await websocket.receive_json()
if data.get("type") == "message":
# Character sends message to storyteller
message = Message(sender="character", content=data["content"])
character.conversation_history.append(message)
character.pending_response = True
# Forward to storyteller
storyteller_key = f"{session_id}_storyteller"
if storyteller_key in manager.active_connections:
await manager.send_to_client(storyteller_key, {
"type": "character_message",
"character_id": character_id,
"character_name": character.name,
"message": message.dict()
})
except WebSocketDisconnect:
manager.disconnect(client_key)
# WebSocket endpoint for storyteller
@app.websocket("/ws/storyteller/{session_id}")
async def storyteller_websocket(websocket: WebSocket, session_id: str):
if session_id not in sessions:
await websocket.close(code=1008, reason="Session not found")
return
client_key = f"{session_id}_storyteller"
await manager.connect(websocket, client_key)
try:
# Send all characters and their conversation states
session = sessions[session_id]
await websocket.send_json({
"type": "session_state",
"characters": {
char_id: {
"id": char.id,
"name": char.name,
"description": char.description,
"personality": char.personality,
"conversation_history": [msg.dict() for msg in char.conversation_history],
"pending_response": char.pending_response
}
for char_id, char in session.characters.items()
},
"current_scene": session.current_scene
})
while True:
data = await websocket.receive_json()
if data.get("type") == "respond_to_character":
# Storyteller responds to a specific character
character_id = data["character_id"]
content = data["content"]
if character_id in session.characters:
character = session.characters[character_id]
message = Message(sender="storyteller", content=content)
character.conversation_history.append(message)
character.pending_response = False
# Send to character
char_key = f"{session_id}_{character_id}"
if char_key in manager.active_connections:
await manager.send_to_client(char_key, {
"type": "storyteller_response",
"message": message.dict()
})
elif data.get("type") == "narrate_scene":
# Broadcast scene to all characters
scene = data["content"]
session.current_scene = scene
session.scene_history.append(scene)
# Send to all connected characters
for char_id in session.characters:
char_key = f"{session_id}_{char_id}"
if char_key in manager.active_connections:
await manager.send_to_client(char_key, {
"type": "scene_narration",
"content": scene
})
except WebSocketDisconnect:
manager.disconnect(client_key)
# AI-assisted response generation using character's specific LLM
async def call_llm(model: str, messages: List[dict], temperature: float = 0.8, max_tokens: int = 200) -> str:
"""Call LLM via OpenRouter or OpenAI depending on model"""
# OpenAI models
if model.startswith("gpt-") or model.startswith("o1-"):
if not os.getenv("OPENAI_API_KEY"):
return "OpenAI API key not set."
try:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content
except Exception as e:
return f"OpenAI error: {str(e)}"
# OpenRouter models (Claude, Llama, Gemini, etc.)
else:
if not openrouter_api_key:
return "OpenRouter API key not set."
try:
async with httpx.AsyncClient() as http_client:
response = await http_client.post(
"https://openrouter.ai/api/v1/chat/completions",
headers={
"Authorization": f"Bearer {openrouter_api_key}",
"HTTP-Referer": "http://localhost:3000",
"X-Title": "Storyteller RPG"
},
json={
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens
},
timeout=30.0
)
response.raise_for_status()
data = response.json()
return data["choices"][0]["message"]["content"]
except Exception as e:
return f"OpenRouter error: {str(e)}"
@app.post("/sessions/{session_id}/generate_suggestion")
async def generate_suggestion(session_id: str, character_id: str, context: str = ""):
"""Generate AI suggestion for storyteller response to a character using the character's LLM"""
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
session = sessions[session_id]
if character_id not in session.characters:
raise HTTPException(status_code=404, detail="Character not found")
character = session.characters[character_id]
# Prepare context for AI suggestion
messages = [
{
"role": "system",
"content": f"You are {character.name} in an RPG. Respond in character. Character description: {character.description}. Personality: {character.personality}. Current scene: {session.current_scene}"
}
]
# Add recent conversation history
for msg in character.conversation_history[-6:]:
role = "assistant" if msg.sender == "character" else "user"
messages.append({"role": role, "content": msg.content})
if context:
messages.append({"role": "user", "content": f"Additional context: {context}"})
try:
suggestion = await call_llm(character.llm_model, messages, temperature=0.8, max_tokens=200)
return {"suggestion": suggestion, "model_used": character.llm_model}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating suggestion: {str(e)}")
# Get available LLM models
@app.get("/models")
async def get_available_models():
"""Get list of available LLM models"""
models = {
"openai": [],
"openrouter": []
}
if os.getenv("OPENAI_API_KEY"):
models["openai"] = [
{"id": "gpt-4o", "name": "GPT-4o (Latest)", "provider": "OpenAI"},
{"id": "gpt-4-turbo", "name": "GPT-4 Turbo", "provider": "OpenAI"},
{"id": "gpt-4", "name": "GPT-4", "provider": "OpenAI"},
{"id": "gpt-3.5-turbo", "name": "GPT-3.5 Turbo (Fast & Cheap)", "provider": "OpenAI"},
]
if openrouter_api_key:
models["openrouter"] = [
{"id": "anthropic/claude-3.5-sonnet", "name": "Claude 3.5 Sonnet", "provider": "Anthropic"},
{"id": "anthropic/claude-3-opus", "name": "Claude 3 Opus", "provider": "Anthropic"},
{"id": "anthropic/claude-3-haiku", "name": "Claude 3 Haiku (Fast)", "provider": "Anthropic"},
{"id": "google/gemini-pro-1.5", "name": "Gemini Pro 1.5", "provider": "Google"},
{"id": "meta-llama/llama-3.1-70b-instruct", "name": "Llama 3.1 70B", "provider": "Meta"},
{"id": "meta-llama/llama-3.1-8b-instruct", "name": "Llama 3.1 8B (Fast)", "provider": "Meta"},
{"id": "mistralai/mistral-large", "name": "Mistral Large", "provider": "Mistral"},
{"id": "cohere/command-r-plus", "name": "Command R+", "provider": "Cohere"},
]
return models
# Get all pending character messages
@app.get("/sessions/{session_id}/pending_messages")
async def get_pending_messages(session_id: str):
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
session = sessions[session_id]
pending = {}
for char_id, char in session.characters.items():
if char.pending_response:
last_message = char.conversation_history[-1] if char.conversation_history else None
if last_message and last_message.sender == "character":
pending[char_id] = {
"character_name": char.name,
"message": last_message.dict()
}
return pending
# Get character conversation history (for storyteller)
@app.get("/sessions/{session_id}/characters/{character_id}/conversation")
async def get_character_conversation(session_id: str, character_id: str):
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
session = sessions[session_id]
if character_id not in session.characters:
raise HTTPException(status_code=404, detail="Character not found")
character = session.characters[character_id]
return {
"character": {
"id": character.id,
"name": character.name,
"description": character.description,
"personality": character.personality
},
"conversation": [msg.dict() for msg in character.conversation_history],
"pending_response": character.pending_response
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)