Added config support

This commit is contained in:
Connor Johnstone
2026-03-18 15:14:32 -04:00
parent ff41233a96
commit 32b4b533c0
11 changed files with 381 additions and 259 deletions

View File

@@ -1,210 +1 @@
use std::path::PathBuf;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppConfig {
#[serde(default = "default_library_path")]
pub library_path: PathBuf,
#[serde(default = "default_database_url")]
pub database_url: String,
#[serde(default = "default_download_path")]
pub download_path: PathBuf,
#[serde(default = "default_organization_format")]
pub organization_format: String,
/// Which secondary release group types to include. Empty = studio releases only.
/// Options: "Compilation", "Live", "Soundtrack", "Remix", "DJ-mix", "Demo", etc.
#[serde(default)]
pub allowed_secondary_types: Vec<String>,
#[serde(default)]
pub web: WebConfig,
#[serde(default)]
pub tagging: TaggingConfig,
#[serde(default)]
pub download: DownloadConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebConfig {
#[serde(default = "default_port")]
pub port: u16,
#[serde(default = "default_bind")]
pub bind: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaggingConfig {
#[serde(default)]
pub auto_tag: bool,
#[serde(default = "default_true")]
pub write_tags: bool,
#[serde(default = "default_confidence")]
pub confidence: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DownloadConfig {
#[serde(default = "default_format")]
pub format: String,
#[serde(default = "default_search_source")]
pub search_source: String,
#[serde(default)]
pub cookies_path: Option<PathBuf>,
}
impl Default for AppConfig {
fn default() -> Self {
Self {
library_path: default_library_path(),
database_url: default_database_url(),
download_path: default_download_path(),
organization_format: default_organization_format(),
allowed_secondary_types: vec![], // empty = studio only
web: WebConfig::default(),
tagging: TaggingConfig::default(),
download: DownloadConfig::default(),
}
}
}
impl Default for WebConfig {
fn default() -> Self {
Self {
port: default_port(),
bind: default_bind(),
}
}
}
impl Default for TaggingConfig {
fn default() -> Self {
Self {
auto_tag: false,
write_tags: true,
confidence: default_confidence(),
}
}
}
impl Default for DownloadConfig {
fn default() -> Self {
Self {
format: default_format(),
search_source: default_search_source(),
cookies_path: None,
}
}
}
fn default_library_path() -> PathBuf {
dirs::audio_dir().unwrap_or_else(|| PathBuf::from("~/Music"))
}
fn default_database_url() -> String {
let data_dir = dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("shanty");
std::fs::create_dir_all(&data_dir).ok();
format!("sqlite://{}?mode=rwc", data_dir.join("shanty.db").display())
}
fn default_download_path() -> PathBuf {
let dir = dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("shanty")
.join("downloads");
std::fs::create_dir_all(&dir).ok();
dir
}
fn default_organization_format() -> String {
shanty_org::DEFAULT_FORMAT.to_string()
}
fn default_port() -> u16 {
8085
}
fn default_bind() -> String {
"0.0.0.0".to_string()
}
fn default_confidence() -> f64 {
0.8
}
fn default_format() -> String {
"opus".to_string()
}
fn default_search_source() -> String {
"ytmusic".to_string()
}
fn default_true() -> bool {
true
}
impl AppConfig {
/// Load config from file, falling back to defaults.
pub fn load(path: Option<&str>) -> Self {
let config_path = path
.map(PathBuf::from)
.or_else(|| std::env::var("SHANTY_CONFIG").ok().map(PathBuf::from))
.unwrap_or_else(|| {
dirs::config_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("shanty")
.join("config.yaml")
});
if config_path.exists() {
match std::fs::read_to_string(&config_path) {
Ok(contents) => match serde_yaml::from_str(&contents) {
Ok(config) => {
tracing::info!(path = %config_path.display(), "loaded config");
return Self::apply_env_overrides(config);
}
Err(e) => {
tracing::warn!(path = %config_path.display(), error = %e, "failed to parse config, using defaults");
}
},
Err(e) => {
tracing::warn!(path = %config_path.display(), error = %e, "failed to read config, using defaults");
}
}
} else {
tracing::info!(path = %config_path.display(), "no config file found, using defaults");
}
Self::apply_env_overrides(AppConfig::default())
}
fn apply_env_overrides(mut config: Self) -> Self {
if let Ok(v) = std::env::var("SHANTY_DATABASE_URL") {
config.database_url = v;
}
if let Ok(v) = std::env::var("SHANTY_LIBRARY_PATH") {
config.library_path = PathBuf::from(v);
}
if let Ok(v) = std::env::var("SHANTY_DOWNLOAD_PATH") {
config.download_path = PathBuf::from(v);
}
if let Ok(v) = std::env::var("SHANTY_WEB_PORT") {
if let Ok(port) = v.parse() {
config.web.port = port;
}
}
if let Ok(v) = std::env::var("SHANTY_WEB_BIND") {
config.web.bind = v;
}
config
}
}
pub use shanty_config::*;

View File

@@ -58,11 +58,13 @@ async fn main() -> anyhow::Result<()> {
let bind = format!("{}:{}", config.web.bind, config.web.port);
tracing::info!(bind = %bind, "starting server");
let config_path = cli.config.clone();
let state = web::Data::new(AppState {
db,
mb_client,
search,
config,
config: std::sync::Arc::new(tokio::sync::RwLock::new(config)),
config_path: config_path,
tasks: TaskManager::new(),
});

View File

@@ -301,7 +301,7 @@ pub async fn enrich_artist(
// Fetch release groups and filter by allowed secondary types
let all_release_groups = state.search.get_release_groups(&mbid).await
.map_err(|e| ApiError::Internal(e.to_string()))?;
let allowed = &state.config.allowed_secondary_types;
let allowed = state.config.read().await.allowed_secondary_types.clone();
let release_groups: Vec<_> = all_release_groups
.into_iter()
.filter(|rg| {

View File

@@ -88,13 +88,14 @@ async fn trigger_process(
let tid = task_id.clone();
tokio::spawn(async move {
let cookies = state.config.download.cookies_path.clone();
let format: shanty_dl::AudioFormat = state.config.download.format.parse().unwrap_or(shanty_dl::AudioFormat::Opus);
let source: shanty_dl::SearchSource = state.config.download.search_source.parse().unwrap_or(shanty_dl::SearchSource::YouTubeMusic);
let rate = if cookies.is_some() { 1800 } else { 450 };
let cfg = state.config.read().await.clone();
let cookies = cfg.download.cookies_path.clone();
let format: shanty_dl::AudioFormat = cfg.download.format.parse().unwrap_or(shanty_dl::AudioFormat::Opus);
let source: shanty_dl::SearchSource = cfg.download.search_source.parse().unwrap_or(shanty_dl::SearchSource::YouTubeMusic);
let rate = if cookies.is_some() { cfg.download.rate_limit_auth } else { cfg.download.rate_limit };
let backend = shanty_dl::YtDlpBackend::new(rate, source, cookies.clone());
let backend_config = shanty_dl::BackendConfig {
output_dir: state.config.download_path.clone(),
output_dir: cfg.download_path.clone(),
format,
cookies_path: cookies,
};

View File

@@ -1,8 +1,10 @@
use actix_web::{web, HttpResponse};
use serde::Deserialize;
use shanty_db::entities::download_queue::DownloadStatus;
use shanty_db::queries;
use crate::config::AppConfig;
use crate::error::ApiError;
use crate::routes::artists::enrich_all_watched_artists;
use crate::state::AppState;
@@ -16,7 +18,11 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
.service(web::resource("/tasks/{id}").route(web::get().to(get_task)))
.service(web::resource("/watchlist").route(web::get().to(list_watchlist)))
.service(web::resource("/watchlist/{id}").route(web::delete().to(remove_watchlist)))
.service(web::resource("/config").route(web::get().to(get_config)));
.service(
web::resource("/config")
.route(web::get().to(get_config))
.route(web::put().to(save_config)),
);
}
async fn get_status(
@@ -28,13 +34,11 @@ async fn get_status(
let failed_items = queries::downloads::list(state.db.conn(), Some(DownloadStatus::Failed)).await?;
let tasks = state.tasks.list();
// Combine active/recent download items for the dashboard
let mut queue_items = Vec::new();
queue_items.extend(downloading_items.iter().cloned());
queue_items.extend(pending_items.iter().cloned());
queue_items.extend(failed_items.iter().take(5).cloned());
// Tracks needing metadata (tagging queue)
let needs_tagging = queries::tracks::get_needing_metadata(state.db.conn()).await?;
Ok(HttpResponse::Ok().json(serde_json::json!({
@@ -61,11 +65,12 @@ async fn trigger_index(
let tid = task_id.clone();
tokio::spawn(async move {
let cfg = state.config.read().await.clone();
state.tasks.update_progress(&tid, 0, 0, "Scanning library...");
let scan_config = shanty_index::ScanConfig {
root: state.config.library_path.clone(),
root: cfg.library_path.clone(),
dry_run: false,
concurrency: 4,
concurrency: cfg.indexing.concurrency,
};
match shanty_index::run_scan(state.db.conn(), &scan_config).await {
Ok(stats) => state.tasks.complete(&tid, format!("{stats}")),
@@ -84,6 +89,7 @@ async fn trigger_tag(
let tid = task_id.clone();
tokio::spawn(async move {
let cfg = state.config.read().await.clone();
state.tasks.update_progress(&tid, 0, 0, "Preparing tagger...");
let mb = match shanty_tag::MusicBrainzClient::new() {
Ok(c) => c,
@@ -94,8 +100,8 @@ async fn trigger_tag(
};
let tag_config = shanty_tag::TagConfig {
dry_run: false,
write_tags: state.config.tagging.write_tags,
confidence: state.config.tagging.confidence,
write_tags: cfg.tagging.write_tags,
confidence: cfg.tagging.confidence,
};
state.tasks.update_progress(&tid, 0, 0, "Tagging tracks...");
match shanty_tag::run_tagging(state.db.conn(), &mb, &tag_config, None).await {
@@ -115,16 +121,16 @@ async fn trigger_organize(
let tid = task_id.clone();
tokio::spawn(async move {
let cfg = state.config.read().await.clone();
state.tasks.update_progress(&tid, 0, 0, "Organizing files...");
let org_config = shanty_org::OrgConfig {
target_dir: state.config.library_path.clone(),
format: state.config.organization_format.clone(),
target_dir: cfg.library_path.clone(),
format: cfg.organization_format.clone(),
dry_run: false,
copy: false,
};
match shanty_org::organize_from_db(state.db.conn(), &org_config).await {
Ok(stats) => {
// Promote all Downloaded wanted items to Owned
let promoted = queries::wanted::promote_downloaded_to_owned(state.db.conn())
.await
.unwrap_or(0);
@@ -134,7 +140,6 @@ async fn trigger_organize(
format!("{stats}")
};
state.tasks.complete(&tid, msg);
// Refresh artist data in background
let _ = enrich_all_watched_artists(&state).await;
}
Err(e) => state.tasks.fail(&tid, e.to_string()),
@@ -147,7 +152,6 @@ async fn trigger_organize(
async fn trigger_pipeline(
state: web::Data<AppState>,
) -> Result<HttpResponse, ApiError> {
// Register all 6 pipeline tasks as Pending
let sync_id = state.tasks.register_pending("sync");
let download_id = state.tasks.register_pending("download");
let index_id = state.tasks.register_pending("index");
@@ -167,6 +171,8 @@ async fn trigger_pipeline(
let state = state.clone();
tokio::spawn(async move {
let cfg = state.config.read().await.clone();
// Step 1: Sync
state.tasks.start(&sync_id);
state.tasks.update_progress(&sync_id, 0, 0, "Syncing watchlist to download queue...");
@@ -177,13 +183,13 @@ async fn trigger_pipeline(
// Step 2: Download
state.tasks.start(&download_id);
let cookies = state.config.download.cookies_path.clone();
let format: shanty_dl::AudioFormat = state.config.download.format.parse().unwrap_or(shanty_dl::AudioFormat::Opus);
let source: shanty_dl::SearchSource = state.config.download.search_source.parse().unwrap_or(shanty_dl::SearchSource::YouTubeMusic);
let rate = if cookies.is_some() { 1800 } else { 450 };
let cookies = cfg.download.cookies_path.clone();
let format: shanty_dl::AudioFormat = cfg.download.format.parse().unwrap_or(shanty_dl::AudioFormat::Opus);
let source: shanty_dl::SearchSource = cfg.download.search_source.parse().unwrap_or(shanty_dl::SearchSource::YouTubeMusic);
let rate = if cookies.is_some() { cfg.download.rate_limit_auth } else { cfg.download.rate_limit };
let backend = shanty_dl::YtDlpBackend::new(rate, source, cookies.clone());
let backend_config = shanty_dl::BackendConfig {
output_dir: state.config.download_path.clone(),
output_dir: cfg.download_path.clone(),
format,
cookies_path: cookies,
};
@@ -204,9 +210,9 @@ async fn trigger_pipeline(
state.tasks.start(&index_id);
state.tasks.update_progress(&index_id, 0, 0, "Scanning library...");
let scan_config = shanty_index::ScanConfig {
root: state.config.library_path.clone(),
root: cfg.library_path.clone(),
dry_run: false,
concurrency: 4,
concurrency: cfg.indexing.concurrency,
};
match shanty_index::run_scan(state.db.conn(), &scan_config).await {
Ok(stats) => state.tasks.complete(&index_id, format!("{stats}")),
@@ -220,8 +226,8 @@ async fn trigger_pipeline(
Ok(mb) => {
let tag_config = shanty_tag::TagConfig {
dry_run: false,
write_tags: state.config.tagging.write_tags,
confidence: state.config.tagging.confidence,
write_tags: cfg.tagging.write_tags,
confidence: cfg.tagging.confidence,
};
match shanty_tag::run_tagging(state.db.conn(), &mb, &tag_config, None).await {
Ok(stats) => state.tasks.complete(&tag_id, format!("{stats}")),
@@ -235,8 +241,8 @@ async fn trigger_pipeline(
state.tasks.start(&organize_id);
state.tasks.update_progress(&organize_id, 0, 0, "Organizing files...");
let org_config = shanty_org::OrgConfig {
target_dir: state.config.library_path.clone(),
format: state.config.organization_format.clone(),
target_dir: cfg.library_path.clone(),
format: cfg.organization_format.clone(),
dry_run: false,
copy: false,
};
@@ -254,7 +260,7 @@ async fn trigger_pipeline(
Err(e) => state.tasks.fail(&organize_id, e.to_string()),
}
// Step 6: Enrich — refresh cached artist totals for the library page
// Step 6: Enrich
state.tasks.start(&enrich_id);
state.tasks.update_progress(&enrich_id, 0, 0, "Refreshing artist data...");
match enrich_all_watched_artists(&state).await {
@@ -296,5 +302,30 @@ async fn remove_watchlist(
async fn get_config(
state: web::Data<AppState>,
) -> Result<HttpResponse, ApiError> {
Ok(HttpResponse::Ok().json(&state.config))
let config = state.config.read().await;
Ok(HttpResponse::Ok().json(&*config))
}
#[derive(Deserialize)]
struct SaveConfigRequest {
#[serde(flatten)]
config: AppConfig,
}
async fn save_config(
state: web::Data<AppState>,
body: web::Json<SaveConfigRequest>,
) -> Result<HttpResponse, ApiError> {
let new_config = body.into_inner().config;
// Persist to YAML
new_config.save(state.config_path.as_deref())
.map_err(|e| ApiError::Internal(e))?;
// Update in-memory config
let mut config = state.config.write().await;
*config = new_config.clone();
tracing::info!("config updated via API");
Ok(HttpResponse::Ok().json(&new_config))
}

View File

@@ -1,3 +1,6 @@
use std::sync::Arc;
use tokio::sync::RwLock;
use shanty_db::Database;
use shanty_search::MusicBrainzSearch;
use shanty_tag::MusicBrainzClient;
@@ -9,6 +12,7 @@ pub struct AppState {
pub db: Database,
pub mb_client: MusicBrainzClient,
pub search: MusicBrainzSearch,
pub config: AppConfig,
pub config: Arc<RwLock<AppConfig>>,
pub config_path: Option<String>,
pub tasks: TaskManager,
}