Initial commit

This commit is contained in:
Aodhan Collins
2025-10-11 21:21:36 +01:00
commit eccd456c59
29 changed files with 5375 additions and 0 deletions

397
main.py Normal file
View File

@@ -0,0 +1,397 @@
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Dict, List, Optional
import uuid
import os
from dotenv import load_dotenv
from openai import AsyncOpenAI
import asyncio
from datetime import datetime
import httpx
# Load environment variables
load_dotenv()
# Initialize FastAPI
app = FastAPI(title="Storyteller RPG API")
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Initialize OpenAI
client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
if not os.getenv("OPENAI_API_KEY") and not openrouter_api_key:
print("Warning: Neither OPENAI_API_KEY nor OPENROUTER_API_KEY set. AI features will not work.")
# Models
class Message(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
sender: str # "character" or "storyteller"
content: str
timestamp: str = Field(default_factory=lambda: datetime.now().isoformat())
class Character(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str
description: str
personality: str = "" # Additional personality traits
llm_model: str = "gpt-3.5-turbo" # LLM model for this character
conversation_history: List[Message] = [] # Private conversation with storyteller
pending_response: bool = False # Waiting for storyteller response
class StorytellerResponse(BaseModel):
character_id: str
content: str
class GameSession(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
name: str
characters: Dict[str, Character] = {}
current_scene: str = ""
scene_history: List[str] = [] # All scenes narrated
# In-memory storage (replace with database in production)
sessions: Dict[str, GameSession] = {}
# WebSocket connection manager
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {} # key: "session_character" or "session_storyteller"
async def connect(self, websocket: WebSocket, client_id: str):
await websocket.accept()
self.active_connections[client_id] = websocket
def disconnect(self, client_id: str):
if client_id in self.active_connections:
del self.active_connections[client_id]
async def send_to_client(self, client_id: str, message: dict):
if client_id in self.active_connections:
await self.active_connections[client_id].send_json(message)
manager = ConnectionManager()
# API Endpoints
@app.post("/sessions/")
async def create_session(name: str):
session = GameSession(name=name)
sessions[session.id] = session
return session
@app.get("/sessions/{session_id}")
async def get_session(session_id: str):
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
return sessions[session_id]
@app.post("/sessions/{session_id}/characters/")
async def add_character(
session_id: str,
name: str,
description: str,
personality: str = "",
llm_model: str = "gpt-3.5-turbo"
):
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
character = Character(
name=name,
description=description,
personality=personality,
llm_model=llm_model
)
session = sessions[session_id]
session.characters[character.id] = character
# Notify storyteller of new character
storyteller_key = f"{session_id}_storyteller"
if storyteller_key in manager.active_connections:
await manager.send_to_client(storyteller_key, {
"type": "character_joined",
"character": {
"id": character.id,
"name": character.name,
"description": character.description,
"llm_model": character.llm_model
}
})
return character
# WebSocket endpoint for character interactions (character view)
@app.websocket("/ws/character/{session_id}/{character_id}")
async def character_websocket(websocket: WebSocket, session_id: str, character_id: str):
if session_id not in sessions or character_id not in sessions[session_id].characters:
await websocket.close(code=1008, reason="Session or character not found")
return
client_key = f"{session_id}_{character_id}"
await manager.connect(websocket, client_key)
try:
# Send conversation history
session = sessions[session_id]
character = session.characters[character_id]
await websocket.send_json({
"type": "history",
"messages": [msg.dict() for msg in character.conversation_history]
})
while True:
data = await websocket.receive_json()
if data.get("type") == "message":
# Character sends message to storyteller
message = Message(sender="character", content=data["content"])
character.conversation_history.append(message)
character.pending_response = True
# Forward to storyteller
storyteller_key = f"{session_id}_storyteller"
if storyteller_key in manager.active_connections:
await manager.send_to_client(storyteller_key, {
"type": "character_message",
"character_id": character_id,
"character_name": character.name,
"message": message.dict()
})
except WebSocketDisconnect:
manager.disconnect(client_key)
# WebSocket endpoint for storyteller
@app.websocket("/ws/storyteller/{session_id}")
async def storyteller_websocket(websocket: WebSocket, session_id: str):
if session_id not in sessions:
await websocket.close(code=1008, reason="Session not found")
return
client_key = f"{session_id}_storyteller"
await manager.connect(websocket, client_key)
try:
# Send all characters and their conversation states
session = sessions[session_id]
await websocket.send_json({
"type": "session_state",
"characters": {
char_id: {
"id": char.id,
"name": char.name,
"description": char.description,
"personality": char.personality,
"conversation_history": [msg.dict() for msg in char.conversation_history],
"pending_response": char.pending_response
}
for char_id, char in session.characters.items()
},
"current_scene": session.current_scene
})
while True:
data = await websocket.receive_json()
if data.get("type") == "respond_to_character":
# Storyteller responds to a specific character
character_id = data["character_id"]
content = data["content"]
if character_id in session.characters:
character = session.characters[character_id]
message = Message(sender="storyteller", content=content)
character.conversation_history.append(message)
character.pending_response = False
# Send to character
char_key = f"{session_id}_{character_id}"
if char_key in manager.active_connections:
await manager.send_to_client(char_key, {
"type": "storyteller_response",
"message": message.dict()
})
elif data.get("type") == "narrate_scene":
# Broadcast scene to all characters
scene = data["content"]
session.current_scene = scene
session.scene_history.append(scene)
# Send to all connected characters
for char_id in session.characters:
char_key = f"{session_id}_{char_id}"
if char_key in manager.active_connections:
await manager.send_to_client(char_key, {
"type": "scene_narration",
"content": scene
})
except WebSocketDisconnect:
manager.disconnect(client_key)
# AI-assisted response generation using character's specific LLM
async def call_llm(model: str, messages: List[dict], temperature: float = 0.8, max_tokens: int = 200) -> str:
"""Call LLM via OpenRouter or OpenAI depending on model"""
# OpenAI models
if model.startswith("gpt-") or model.startswith("o1-"):
if not os.getenv("OPENAI_API_KEY"):
return "OpenAI API key not set."
try:
response = await client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content
except Exception as e:
return f"OpenAI error: {str(e)}"
# OpenRouter models (Claude, Llama, Gemini, etc.)
else:
if not openrouter_api_key:
return "OpenRouter API key not set."
try:
async with httpx.AsyncClient() as http_client:
response = await http_client.post(
"https://openrouter.ai/api/v1/chat/completions",
headers={
"Authorization": f"Bearer {openrouter_api_key}",
"HTTP-Referer": "http://localhost:3000",
"X-Title": "Storyteller RPG"
},
json={
"model": model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens
},
timeout=30.0
)
response.raise_for_status()
data = response.json()
return data["choices"][0]["message"]["content"]
except Exception as e:
return f"OpenRouter error: {str(e)}"
@app.post("/sessions/{session_id}/generate_suggestion")
async def generate_suggestion(session_id: str, character_id: str, context: str = ""):
"""Generate AI suggestion for storyteller response to a character using the character's LLM"""
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
session = sessions[session_id]
if character_id not in session.characters:
raise HTTPException(status_code=404, detail="Character not found")
character = session.characters[character_id]
# Prepare context for AI suggestion
messages = [
{
"role": "system",
"content": f"You are {character.name} in an RPG. Respond in character. Character description: {character.description}. Personality: {character.personality}. Current scene: {session.current_scene}"
}
]
# Add recent conversation history
for msg in character.conversation_history[-6:]:
role = "assistant" if msg.sender == "character" else "user"
messages.append({"role": role, "content": msg.content})
if context:
messages.append({"role": "user", "content": f"Additional context: {context}"})
try:
suggestion = await call_llm(character.llm_model, messages, temperature=0.8, max_tokens=200)
return {"suggestion": suggestion, "model_used": character.llm_model}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error generating suggestion: {str(e)}")
# Get available LLM models
@app.get("/models")
async def get_available_models():
"""Get list of available LLM models"""
models = {
"openai": [],
"openrouter": []
}
if os.getenv("OPENAI_API_KEY"):
models["openai"] = [
{"id": "gpt-4o", "name": "GPT-4o (Latest)", "provider": "OpenAI"},
{"id": "gpt-4-turbo", "name": "GPT-4 Turbo", "provider": "OpenAI"},
{"id": "gpt-4", "name": "GPT-4", "provider": "OpenAI"},
{"id": "gpt-3.5-turbo", "name": "GPT-3.5 Turbo (Fast & Cheap)", "provider": "OpenAI"},
]
if openrouter_api_key:
models["openrouter"] = [
{"id": "anthropic/claude-3.5-sonnet", "name": "Claude 3.5 Sonnet", "provider": "Anthropic"},
{"id": "anthropic/claude-3-opus", "name": "Claude 3 Opus", "provider": "Anthropic"},
{"id": "anthropic/claude-3-haiku", "name": "Claude 3 Haiku (Fast)", "provider": "Anthropic"},
{"id": "google/gemini-pro-1.5", "name": "Gemini Pro 1.5", "provider": "Google"},
{"id": "meta-llama/llama-3.1-70b-instruct", "name": "Llama 3.1 70B", "provider": "Meta"},
{"id": "meta-llama/llama-3.1-8b-instruct", "name": "Llama 3.1 8B (Fast)", "provider": "Meta"},
{"id": "mistralai/mistral-large", "name": "Mistral Large", "provider": "Mistral"},
{"id": "cohere/command-r-plus", "name": "Command R+", "provider": "Cohere"},
]
return models
# Get all pending character messages
@app.get("/sessions/{session_id}/pending_messages")
async def get_pending_messages(session_id: str):
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
session = sessions[session_id]
pending = {}
for char_id, char in session.characters.items():
if char.pending_response:
last_message = char.conversation_history[-1] if char.conversation_history else None
if last_message and last_message.sender == "character":
pending[char_id] = {
"character_name": char.name,
"message": last_message.dict()
}
return pending
# Get character conversation history (for storyteller)
@app.get("/sessions/{session_id}/characters/{character_id}/conversation")
async def get_character_conversation(session_id: str, character_id: str):
if session_id not in sessions:
raise HTTPException(status_code=404, detail="Session not found")
session = sessions[session_id]
if character_id not in session.characters:
raise HTTPException(status_code=404, detail="Character not found")
character = session.characters[character_id]
return {
"character": {
"id": character.id,
"name": character.name,
"description": character.description,
"personality": character.personality
},
"conversation": [msg.dict() for msg in character.conversation_history],
"pending_response": character.pending_response
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)