reafactor: 优化代码设置

This commit is contained in:
ViperEkura 2026-04-24 13:11:09 +08:00
parent 1bdccb437b
commit 232a86e11f
1 changed files with 101 additions and 119 deletions

View File

@ -2,8 +2,9 @@
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session
from luxx.database import get_db, SessionLocal from luxx.database import get_db
from luxx.models import User, LLMProvider from luxx.models import User, LLMProvider
from luxx.routes.auth import get_current_user from luxx.routes.auth import get_current_user
from luxx.utils.helpers import success_response from luxx.utils.helpers import success_response
@ -37,32 +38,28 @@ class ProviderUpdate(BaseModel):
@router.get("/", response_model=dict) @router.get("/", response_model=dict)
def list_providers( def list_providers(
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
): ):
"""Get user's LLM providers""" """Get user's LLM providers"""
db = SessionLocal() providers = db.query(LLMProvider).filter(
try: LLMProvider.user_id == current_user.id
providers = db.query(LLMProvider).filter( ).order_by(LLMProvider.is_default.desc(), LLMProvider.created_at.desc()).all()
LLMProvider.user_id == current_user.id
).order_by(LLMProvider.is_default.desc(), LLMProvider.created_at.desc()).all() return success_response(data={
"providers": [p.to_dict() for p in providers],
return success_response(data={ "total": len(providers)
"providers": [p.to_dict() for p in providers], })
"total": len(providers)
})
finally:
db.close()
@router.post("/", response_model=dict) @router.post("/", response_model=dict)
def create_provider( def create_provider(
provider: ProviderCreate, provider: ProviderCreate,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
): ):
"""Create a new LLM provider""" """Create a new LLM provider"""
db = SessionLocal()
try: try:
db_provider = LLMProvider( db_provider = LLMProvider(
user_id=current_user.id, user_id=current_user.id,
name=provider.name, name=provider.name,
@ -80,83 +77,96 @@ def create_provider(
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
finally:
db.close()
@router.get("/{provider_id}", response_model=dict) @router.get("/{provider_id}", response_model=dict)
def get_provider( def get_provider(
provider_id: int, provider_id: int,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
): ):
"""Get provider details""" """Get provider details"""
db = SessionLocal() provider = db.query(LLMProvider).filter(
try: LLMProvider.id == provider_id,
provider = db.query(LLMProvider).filter( LLMProvider.user_id == current_user.id
LLMProvider.id == provider_id, ).first()
LLMProvider.user_id == current_user.id
).first() if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
if not provider:
raise HTTPException(status_code=404, detail="Provider not found") return success_response(data=provider.to_dict(include_key=True))
return success_response(data=provider.to_dict(include_key=True))
finally:
db.close()
@router.put("/{provider_id}", response_model=dict) @router.put("/{provider_id}", response_model=dict)
def update_provider( def update_provider(
provider_id: int, provider_id: int,
update: ProviderUpdate, update: ProviderUpdate,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
): ):
"""Update provider""" """Update provider"""
db = SessionLocal() provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id,
LLMProvider.user_id == current_user.id
).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# If setting as default, unset others
if update.is_default:
db.query(LLMProvider).filter(
LLMProvider.user_id == current_user.id,
LLMProvider.id != provider_id
).update({"is_default": False})
# Update fields
update_data = update.dict(exclude_unset=True)
# Keep existing API key if the new one is empty
if update_data.get('api_key') == '':
update_data.pop('api_key')
for key, value in update_data.items():
setattr(provider, key, value)
try: try:
provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id,
LLMProvider.user_id == current_user.id
).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
# If setting as default, unset others
if update.is_default:
db.query(LLMProvider).filter(
LLMProvider.user_id == current_user.id,
LLMProvider.id != provider_id
).update({"is_default": False})
# Update fields
update_data = update.dict(exclude_unset=True)
# Keep existing API key if the new one is empty
if update_data.get('api_key') == '':
update_data.pop('api_key')
for key, value in update_data.items():
setattr(provider, key, value)
db.commit() db.commit()
db.refresh(provider) db.refresh(provider)
return success_response(data=provider.to_dict(include_key=True)) return success_response(data=provider.to_dict(include_key=True))
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
finally:
db.close()
@router.delete("/{provider_id}", response_model=dict) @router.delete("/{provider_id}", response_model=dict)
def delete_provider( def delete_provider(
provider_id: int, provider_id: int,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
): ):
"""Delete provider""" """Delete provider"""
db = SessionLocal() provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id,
LLMProvider.user_id == current_user.id
).first()
if not provider:
raise HTTPException(status_code=404, detail="Provider not found")
db.delete(provider)
db.commit()
return success_response(message="Provider deleted")
@router.post("/{provider_id}/test")
def test_provider(
provider_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Test provider connection"""
try: try:
provider = db.query(LLMProvider).filter( provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id, LLMProvider.id == provider_id,
@ -164,63 +174,35 @@ def delete_provider(
).first() ).first()
if not provider: if not provider:
raise HTTPException(status_code=404, detail="Provider not found") return {"success": False, "message": "Provider not found"}
db.delete(provider) # Test the connection
db.commit() async def test():
async with httpx.AsyncClient(timeout=10.0) as client:
return success_response(message="Provider deleted") response = await client.post(
finally: provider.base_url,
db.close() headers={
"Authorization": f"Bearer {provider.api_key}",
"Content-Type": "application/json"
@router.post("/{provider_id}/test") },
def test_provider( json={
provider_id: int, "model": provider.default_model,
current_user: User = Depends(get_current_user) "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hi"}],
): "stream": False
"""Test provider connection"""
try:
db = SessionLocal()
try:
provider = db.query(LLMProvider).filter(
LLMProvider.id == provider_id,
LLMProvider.user_id == current_user.id
).first()
if not provider:
return {"success": False, "message": "Provider not found"}
# Test the connection
async def test():
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(
provider.base_url,
headers={
"Authorization": f"Bearer {provider.api_key}",
"Content-Type": "application/json"
},
json={
"model": provider.default_model,
"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hi"}],
"stream": False
}
)
response.raise_for_status()
return {
"status_code": response.status_code,
"success": True,
"response_body": response.text[:500] if response.text else None
} }
)
result = asyncio.run(test()) response.raise_for_status()
return { return {
"success": result.get("success", False), "status_code": response.status_code,
"message": result.get("message") or (f"HTTP {result.get('status_code', '?')}: {result.get('response_body') or 'Unknown error'}"), "success": True,
"data": result "response_body": response.text[:500] if response.text else None
} }
finally:
db.close() result = asyncio.run(test())
return {
"success": result.get("success", False),
"message": result.get("message") or (f"HTTP {result.get('status_code', '?')}: {result.get('response_body') or 'Unknown error'}"),
"data": result
}
except Exception as e: except Exception as e:
return {"success": False, "message": f"Exception {str(e)}", "error_type": type(e).__name__, "traceback": traceback.format_exc()[:500]} return {"success": False, "message": f"Exception {str(e)}", "error_type": type(e).__name__, "traceback": traceback.format_exc()[:500]}