Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions litellm-proxy-extras/litellm_proxy_extras/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -580,4 +580,12 @@ model LiteLLM_SearchToolsTable {
search_tool_info Json?
created_at DateTime @default(now())
updated_at DateTime @updatedAt
}

// SSO configuration table
model LiteLLM_SSOConfig {
id String @id @default("sso_config")
sso_settings Json
created_at DateTime @default(now())
updated_at DateTime @updatedAt
}
23 changes: 23 additions & 0 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,7 @@ def swagger_monkey_patch(*args, **kwargs):
scheduler = None
last_model_cost_map_reload = None


### DB WRITER ###
db_writer_client: Optional[AsyncHTTPHandler] = None
### logger ###
Expand Down Expand Up @@ -3331,6 +3332,28 @@ async def _init_non_llm_objects_in_db(self, prisma_client: PrismaClient):

if self._should_load_db_object(object_type="model_cost_map"):
await self._check_and_reload_model_cost_map(prisma_client=prisma_client)
if self._should_load_db_object(object_type="sso_settings"):
await self._init_sso_settings_in_db(prisma_client=prisma_client)

async def _init_sso_settings_in_db(self, prisma_client: PrismaClient):
"""
Initialize SSO settings from database into the router on startup.
"""

try:
sso_settings = await prisma_client.db.litellm_ssoconfig.find_unique(
where={"id": "sso_config"}
)
if sso_settings is not None:
# Capitalize all keys in sso_settings dictionary
uppercase_sso_settings = {key.upper(): value for key, value in sso_settings.sso_settings.items()}
self._decrypt_and_set_db_env_variables(environment_variables=uppercase_sso_settings)
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.proxy_server.py::ProxyConfig:_init_sso_settings_in_db - {}".format(
str(e)
)
)

