From 38dea156d4e5ecde652ecb055ae69553921dd4b3 Mon Sep 17 00:00:00 2001 From: Connor Johnstone Date: Thu, 5 Mar 2026 13:26:23 -0500 Subject: [PATCH] Added cli multi-artist --- README.md | 5 ++- src/main.rs | 98 ++++++++++++++++++++++++++++++------------------- src/playlist.rs | 9 +++-- 3 files changed, 70 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 47a0fe1..3210edf 100644 --- a/README.md +++ b/README.md @@ -42,12 +42,15 @@ Flags: drift build ``` -Opens an interactive picker to choose a seed artist. Or pass an artist name directly: +Opens an interactive picker to choose a seed artist. Or pass one or more artist names directly: ``` drift build "Radiohead" +drift build "Radiohead" "Portishead" ``` +With multiple seeds, artists similar to several seeds rank higher — the playlist blends their neighborhoods naturally. + Flags: - `-n 30` — number of tracks (default 20) - `-p 8` — popularity bias, 0–10 (default 5, higher = prefer popular tracks) diff --git a/src/main.rs b/src/main.rs index 97c8423..43ffd80 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,7 +7,7 @@ mod mpd; mod playlist; mod tui; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::env; use std::path::PathBuf; @@ -59,8 +59,8 @@ enum Command { /// Popularity bias (0=no preference, 10=heavy popular bias) #[arg(short, default_value_t = 5, value_parser = clap::value_parser!(u8).range(0..=10))] popularity: u8, - /// Artist name to seed (or pick interactively) - artist: Option, + /// Artist name(s) to seed (or pick interactively) + artists: Vec, }, } @@ -101,11 +101,11 @@ fn main() { Command::Index { verbose, force, directory } => { cmd_index(verbose, force, &directory); } - Command::Build { verbose, mpd, airsonic, shuffle, random, count, popularity, artist } => { + Command::Build { verbose, mpd, airsonic, shuffle, random, count, popularity, artists } => { let opts = BuildOptions { verbose, mpd, airsonic, shuffle, random, count, popularity_bias: popularity, }; - cmd_build(opts, artist.as_deref()); + cmd_build(opts, artists); } } } @@ -280,61 +280,81 @@ fn resolve_artist(artists: &[(String, String)], query: &str) -> Option<(String, None } -fn cmd_build(opts: BuildOptions, artist: Option<&str>) { +fn cmd_build(opts: BuildOptions, artist_args: Vec) { dotenvy::dotenv().ok(); let conn = db::open(&db_path()).expect("failed to open database"); - let artists = match db::get_all_artists(&conn) { + let all_artists = match db::get_all_artists(&conn) { Ok(a) => a, Err(e) => { eprintln!("DB error: {e}"); std::process::exit(1); } }; - if artists.is_empty() { + if all_artists.is_empty() { eprintln!("No artists in database. Run 'index' first."); std::process::exit(1); } - let (artist_mbid, seed_name) = if let Some(query) = artist { - match resolve_artist(&artists, query) { - Some((mbid, name)) => { - eprintln!("Matched: {name}"); - (mbid, name) - } - None => { - eprintln!("No artist matching \"{query}\""); - std::process::exit(1); - } - } - } else { - match tui::run_artist_picker(&artists) { - Some(selection) => selection, + let seeds: Vec<(String, String)> = if artist_args.is_empty() { + match tui::run_artist_picker(&all_artists) { + Some(selection) => vec![selection], None => std::process::exit(0), } + } else { + artist_args + .iter() + .map(|query| { + match resolve_artist(&all_artists, query) { + Some((mbid, name)) => { + eprintln!("Matched: {name}"); + (mbid, name) + } + None => { + eprintln!("No artist matching \"{query}\""); + std::process::exit(1); + } + } + }) + .collect() }; - build_playlist(&conn, &artist_mbid, &seed_name, &opts); + build_playlist(&conn, &seeds, &opts); } fn build_playlist( conn: &rusqlite::Connection, - artist_mbid: &str, - seed_name: &str, + seeds: &[(String, String)], opts: &BuildOptions, ) { - let similar = match db::get_available_similar_artists(conn, artist_mbid) { - Ok(a) => a, - Err(e) => { - eprintln!("DB error: {e}"); - std::process::exit(1); - } - }; + // Merge similar artists from all seeds: mbid → (name, total_score, count) + let mut merged: HashMap = HashMap::new(); - let mut artists: Vec<(String, String, f64)> = vec![ - (artist_mbid.to_string(), seed_name.to_string(), 1.0), - ]; - artists.extend(similar); + for (seed_mbid, seed_name) in seeds { + // Insert the seed itself with score 1.0 + let entry = merged.entry(seed_mbid.clone()).or_insert_with(|| (seed_name.clone(), 0.0, 0)); + entry.1 += 1.0; + entry.2 += 1; + + let similar = match db::get_available_similar_artists(conn, seed_mbid) { + Ok(a) => a, + Err(e) => { + eprintln!("DB error: {e}"); + std::process::exit(1); + } + }; + + for (mbid, name, score) in similar { + let entry = merged.entry(mbid).or_insert_with(|| (name, 0.0, 0)); + entry.1 += score; + entry.2 += 1; + } + } + + let artists: Vec<(String, String, f64)> = merged + .into_iter() + .map(|(mbid, (name, total, count))| (mbid, name, total / count as f64)) + .collect(); let scored = playlist::score_tracks(conn, &artists, opts.popularity_bias); @@ -355,7 +375,8 @@ fn build_playlist( }) .collect(); - let mut selected = playlist::generate_playlist(&candidates, opts.count, seed_name); + let seed_names: HashSet = seeds.iter().map(|(_, name)| name.clone()).collect(); + let mut selected = playlist::generate_playlist(&candidates, opts.count, &seed_names); if opts.random { selected.shuffle(&mut rand::rng()); @@ -365,7 +386,8 @@ fn build_playlist( let tracks: Vec = selected.into_iter().map(|c| c.path).collect(); - output_tracks(&tracks, opts, seed_name, conn); + let display_name = seeds.iter().map(|(_, name)| name.as_str()).collect::>().join(" + "); + output_tracks(&tracks, opts, &display_name, conn); } fn output_tracks( diff --git a/src/playlist.rs b/src/playlist.rs index da5bf67..b46d029 100644 --- a/src/playlist.rs +++ b/src/playlist.rs @@ -100,7 +100,7 @@ pub fn score_tracks( pub fn generate_playlist( candidates: &[Candidate], n: usize, - seed_name: &str, + seed_names: &HashSet, ) -> Vec { if candidates.is_empty() { return Vec::new(); @@ -132,7 +132,10 @@ pub fn generate_playlist( let artist_cap = n.div_ceil(divisor).max(1); while result.len() < n && !pool.is_empty() { - let seed_count = *artist_counts.get(seed_name).unwrap_or(&0); + let seed_count: usize = seed_names + .iter() + .map(|name| *artist_counts.get(name).unwrap_or(&0)) + .sum(); let remaining = n - result.len(); let seed_deficit = seed_min.saturating_sub(seed_count); let force_seed = seed_deficit > 0 && remaining <= seed_deficit; @@ -142,7 +145,7 @@ pub fn generate_playlist( .enumerate() .filter(|(_, c)| { if force_seed { - c.artist == seed_name + seed_names.contains(&c.artist) } else { *artist_counts.get(&c.artist).unwrap_or(&0) < artist_cap }