Luxx/luxx/api/providers.py

221 lines
6.7 KiB
Python

"""LLM Provider routes"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from luxx.core.database import SessionLocal
from luxx.models.user import User, LLMProvider
from luxx.api.auth import get_current_user
from luxx.utils.helpers import success_response
import httpx
import asyncio
import traceback
router = APIRouter(prefix="/providers", tags=["LLM Providers"])
class ProviderCreate(BaseModel):
name: str
provider_type: str = "openai"
base_url: str
api_key: str
default_model: str = "gpt-4"
is_default: bool = False
class ProviderUpdate(BaseModel):
name: Optional[str] = None
provider_type: Optional[str] = None
base_url: Optional[str] = None
api_key: Optional[str] = None
default_model: Optional[str] = None
max_tokens: Optional[int] = None
is_default: Optional[bool] = None
enabled: Optional[bool] = None
@router.get("/", response_model=dict)
def list_providers(
current_user: User = Depends(get_current_user)
):
"""Get user's LLM providers"""
db = SessionLocal()
try:
providers = db.query(LLMProvider).filter(
LLMProvider.user_id == current_user.id
).order_by(LLMProvider.is_default.desc(), LLMProvider.created_at.desc()).all()
return success_response(data={
"providers": [p.to_dict() for p in providers],
"total": len(providers)
})
finally:
db.close()
@router.post("/", response_model=dict)
def create_provider(
provider: ProviderCreate,
current_user: User = Depends(get_current_user)
):
"""Create a new LLM provider"""
db = SessionLocal()
try:
db_provider = LLMProvider(
user_id=current_user.id,
name=provider.name,
provider_type=provider.provider_type,
base_url=provider.base_url,
api_key=provider.api_key,
default_model=provider.default_model,
is_default=provider.is_default
)
db.add(db_provider)
db.commit()
db.refresh(db_provider)
return success_response(data=db_provider.to_dict(include_key=True))
except Exception as e:
db.rollback()
raise HTTPException(status_code=400, detail=str(e))
finally:
db.close()
@router.get("/{provider_id}", response_model=dict)
def get_provider(
provider_id: int,
current_user: User = Depends(get_current_user)
):
"""Get provider details"""
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id,
LLMProvider.user_id == current_user.id
).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
return success_response(data=provider.to_dict(include_key=True))
finally:
db.close()
@router.put("/{provider_id}", response_model=dict)
def update_provider(
provider_id: int,
update: ProviderUpdate,
current_user: User = Depends(get_current_user)
):
"""Update provider"""
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id,
LLMProvider.user_id == current_user.id
).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
if update.is_default:
db.query(LLMProvider).filter(
LLMProvider.user_id == current_user.id,
LLMProvider.id != provider_id
).update({"is_default": False})
update_data = update.dict(exclude_unset=True)
if update_data.get('api_key') == '':
update_data.pop('api_key')
for key, value in update_data.items():
setattr(provider, key, value)
db.commit()
db.refresh(provider)
return success_response(data=provider.to_dict(include_key=True))
except HTTPException:
raise
except Exception as e:
db.rollback()
raise HTTPException(status_code=400, detail=str(e))
finally:
db.close()
@router.delete("/{provider_id}", response_model=dict)
def delete_provider(
provider_id: int,
current_user: User = Depends(get_current_user)
):
"""Delete provider"""
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id,
LLMProvider.user_id == current_user.id
).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
db.delete(provider)
db.commit()
return success_response(message="Provider deleted")
finally:
db.close()
@router.post("/{provider_id}/test")
def test_provider(
provider_id: int,
current_user: User = Depends(get_current_user)
):
"""Test provider connection"""
try:
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id,
LLMProvider.user_id == current_user.id
).first()
if not provider:
return {"success": False, "message": "Provider not found"}
async def test():
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(
provider.base_url,
headers={
"Authorization": f"Bearer {provider.api_key}",
"Content-Type": "application/json"
},
json={
"model": provider.default_model,
"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hi"}],
"stream": False
}
)
response.raise_for_status()
return {
"status_code": response.status_code,
"success": True,
"response_body": response.text[:500] if response.text else None
}
result = asyncio.run(test())
return {
"success": result.get("success", False),
"message": result.get("message") or (f"HTTP {result.get('status_code', '?')}: {result.get('response_body') or 'Unknown error'}"),
"data": result
}
finally:
db.close()
except Exception as e:
return {"success": False, "message": f"Exception {str(e)}", "error_type": type(e).__name__, "traceback": traceback.format_exc()[:500]}