398 lines
15 KiB
Python
398 lines
15 KiB
Python
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
from typing import Dict, List, Optional
|
|
import uuid
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from openai import AsyncOpenAI
|
|
import asyncio
|
|
from datetime import datetime
|
|
import httpx
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Initialize FastAPI
|
|
app = FastAPI(title="Storyteller RPG API")
|
|
|
|
# CORS middleware
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Initialize OpenAI
|
|
client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
|
|
|
|
if not os.getenv("OPENAI_API_KEY") and not openrouter_api_key:
|
|
print("Warning: Neither OPENAI_API_KEY nor OPENROUTER_API_KEY set. AI features will not work.")
|
|
|
|
# Models
|
|
class Message(BaseModel):
|
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
sender: str # "character" or "storyteller"
|
|
content: str
|
|
timestamp: str = Field(default_factory=lambda: datetime.now().isoformat())
|
|
|
|
class Character(BaseModel):
|
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
name: str
|
|
description: str
|
|
personality: str = "" # Additional personality traits
|
|
llm_model: str = "gpt-3.5-turbo" # LLM model for this character
|
|
conversation_history: List[Message] = [] # Private conversation with storyteller
|
|
pending_response: bool = False # Waiting for storyteller response
|
|
|
|
class StorytellerResponse(BaseModel):
|
|
character_id: str
|
|
content: str
|
|
|
|
class GameSession(BaseModel):
|
|
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
name: str
|
|
characters: Dict[str, Character] = {}
|
|
current_scene: str = ""
|
|
scene_history: List[str] = [] # All scenes narrated
|
|
|
|
# In-memory storage (replace with database in production)
|
|
sessions: Dict[str, GameSession] = {}
|
|
|
|
# WebSocket connection manager
|
|
class ConnectionManager:
|
|
def __init__(self):
|
|
self.active_connections: Dict[str, WebSocket] = {} # key: "session_character" or "session_storyteller"
|
|
|
|
async def connect(self, websocket: WebSocket, client_id: str):
|
|
await websocket.accept()
|
|
self.active_connections[client_id] = websocket
|
|
|
|
def disconnect(self, client_id: str):
|
|
if client_id in self.active_connections:
|
|
del self.active_connections[client_id]
|
|
|
|
async def send_to_client(self, client_id: str, message: dict):
|
|
if client_id in self.active_connections:
|
|
await self.active_connections[client_id].send_json(message)
|
|
|
|
manager = ConnectionManager()
|
|
|
|
# API Endpoints
|
|
@app.post("/sessions/")
|
|
async def create_session(name: str):
|
|
session = GameSession(name=name)
|
|
sessions[session.id] = session
|
|
return session
|
|
|
|
@app.get("/sessions/{session_id}")
|
|
async def get_session(session_id: str):
|
|
if session_id not in sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
return sessions[session_id]
|
|
|
|
@app.post("/sessions/{session_id}/characters/")
|
|
async def add_character(
|
|
session_id: str,
|
|
name: str,
|
|
description: str,
|
|
personality: str = "",
|
|
llm_model: str = "gpt-3.5-turbo"
|
|
):
|
|
if session_id not in sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
character = Character(
|
|
name=name,
|
|
description=description,
|
|
personality=personality,
|
|
llm_model=llm_model
|
|
)
|
|
session = sessions[session_id]
|
|
session.characters[character.id] = character
|
|
|
|
# Notify storyteller of new character
|
|
storyteller_key = f"{session_id}_storyteller"
|
|
if storyteller_key in manager.active_connections:
|
|
await manager.send_to_client(storyteller_key, {
|
|
"type": "character_joined",
|
|
"character": {
|
|
"id": character.id,
|
|
"name": character.name,
|
|
"description": character.description,
|
|
"llm_model": character.llm_model
|
|
}
|
|
})
|
|
|
|
return character
|
|
|
|
# WebSocket endpoint for character interactions (character view)
|
|
@app.websocket("/ws/character/{session_id}/{character_id}")
|
|
async def character_websocket(websocket: WebSocket, session_id: str, character_id: str):
|
|
if session_id not in sessions or character_id not in sessions[session_id].characters:
|
|
await websocket.close(code=1008, reason="Session or character not found")
|
|
return
|
|
|
|
client_key = f"{session_id}_{character_id}"
|
|
await manager.connect(websocket, client_key)
|
|
|
|
try:
|
|
# Send conversation history
|
|
session = sessions[session_id]
|
|
character = session.characters[character_id]
|
|
await websocket.send_json({
|
|
"type": "history",
|
|
"messages": [msg.dict() for msg in character.conversation_history]
|
|
})
|
|
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
|
|
if data.get("type") == "message":
|
|
# Character sends message to storyteller
|
|
message = Message(sender="character", content=data["content"])
|
|
character.conversation_history.append(message)
|
|
character.pending_response = True
|
|
|
|
# Forward to storyteller
|
|
storyteller_key = f"{session_id}_storyteller"
|
|
if storyteller_key in manager.active_connections:
|
|
await manager.send_to_client(storyteller_key, {
|
|
"type": "character_message",
|
|
"character_id": character_id,
|
|
"character_name": character.name,
|
|
"message": message.dict()
|
|
})
|
|
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(client_key)
|
|
|
|
# WebSocket endpoint for storyteller
|
|
@app.websocket("/ws/storyteller/{session_id}")
|
|
async def storyteller_websocket(websocket: WebSocket, session_id: str):
|
|
if session_id not in sessions:
|
|
await websocket.close(code=1008, reason="Session not found")
|
|
return
|
|
|
|
client_key = f"{session_id}_storyteller"
|
|
await manager.connect(websocket, client_key)
|
|
|
|
try:
|
|
# Send all characters and their conversation states
|
|
session = sessions[session_id]
|
|
await websocket.send_json({
|
|
"type": "session_state",
|
|
"characters": {
|
|
char_id: {
|
|
"id": char.id,
|
|
"name": char.name,
|
|
"description": char.description,
|
|
"personality": char.personality,
|
|
"conversation_history": [msg.dict() for msg in char.conversation_history],
|
|
"pending_response": char.pending_response
|
|
}
|
|
for char_id, char in session.characters.items()
|
|
},
|
|
"current_scene": session.current_scene
|
|
})
|
|
|
|
while True:
|
|
data = await websocket.receive_json()
|
|
|
|
if data.get("type") == "respond_to_character":
|
|
# Storyteller responds to a specific character
|
|
character_id = data["character_id"]
|
|
content = data["content"]
|
|
|
|
if character_id in session.characters:
|
|
character = session.characters[character_id]
|
|
message = Message(sender="storyteller", content=content)
|
|
character.conversation_history.append(message)
|
|
character.pending_response = False
|
|
|
|
# Send to character
|
|
char_key = f"{session_id}_{character_id}"
|
|
if char_key in manager.active_connections:
|
|
await manager.send_to_client(char_key, {
|
|
"type": "storyteller_response",
|
|
"message": message.dict()
|
|
})
|
|
|
|
elif data.get("type") == "narrate_scene":
|
|
# Broadcast scene to all characters
|
|
scene = data["content"]
|
|
session.current_scene = scene
|
|
session.scene_history.append(scene)
|
|
|
|
# Send to all connected characters
|
|
for char_id in session.characters:
|
|
char_key = f"{session_id}_{char_id}"
|
|
if char_key in manager.active_connections:
|
|
await manager.send_to_client(char_key, {
|
|
"type": "scene_narration",
|
|
"content": scene
|
|
})
|
|
|
|
except WebSocketDisconnect:
|
|
manager.disconnect(client_key)
|
|
|
|
# AI-assisted response generation using character's specific LLM
|
|
async def call_llm(model: str, messages: List[dict], temperature: float = 0.8, max_tokens: int = 200) -> str:
|
|
"""Call LLM via OpenRouter or OpenAI depending on model"""
|
|
|
|
# OpenAI models
|
|
if model.startswith("gpt-") or model.startswith("o1-"):
|
|
if not os.getenv("OPENAI_API_KEY"):
|
|
return "OpenAI API key not set."
|
|
try:
|
|
response = await client.chat.completions.create(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens
|
|
)
|
|
return response.choices[0].message.content
|
|
except Exception as e:
|
|
return f"OpenAI error: {str(e)}"
|
|
|
|
# OpenRouter models (Claude, Llama, Gemini, etc.)
|
|
else:
|
|
if not openrouter_api_key:
|
|
return "OpenRouter API key not set."
|
|
try:
|
|
async with httpx.AsyncClient() as http_client:
|
|
response = await http_client.post(
|
|
"https://openrouter.ai/api/v1/chat/completions",
|
|
headers={
|
|
"Authorization": f"Bearer {openrouter_api_key}",
|
|
"HTTP-Referer": "http://localhost:3000",
|
|
"X-Title": "Storyteller RPG"
|
|
},
|
|
json={
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens
|
|
},
|
|
timeout=30.0
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data["choices"][0]["message"]["content"]
|
|
except Exception as e:
|
|
return f"OpenRouter error: {str(e)}"
|
|
|
|
@app.post("/sessions/{session_id}/generate_suggestion")
|
|
async def generate_suggestion(session_id: str, character_id: str, context: str = ""):
|
|
"""Generate AI suggestion for storyteller response to a character using the character's LLM"""
|
|
if session_id not in sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
session = sessions[session_id]
|
|
if character_id not in session.characters:
|
|
raise HTTPException(status_code=404, detail="Character not found")
|
|
|
|
character = session.characters[character_id]
|
|
|
|
# Prepare context for AI suggestion
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": f"You are {character.name} in an RPG. Respond in character. Character description: {character.description}. Personality: {character.personality}. Current scene: {session.current_scene}"
|
|
}
|
|
]
|
|
|
|
# Add recent conversation history
|
|
for msg in character.conversation_history[-6:]:
|
|
role = "assistant" if msg.sender == "character" else "user"
|
|
messages.append({"role": role, "content": msg.content})
|
|
|
|
if context:
|
|
messages.append({"role": "user", "content": f"Additional context: {context}"})
|
|
|
|
try:
|
|
suggestion = await call_llm(character.llm_model, messages, temperature=0.8, max_tokens=200)
|
|
return {"suggestion": suggestion, "model_used": character.llm_model}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error generating suggestion: {str(e)}")
|
|
|
|
# Get available LLM models
|
|
@app.get("/models")
|
|
async def get_available_models():
|
|
"""Get list of available LLM models"""
|
|
models = {
|
|
"openai": [],
|
|
"openrouter": []
|
|
}
|
|
|
|
if os.getenv("OPENAI_API_KEY"):
|
|
models["openai"] = [
|
|
{"id": "gpt-4o", "name": "GPT-4o (Latest)", "provider": "OpenAI"},
|
|
{"id": "gpt-4-turbo", "name": "GPT-4 Turbo", "provider": "OpenAI"},
|
|
{"id": "gpt-4", "name": "GPT-4", "provider": "OpenAI"},
|
|
{"id": "gpt-3.5-turbo", "name": "GPT-3.5 Turbo (Fast & Cheap)", "provider": "OpenAI"},
|
|
]
|
|
|
|
if openrouter_api_key:
|
|
models["openrouter"] = [
|
|
{"id": "anthropic/claude-3.5-sonnet", "name": "Claude 3.5 Sonnet", "provider": "Anthropic"},
|
|
{"id": "anthropic/claude-3-opus", "name": "Claude 3 Opus", "provider": "Anthropic"},
|
|
{"id": "anthropic/claude-3-haiku", "name": "Claude 3 Haiku (Fast)", "provider": "Anthropic"},
|
|
{"id": "google/gemini-pro-1.5", "name": "Gemini Pro 1.5", "provider": "Google"},
|
|
{"id": "meta-llama/llama-3.1-70b-instruct", "name": "Llama 3.1 70B", "provider": "Meta"},
|
|
{"id": "meta-llama/llama-3.1-8b-instruct", "name": "Llama 3.1 8B (Fast)", "provider": "Meta"},
|
|
{"id": "mistralai/mistral-large", "name": "Mistral Large", "provider": "Mistral"},
|
|
{"id": "cohere/command-r-plus", "name": "Command R+", "provider": "Cohere"},
|
|
]
|
|
|
|
return models
|
|
|
|
# Get all pending character messages
|
|
@app.get("/sessions/{session_id}/pending_messages")
|
|
async def get_pending_messages(session_id: str):
|
|
if session_id not in sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
session = sessions[session_id]
|
|
pending = {}
|
|
|
|
for char_id, char in session.characters.items():
|
|
if char.pending_response:
|
|
last_message = char.conversation_history[-1] if char.conversation_history else None
|
|
if last_message and last_message.sender == "character":
|
|
pending[char_id] = {
|
|
"character_name": char.name,
|
|
"message": last_message.dict()
|
|
}
|
|
|
|
return pending
|
|
|
|
# Get character conversation history (for storyteller)
|
|
@app.get("/sessions/{session_id}/characters/{character_id}/conversation")
|
|
async def get_character_conversation(session_id: str, character_id: str):
|
|
if session_id not in sessions:
|
|
raise HTTPException(status_code=404, detail="Session not found")
|
|
|
|
session = sessions[session_id]
|
|
if character_id not in session.characters:
|
|
raise HTTPException(status_code=404, detail="Character not found")
|
|
|
|
character = session.characters[character_id]
|
|
return {
|
|
"character": {
|
|
"id": character.id,
|
|
"name": character.name,
|
|
"description": character.description,
|
|
"personality": character.personality
|
|
},
|
|
"conversation": [msg.dict() for msg in character.conversation_history],
|
|
"pending_response": character.pending_response
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|