Initial commit
This commit is contained in:
397
main.py
Normal file
397
main.py
Normal file
@@ -0,0 +1,397 @@
|
||||
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, List, Optional
|
||||
import uuid
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from openai import AsyncOpenAI
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
import httpx
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Initialize FastAPI
|
||||
app = FastAPI(title="Storyteller RPG API")
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Initialize OpenAI
|
||||
client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
openrouter_api_key = os.getenv("OPENROUTER_API_KEY")
|
||||
|
||||
if not os.getenv("OPENAI_API_KEY") and not openrouter_api_key:
|
||||
print("Warning: Neither OPENAI_API_KEY nor OPENROUTER_API_KEY set. AI features will not work.")
|
||||
|
||||
# Models
|
||||
class Message(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
sender: str # "character" or "storyteller"
|
||||
content: str
|
||||
timestamp: str = Field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
class Character(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
name: str
|
||||
description: str
|
||||
personality: str = "" # Additional personality traits
|
||||
llm_model: str = "gpt-3.5-turbo" # LLM model for this character
|
||||
conversation_history: List[Message] = [] # Private conversation with storyteller
|
||||
pending_response: bool = False # Waiting for storyteller response
|
||||
|
||||
class StorytellerResponse(BaseModel):
|
||||
character_id: str
|
||||
content: str
|
||||
|
||||
class GameSession(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
name: str
|
||||
characters: Dict[str, Character] = {}
|
||||
current_scene: str = ""
|
||||
scene_history: List[str] = [] # All scenes narrated
|
||||
|
||||
# In-memory storage (replace with database in production)
|
||||
sessions: Dict[str, GameSession] = {}
|
||||
|
||||
# WebSocket connection manager
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {} # key: "session_character" or "session_storyteller"
|
||||
|
||||
async def connect(self, websocket: WebSocket, client_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections[client_id] = websocket
|
||||
|
||||
def disconnect(self, client_id: str):
|
||||
if client_id in self.active_connections:
|
||||
del self.active_connections[client_id]
|
||||
|
||||
async def send_to_client(self, client_id: str, message: dict):
|
||||
if client_id in self.active_connections:
|
||||
await self.active_connections[client_id].send_json(message)
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
# API Endpoints
|
||||
@app.post("/sessions/")
|
||||
async def create_session(name: str):
|
||||
session = GameSession(name=name)
|
||||
sessions[session.id] = session
|
||||
return session
|
||||
|
||||
@app.get("/sessions/{session_id}")
|
||||
async def get_session(session_id: str):
|
||||
if session_id not in sessions:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return sessions[session_id]
|
||||
|
||||
@app.post("/sessions/{session_id}/characters/")
|
||||
async def add_character(
|
||||
session_id: str,
|
||||
name: str,
|
||||
description: str,
|
||||
personality: str = "",
|
||||
llm_model: str = "gpt-3.5-turbo"
|
||||
):
|
||||
if session_id not in sessions:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
character = Character(
|
||||
name=name,
|
||||
description=description,
|
||||
personality=personality,
|
||||
llm_model=llm_model
|
||||
)
|
||||
session = sessions[session_id]
|
||||
session.characters[character.id] = character
|
||||
|
||||
# Notify storyteller of new character
|
||||
storyteller_key = f"{session_id}_storyteller"
|
||||
if storyteller_key in manager.active_connections:
|
||||
await manager.send_to_client(storyteller_key, {
|
||||
"type": "character_joined",
|
||||
"character": {
|
||||
"id": character.id,
|
||||
"name": character.name,
|
||||
"description": character.description,
|
||||
"llm_model": character.llm_model
|
||||
}
|
||||
})
|
||||
|
||||
return character
|
||||
|
||||
# WebSocket endpoint for character interactions (character view)
|
||||
@app.websocket("/ws/character/{session_id}/{character_id}")
|
||||
async def character_websocket(websocket: WebSocket, session_id: str, character_id: str):
|
||||
if session_id not in sessions or character_id not in sessions[session_id].characters:
|
||||
await websocket.close(code=1008, reason="Session or character not found")
|
||||
return
|
||||
|
||||
client_key = f"{session_id}_{character_id}"
|
||||
await manager.connect(websocket, client_key)
|
||||
|
||||
try:
|
||||
# Send conversation history
|
||||
session = sessions[session_id]
|
||||
character = session.characters[character_id]
|
||||
await websocket.send_json({
|
||||
"type": "history",
|
||||
"messages": [msg.dict() for msg in character.conversation_history]
|
||||
})
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data.get("type") == "message":
|
||||
# Character sends message to storyteller
|
||||
message = Message(sender="character", content=data["content"])
|
||||
character.conversation_history.append(message)
|
||||
character.pending_response = True
|
||||
|
||||
# Forward to storyteller
|
||||
storyteller_key = f"{session_id}_storyteller"
|
||||
if storyteller_key in manager.active_connections:
|
||||
await manager.send_to_client(storyteller_key, {
|
||||
"type": "character_message",
|
||||
"character_id": character_id,
|
||||
"character_name": character.name,
|
||||
"message": message.dict()
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(client_key)
|
||||
|
||||
# WebSocket endpoint for storyteller
|
||||
@app.websocket("/ws/storyteller/{session_id}")
|
||||
async def storyteller_websocket(websocket: WebSocket, session_id: str):
|
||||
if session_id not in sessions:
|
||||
await websocket.close(code=1008, reason="Session not found")
|
||||
return
|
||||
|
||||
client_key = f"{session_id}_storyteller"
|
||||
await manager.connect(websocket, client_key)
|
||||
|
||||
try:
|
||||
# Send all characters and their conversation states
|
||||
session = sessions[session_id]
|
||||
await websocket.send_json({
|
||||
"type": "session_state",
|
||||
"characters": {
|
||||
char_id: {
|
||||
"id": char.id,
|
||||
"name": char.name,
|
||||
"description": char.description,
|
||||
"personality": char.personality,
|
||||
"conversation_history": [msg.dict() for msg in char.conversation_history],
|
||||
"pending_response": char.pending_response
|
||||
}
|
||||
for char_id, char in session.characters.items()
|
||||
},
|
||||
"current_scene": session.current_scene
|
||||
})
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data.get("type") == "respond_to_character":
|
||||
# Storyteller responds to a specific character
|
||||
character_id = data["character_id"]
|
||||
content = data["content"]
|
||||
|
||||
if character_id in session.characters:
|
||||
character = session.characters[character_id]
|
||||
message = Message(sender="storyteller", content=content)
|
||||
character.conversation_history.append(message)
|
||||
character.pending_response = False
|
||||
|
||||
# Send to character
|
||||
char_key = f"{session_id}_{character_id}"
|
||||
if char_key in manager.active_connections:
|
||||
await manager.send_to_client(char_key, {
|
||||
"type": "storyteller_response",
|
||||
"message": message.dict()
|
||||
})
|
||||
|
||||
elif data.get("type") == "narrate_scene":
|
||||
# Broadcast scene to all characters
|
||||
scene = data["content"]
|
||||
session.current_scene = scene
|
||||
session.scene_history.append(scene)
|
||||
|
||||
# Send to all connected characters
|
||||
for char_id in session.characters:
|
||||
char_key = f"{session_id}_{char_id}"
|
||||
if char_key in manager.active_connections:
|
||||
await manager.send_to_client(char_key, {
|
||||
"type": "scene_narration",
|
||||
"content": scene
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(client_key)
|
||||
|
||||
# AI-assisted response generation using character's specific LLM
|
||||
async def call_llm(model: str, messages: List[dict], temperature: float = 0.8, max_tokens: int = 200) -> str:
|
||||
"""Call LLM via OpenRouter or OpenAI depending on model"""
|
||||
|
||||
# OpenAI models
|
||||
if model.startswith("gpt-") or model.startswith("o1-"):
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
return "OpenAI API key not set."
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
return f"OpenAI error: {str(e)}"
|
||||
|
||||
# OpenRouter models (Claude, Llama, Gemini, etc.)
|
||||
else:
|
||||
if not openrouter_api_key:
|
||||
return "OpenRouter API key not set."
|
||||
try:
|
||||
async with httpx.AsyncClient() as http_client:
|
||||
response = await http_client.post(
|
||||
"https://openrouter.ai/api/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {openrouter_api_key}",
|
||||
"HTTP-Referer": "http://localhost:3000",
|
||||
"X-Title": "Storyteller RPG"
|
||||
},
|
||||
json={
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens
|
||||
},
|
||||
timeout=30.0
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
return f"OpenRouter error: {str(e)}"
|
||||
|
||||
@app.post("/sessions/{session_id}/generate_suggestion")
|
||||
async def generate_suggestion(session_id: str, character_id: str, context: str = ""):
|
||||
"""Generate AI suggestion for storyteller response to a character using the character's LLM"""
|
||||
if session_id not in sessions:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
session = sessions[session_id]
|
||||
if character_id not in session.characters:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
character = session.characters[character_id]
|
||||
|
||||
# Prepare context for AI suggestion
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": f"You are {character.name} in an RPG. Respond in character. Character description: {character.description}. Personality: {character.personality}. Current scene: {session.current_scene}"
|
||||
}
|
||||
]
|
||||
|
||||
# Add recent conversation history
|
||||
for msg in character.conversation_history[-6:]:
|
||||
role = "assistant" if msg.sender == "character" else "user"
|
||||
messages.append({"role": role, "content": msg.content})
|
||||
|
||||
if context:
|
||||
messages.append({"role": "user", "content": f"Additional context: {context}"})
|
||||
|
||||
try:
|
||||
suggestion = await call_llm(character.llm_model, messages, temperature=0.8, max_tokens=200)
|
||||
return {"suggestion": suggestion, "model_used": character.llm_model}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error generating suggestion: {str(e)}")
|
||||
|
||||
# Get available LLM models
|
||||
@app.get("/models")
|
||||
async def get_available_models():
|
||||
"""Get list of available LLM models"""
|
||||
models = {
|
||||
"openai": [],
|
||||
"openrouter": []
|
||||
}
|
||||
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
models["openai"] = [
|
||||
{"id": "gpt-4o", "name": "GPT-4o (Latest)", "provider": "OpenAI"},
|
||||
{"id": "gpt-4-turbo", "name": "GPT-4 Turbo", "provider": "OpenAI"},
|
||||
{"id": "gpt-4", "name": "GPT-4", "provider": "OpenAI"},
|
||||
{"id": "gpt-3.5-turbo", "name": "GPT-3.5 Turbo (Fast & Cheap)", "provider": "OpenAI"},
|
||||
]
|
||||
|
||||
if openrouter_api_key:
|
||||
models["openrouter"] = [
|
||||
{"id": "anthropic/claude-3.5-sonnet", "name": "Claude 3.5 Sonnet", "provider": "Anthropic"},
|
||||
{"id": "anthropic/claude-3-opus", "name": "Claude 3 Opus", "provider": "Anthropic"},
|
||||
{"id": "anthropic/claude-3-haiku", "name": "Claude 3 Haiku (Fast)", "provider": "Anthropic"},
|
||||
{"id": "google/gemini-pro-1.5", "name": "Gemini Pro 1.5", "provider": "Google"},
|
||||
{"id": "meta-llama/llama-3.1-70b-instruct", "name": "Llama 3.1 70B", "provider": "Meta"},
|
||||
{"id": "meta-llama/llama-3.1-8b-instruct", "name": "Llama 3.1 8B (Fast)", "provider": "Meta"},
|
||||
{"id": "mistralai/mistral-large", "name": "Mistral Large", "provider": "Mistral"},
|
||||
{"id": "cohere/command-r-plus", "name": "Command R+", "provider": "Cohere"},
|
||||
]
|
||||
|
||||
return models
|
||||
|
||||
# Get all pending character messages
|
||||
@app.get("/sessions/{session_id}/pending_messages")
|
||||
async def get_pending_messages(session_id: str):
|
||||
if session_id not in sessions:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
session = sessions[session_id]
|
||||
pending = {}
|
||||
|
||||
for char_id, char in session.characters.items():
|
||||
if char.pending_response:
|
||||
last_message = char.conversation_history[-1] if char.conversation_history else None
|
||||
if last_message and last_message.sender == "character":
|
||||
pending[char_id] = {
|
||||
"character_name": char.name,
|
||||
"message": last_message.dict()
|
||||
}
|
||||
|
||||
return pending
|
||||
|
||||
# Get character conversation history (for storyteller)
|
||||
@app.get("/sessions/{session_id}/characters/{character_id}/conversation")
|
||||
async def get_character_conversation(session_id: str, character_id: str):
|
||||
if session_id not in sessions:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
session = sessions[session_id]
|
||||
if character_id not in session.characters:
|
||||
raise HTTPException(status_code=404, detail="Character not found")
|
||||
|
||||
character = session.characters[character_id]
|
||||
return {
|
||||
"character": {
|
||||
"id": character.id,
|
||||
"name": character.name,
|
||||
"description": character.description,
|
||||
"personality": character.personality
|
||||
},
|
||||
"conversation": [msg.dict() for msg in character.conversation_history],
|
||||
"pending_response": character.pending_response
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user