async def _check_and_reload_model_cost_map(self, prisma_client: PrismaClient):
"""
Expand Down
8 changes: 8 additions & 0 deletions litellm/proxy/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -580,4 +580,12 @@ model LiteLLM_SearchToolsTable {
search_tool_info Json?
created_at DateTime @default(now())
updated_at DateTime @updatedAt
}

// SSO configuration table
model LiteLLM_SSOConfig {
id String @id @default("sso_config")
sso_settings Json
created_at DateTime @default(now())
updated_at DateTime @updatedAt
}
130 changes: 71 additions & 59 deletions litellm/proxy/ui_crud_endpoints/proxy_setting_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,40 +390,47 @@ async def update_default_team_settings(settings: DefaultTeamSSOParams):
)
async def get_sso_settings():
"""
Get all SSO configuration settings from the environment variables.
Get all SSO configuration settings from the dedicated SSO table.
Returns a structured object with values and descriptions for UI display.
"""
import os

from litellm.proxy.proxy_server import proxy_config
from litellm.proxy.proxy_server import prisma_client, proxy_config

# Load existing config to get both environment variables and general settings
config = await proxy_config.get_config()
general_settings = config.get("general_settings", {}) or {}
environment_variables = config.get("environment_variables", {}) or {}
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": "Database not connected. Please connect a database."},
)

# Get user_email from general_settings
proxy_admin_email = general_settings.get("proxy_admin_email", None)
# Get SSO config from dedicated table
sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique(
where={"id": "sso_config"}
)

# Helper function to get env var value (first from config, then from environment)
def get_env_value(env_var_name: str):
return environment_variables.get(env_var_name) or os.getenv(env_var_name)
# Initialize with defaults
sso_settings_dict = {}

if sso_db_record and sso_db_record.sso_settings:
# Load settings from database
sso_settings_dict = dict(sso_db_record.sso_settings)

decrypted_sso_settings_dict = proxy_config._decrypt_and_set_db_env_variables(environment_variables=sso_settings_dict)

# Get current environment variables for SSO
# Build SSO config with database values or environment fallback
sso_config = SSOConfig(
google_client_id=get_env_value("GOOGLE_CLIENT_ID"),
google_client_secret=get_env_value("GOOGLE_CLIENT_SECRET"),
microsoft_client_id=get_env_value("MICROSOFT_CLIENT_ID"),
microsoft_client_secret=get_env_value("MICROSOFT_CLIENT_SECRET"),
microsoft_tenant=get_env_value("MICROSOFT_TENANT"),
generic_client_id=get_env_value("GENERIC_CLIENT_ID"),
generic_client_secret=get_env_value("GENERIC_CLIENT_SECRET"),
generic_authorization_endpoint=get_env_value("GENERIC_AUTHORIZATION_ENDPOINT"),
generic_token_endpoint=get_env_value("GENERIC_TOKEN_ENDPOINT"),
generic_userinfo_endpoint=get_env_value("GENERIC_USERINFO_ENDPOINT"),
proxy_base_url=get_env_value("PROXY_BASE_URL"),
user_email=proxy_admin_email, # Get from config instead of environment
ui_access_mode=general_settings.get("ui_access_mode", None),
google_client_id=decrypted_sso_settings_dict.get("google_client_id", None),
google_client_secret=decrypted_sso_settings_dict.get("google_client_secret", None),
microsoft_client_id=decrypted_sso_settings_dict.get("microsoft_client_id", None),
microsoft_client_secret=decrypted_sso_settings_dict.get("microsoft_client_secret", None),
microsoft_tenant=decrypted_sso_settings_dict.get("microsoft_tenant", None),
generic_client_id=decrypted_sso_settings_dict.get("generic_client_id", None),
generic_client_secret=decrypted_sso_settings_dict.get("generic_client_secret", None),
generic_authorization_endpoint=decrypted_sso_settings_dict.get("generic_authorization_endpoint", None),
generic_token_endpoint=decrypted_sso_settings_dict.get("generic_token_endpoint", None),
generic_userinfo_endpoint=decrypted_sso_settings_dict.get("generic_userinfo_endpoint", None),
proxy_base_url=decrypted_sso_settings_dict.get("proxy_base_url", None),
user_email=decrypted_sso_settings_dict.get("user_email"),
ui_access_mode=decrypted_sso_settings_dict.get("ui_access_mode"),
)

# Get the schema for UI display
Expand Down Expand Up @@ -460,11 +467,26 @@ def get_env_value(env_var_name: str):
)
async def update_sso_settings(sso_config: SSOConfig):
"""
Update SSO configuration by saving to both environment variables and config file.
Update SSO configuration by saving to the dedicated SSO table.
"""
import os
import json

from litellm.proxy.proxy_server import proxy_config
from litellm.proxy.proxy_server import prisma_client, store_model_in_db, proxy_config

if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": "Database not connected. Please connect a database."},
)

if store_model_in_db is not True:
raise HTTPException(
status_code=500,
detail={
"error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
},
)

# Update environment variables
env_var_mapping = {
Expand Down Expand Up @@ -495,39 +517,29 @@ async def update_sso_settings(sso_config: SSOConfig):
# Update environment variables in config and in memory
sso_data = sso_config.model_dump()
for field_name, value in sso_data.items():
if field_name == "user_email":
if value:
# Store user_email in general_settings instead of environment variables
config["general_settings"]["proxy_admin_email"] = value
else:
# Clear user_email if null/empty
config["general_settings"].pop("proxy_admin_email", None)
elif field_name == "ui_access_mode":
if field_name in env_var_mapping:
env_var_name = env_var_mapping[field_name]
if value:
config["general_settings"]["ui_access_mode"] = value
os.environ[env_var_name] = value
else:
# Clear ui_access_mode if null/empty
config["general_settings"].pop("ui_access_mode", None)
elif field_name in env_var_mapping and value:
env_var_name = env_var_mapping[field_name]
# Update in config
config["environment_variables"][env_var_name] = value
# Update in runtime environment
os.environ[env_var_name] = value
elif field_name in env_var_mapping:
# Clear environment variable if value is null/empty
env_var_name = env_var_mapping[field_name]
config["environment_variables"].pop(env_var_name, None)
os.environ.pop(env_var_name, None)

stored_config = config
if len(config["environment_variables"]) > 0:

stored_config["environment_variables"] = proxy_config._encrypt_env_variables(
environment_variables=config["environment_variables"]
)
# Save the updated config
await proxy_config.save_config(new_config=stored_config)
# Clear environment variable if value is null/empty
os.environ.pop(env_var_name, None)

encrypted_sso_data = proxy_config._encrypt_env_variables(environment_variables=sso_data)

# Save to dedicated SSO table
await prisma_client.db.litellm_ssoconfig.upsert(
where={"id": "sso_config"},
data={
"create": {
"id": "sso_config",
"sso_settings": json.dumps(encrypted_sso_data),
},
"update": {
"sso_settings": json.dumps(encrypted_sso_data),
},
},
)

return {
"message": "SSO settings updated successfully",
Expand Down
8 changes: 8 additions & 0 deletions schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -580,4 +580,12 @@ model LiteLLM_SearchToolsTable {
search_tool_info Json?
created_at DateTime @default(now())
updated_at DateTime @updatedAt
}

// SSO configuration table
model LiteLLM_SSOConfig {
id String @id @default("sso_config")
sso_settings Json
created_at DateTime @default(now())
updated_at DateTime @updatedAt
}
Loading
Loading