reafactor: 优化代码设置
This commit is contained in:
parent
1bdccb437b
commit
232a86e11f
|
|
@ -2,8 +2,9 @@
|
|||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from luxx.database import get_db, SessionLocal
|
||||
from luxx.database import get_db
|
||||
from luxx.models import User, LLMProvider
|
||||
from luxx.routes.auth import get_current_user
|
||||
from luxx.utils.helpers import success_response
|
||||
|
|
@ -37,32 +38,28 @@ class ProviderUpdate(BaseModel):
|
|||
|
||||
@router.get("/", response_model=dict)
|
||||
def list_providers(
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get user's LLM providers"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
providers = db.query(LLMProvider).filter(
|
||||
LLMProvider.user_id == current_user.id
|
||||
).order_by(LLMProvider.is_default.desc(), LLMProvider.created_at.desc()).all()
|
||||
|
||||
return success_response(data={
|
||||
"providers": [p.to_dict() for p in providers],
|
||||
"total": len(providers)
|
||||
})
|
||||
finally:
|
||||
db.close()
|
||||
providers = db.query(LLMProvider).filter(
|
||||
LLMProvider.user_id == current_user.id
|
||||
).order_by(LLMProvider.is_default.desc(), LLMProvider.created_at.desc()).all()
|
||||
|
||||
return success_response(data={
|
||||
"providers": [p.to_dict() for p in providers],
|
||||
"total": len(providers)
|
||||
})
|
||||
|
||||
|
||||
@router.post("/", response_model=dict)
|
||||
def create_provider(
|
||||
provider: ProviderCreate,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Create a new LLM provider"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
||||
db_provider = LLMProvider(
|
||||
user_id=current_user.id,
|
||||
name=provider.name,
|
||||
|
|
@ -80,83 +77,96 @@ def create_provider(
|
|||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.get("/{provider_id}", response_model=dict)
|
||||
def get_provider(
|
||||
provider_id: int,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Get provider details"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
provider = db.query(LLMProvider).filter(
|
||||
LLMProvider.id == provider_id,
|
||||
LLMProvider.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
return success_response(data=provider.to_dict(include_key=True))
|
||||
finally:
|
||||
db.close()
|
||||
provider = db.query(LLMProvider).filter(
|
||||
LLMProvider.id == provider_id,
|
||||
LLMProvider.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
return success_response(data=provider.to_dict(include_key=True))
|
||||
|
||||
|
||||
@router.put("/{provider_id}", response_model=dict)
|
||||
def update_provider(
|
||||
provider_id: int,
|
||||
update: ProviderUpdate,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Update provider"""
|
||||
db = SessionLocal()
|
||||
provider = db.query(LLMProvider).filter(
|
||||
LLMProvider.id == provider_id,
|
||||
LLMProvider.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# If setting as default, unset others
|
||||
if update.is_default:
|
||||
db.query(LLMProvider).filter(
|
||||
LLMProvider.user_id == current_user.id,
|
||||
LLMProvider.id != provider_id
|
||||
).update({"is_default": False})
|
||||
|
||||
# Update fields
|
||||
update_data = update.dict(exclude_unset=True)
|
||||
# Keep existing API key if the new one is empty
|
||||
if update_data.get('api_key') == '':
|
||||
update_data.pop('api_key')
|
||||
for key, value in update_data.items():
|
||||
setattr(provider, key, value)
|
||||
|
||||
try:
|
||||
provider = db.query(LLMProvider).filter(
|
||||
LLMProvider.id == provider_id,
|
||||
LLMProvider.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
# If setting as default, unset others
|
||||
if update.is_default:
|
||||
db.query(LLMProvider).filter(
|
||||
LLMProvider.user_id == current_user.id,
|
||||
LLMProvider.id != provider_id
|
||||
).update({"is_default": False})
|
||||
|
||||
# Update fields
|
||||
update_data = update.dict(exclude_unset=True)
|
||||
# Keep existing API key if the new one is empty
|
||||
if update_data.get('api_key') == '':
|
||||
update_data.pop('api_key')
|
||||
for key, value in update_data.items():
|
||||
setattr(provider, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(provider)
|
||||
|
||||
return success_response(data=provider.to_dict(include_key=True))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.delete("/{provider_id}", response_model=dict)
|
||||
def delete_provider(
|
||||
provider_id: int,
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Delete provider"""
|
||||
db = SessionLocal()
|
||||
provider = db.query(LLMProvider).filter(
|
||||
LLMProvider.id == provider_id,
|
||||
LLMProvider.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
|
||||
db.delete(provider)
|
||||
db.commit()
|
||||
|
||||
return success_response(message="Provider deleted")
|
||||
|
||||
|
||||
@router.post("/{provider_id}/test")
|
||||
def test_provider(
|
||||
provider_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Test provider connection"""
|
||||
try:
|
||||
provider = db.query(LLMProvider).filter(
|
||||
LLMProvider.id == provider_id,
|
||||
|
|
@ -164,63 +174,35 @@ def delete_provider(
|
|||
).first()
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail="Provider not found")
|
||||
return {"success": False, "message": "Provider not found"}
|
||||
|
||||
db.delete(provider)
|
||||
db.commit()
|
||||
|
||||
return success_response(message="Provider deleted")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@router.post("/{provider_id}/test")
|
||||
def test_provider(
|
||||
provider_id: int,
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""Test provider connection"""
|
||||
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
provider = db.query(LLMProvider).filter(
|
||||
LLMProvider.id == provider_id,
|
||||
LLMProvider.user_id == current_user.id
|
||||
).first()
|
||||
|
||||
if not provider:
|
||||
return {"success": False, "message": "Provider not found"}
|
||||
|
||||
# Test the connection
|
||||
async def test():
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.post(
|
||||
provider.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {provider.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": provider.default_model,
|
||||
"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hi"}],
|
||||
"stream": False
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"success": True,
|
||||
"response_body": response.text[:500] if response.text else None
|
||||
# Test the connection
|
||||
async def test():
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.post(
|
||||
provider.base_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {provider.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": provider.default_model,
|
||||
"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hi"}],
|
||||
"stream": False
|
||||
}
|
||||
|
||||
result = asyncio.run(test())
|
||||
return {
|
||||
"success": result.get("success", False),
|
||||
"message": result.get("message") or (f"HTTP {result.get('status_code', '?')}: {result.get('response_body') or 'Unknown error'}"),
|
||||
"data": result
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
)
|
||||
response.raise_for_status()
|
||||
return {
|
||||
"status_code": response.status_code,
|
||||
"success": True,
|
||||
"response_body": response.text[:500] if response.text else None
|
||||
}
|
||||
|
||||
result = asyncio.run(test())
|
||||
return {
|
||||
"success": result.get("success", False),
|
||||
"message": result.get("message") or (f"HTTP {result.get('status_code', '?')}: {result.get('response_body') or 'Unknown error'}"),
|
||||
"data": result
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"Exception {str(e)}", "error_type": type(e).__name__, "traceback": traceback.format_exc()[:500]}
|
||||
|
|
|
|||
Loading…
Reference in New Issue