Added the playlist generator
This commit is contained in:
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -3229,8 +3229,12 @@ dependencies = [
|
||||
name = "shanty-playlist"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"clap",
|
||||
"chrono",
|
||||
"rand 0.9.2",
|
||||
"sea-orm",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"shanty-data",
|
||||
"shanty-db",
|
||||
"thiserror",
|
||||
"tracing",
|
||||
@@ -3341,6 +3345,7 @@ dependencies = [
|
||||
"shanty-dl",
|
||||
"shanty-index",
|
||||
"shanty-org",
|
||||
"shanty-playlist",
|
||||
"shanty-search",
|
||||
"shanty-tag",
|
||||
"shanty-watch",
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::time::Instant;
|
||||
|
||||
/// A simple rate limiter that enforces a minimum interval between requests.
|
||||
/// Can be cloned (via Arc) to share across multiple clients.
|
||||
#[derive(Clone)]
|
||||
pub struct RateLimiter {
|
||||
last_request: Mutex<Instant>,
|
||||
last_request: Arc<Mutex<Instant>>,
|
||||
interval: Duration,
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(interval: Duration) -> Self {
|
||||
Self {
|
||||
last_request: Mutex::new(Instant::now() - interval),
|
||||
last_request: Arc::new(Mutex::new(Instant::now() - interval)),
|
||||
interval,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use crate::error::DataResult;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::error::{DataError, DataResult};
|
||||
use crate::http::{build_client, urlencoded};
|
||||
use crate::traits::ArtistBioFetcher;
|
||||
use crate::types::ArtistInfo;
|
||||
use crate::traits::{ArtistBioFetcher, SimilarArtistFetcher};
|
||||
use crate::types::{ArtistInfo, PopularTrack, SimilarArtist};
|
||||
|
||||
const USER_AGENT: &str = "Shanty/0.1.0 (shanty-music-app)";
|
||||
|
||||
@@ -51,6 +53,211 @@ impl ArtistBioFetcher for LastFmBioFetcher {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Similar artist fetcher (ported from drift) ---
|
||||
|
||||
const BASE_URL: &str = "https://ws.audioscrobbler.com/2.0/";
|
||||
|
||||
/// Fetches similar artists and top tracks from Last.fm.
|
||||
pub struct LastFmSimilarFetcher {
|
||||
api_key: String,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl LastFmSimilarFetcher {
|
||||
pub fn new(api_key: String) -> DataResult<Self> {
|
||||
let client = build_client(USER_AGENT, 30)?;
|
||||
Ok(Self { api_key, client })
|
||||
}
|
||||
|
||||
/// Normalize Unicode hyphens/quotes to ASCII, then URL-encode.
|
||||
fn normalize_name(name: &str) -> String {
|
||||
name.replace(
|
||||
[
|
||||
'\u{2010}', '\u{2011}', '\u{2012}', '\u{2013}', '\u{2014}', '\u{2015}',
|
||||
],
|
||||
"-",
|
||||
)
|
||||
.replace(['\u{2018}', '\u{2019}'], "'")
|
||||
}
|
||||
|
||||
/// Fetch a URL and return the body, or None if Last.fm returns an API error.
|
||||
async fn fetch_or_none(&self, url: &str) -> DataResult<Option<String>> {
|
||||
let resp = match self.client.get(url).send().await {
|
||||
Ok(r) if r.status().is_success() => r,
|
||||
Ok(r) => {
|
||||
tracing::debug!(status = %r.status(), url = url, "Last.fm non-success");
|
||||
return Ok(None);
|
||||
}
|
||||
Err(e) => return Err(DataError::Http(e)),
|
||||
};
|
||||
let body = resp.text().await?;
|
||||
// Check for Last.fm error response
|
||||
if serde_json::from_str::<LfmApiError>(&body).is_ok() {
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(body))
|
||||
}
|
||||
|
||||
/// Fetch by artist name.
|
||||
async fn fetch_by_name(
|
||||
&self,
|
||||
method: &str,
|
||||
artist_name: &str,
|
||||
extra_params: &str,
|
||||
) -> DataResult<Option<String>> {
|
||||
let name = Self::normalize_name(artist_name);
|
||||
let encoded = urlencoded(&name);
|
||||
let url = format!(
|
||||
"{}?method={}&artist={}&api_key={}{}&format=json",
|
||||
BASE_URL, method, encoded, self.api_key, extra_params
|
||||
);
|
||||
self.fetch_or_none(&url).await
|
||||
}
|
||||
|
||||
/// Try MBID lookup then name lookup, returning whichever yields more results.
|
||||
async fn dual_lookup<T>(
|
||||
&self,
|
||||
method: &str,
|
||||
artist_name: &str,
|
||||
mbid: Option<&str>,
|
||||
extra_params: &str,
|
||||
parse: fn(&str) -> DataResult<Vec<T>>,
|
||||
) -> DataResult<Vec<T>> {
|
||||
let mbid_results = if let Some(mbid) = mbid {
|
||||
let url = format!(
|
||||
"{}?method={}&mbid={}&api_key={}{}&format=json",
|
||||
BASE_URL, method, mbid, self.api_key, extra_params
|
||||
);
|
||||
match self.fetch_or_none(&url).await? {
|
||||
Some(body) => parse(&body).unwrap_or_default(),
|
||||
None => Vec::new(),
|
||||
}
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
let name_results = match self
|
||||
.fetch_by_name(method, artist_name, extra_params)
|
||||
.await?
|
||||
{
|
||||
Some(body) => parse(&body).unwrap_or_default(),
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
if name_results.len() > mbid_results.len() {
|
||||
Ok(name_results)
|
||||
} else {
|
||||
Ok(mbid_results)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_similar_artists(body: &str) -> DataResult<Vec<SimilarArtist>> {
|
||||
let resp: LfmSimilarArtistsResponse = serde_json::from_str(body)?;
|
||||
Ok(resp
|
||||
.similarartists
|
||||
.artist
|
||||
.into_iter()
|
||||
.map(|a| {
|
||||
let mbid = a.mbid.filter(|s| !s.is_empty());
|
||||
SimilarArtist {
|
||||
name: a.name,
|
||||
mbid,
|
||||
match_score: a.match_score.parse().unwrap_or(0.0),
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
|
||||
fn parse_top_tracks(body: &str) -> DataResult<Vec<PopularTrack>> {
|
||||
let resp: LfmTopTracksResponse = serde_json::from_str(body)?;
|
||||
Ok(resp
|
||||
.toptracks
|
||||
.track
|
||||
.into_iter()
|
||||
.map(|t| PopularTrack {
|
||||
name: t.name,
|
||||
mbid: t.mbid.filter(|s| !s.is_empty()),
|
||||
playcount: t.playcount.parse().unwrap_or(0),
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl SimilarArtistFetcher for LastFmSimilarFetcher {
|
||||
async fn get_similar_artists(
|
||||
&self,
|
||||
artist_name: &str,
|
||||
mbid: Option<&str>,
|
||||
) -> DataResult<Vec<SimilarArtist>> {
|
||||
self.dual_lookup(
|
||||
"artist.getSimilar",
|
||||
artist_name,
|
||||
mbid,
|
||||
"&limit=500",
|
||||
Self::parse_similar_artists,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn get_top_tracks(
|
||||
&self,
|
||||
artist_name: &str,
|
||||
mbid: Option<&str>,
|
||||
) -> DataResult<Vec<PopularTrack>> {
|
||||
self.dual_lookup(
|
||||
"artist.getTopTracks",
|
||||
artist_name,
|
||||
mbid,
|
||||
"&limit=1000",
|
||||
Self::parse_top_tracks,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
// Last.fm JSON response structs
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LfmApiError {
|
||||
#[allow(dead_code)]
|
||||
error: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LfmSimilarArtistsResponse {
|
||||
similarartists: LfmSimilarArtistsWrapper,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LfmSimilarArtistsWrapper {
|
||||
artist: Vec<LfmArtistEntry>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LfmArtistEntry {
|
||||
name: String,
|
||||
mbid: Option<String>,
|
||||
#[serde(rename = "match")]
|
||||
match_score: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LfmTopTracksResponse {
|
||||
toptracks: LfmTopTracksWrapper,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LfmTopTracksWrapper {
|
||||
track: Vec<LfmTrackEntry>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LfmTrackEntry {
|
||||
name: String,
|
||||
mbid: Option<String>,
|
||||
playcount: String,
|
||||
}
|
||||
|
||||
/// Strip HTML tags from a string with a simple approach.
|
||||
fn strip_html_tags(s: &str) -> String {
|
||||
let mut result = String::with_capacity(s.len());
|
||||
|
||||
@@ -12,7 +12,7 @@ pub mod wikipedia;
|
||||
pub use coverart::CoverArtArchiveFetcher;
|
||||
pub use error::{DataError, DataResult};
|
||||
pub use fanarttv::FanartTvFetcher;
|
||||
pub use lastfm::LastFmBioFetcher;
|
||||
pub use lastfm::{LastFmBioFetcher, LastFmSimilarFetcher};
|
||||
pub use lrclib::LrclibFetcher;
|
||||
pub use musicbrainz::MusicBrainzFetcher;
|
||||
pub use traits::*;
|
||||
|
||||
@@ -21,14 +21,21 @@ pub struct MusicBrainzFetcher {
|
||||
|
||||
impl MusicBrainzFetcher {
|
||||
pub fn new() -> DataResult<Self> {
|
||||
Self::with_limiter(RateLimiter::new(RATE_LIMIT))
|
||||
}
|
||||
|
||||
/// Create a fetcher that shares a rate limiter with other MB clients.
|
||||
pub fn with_limiter(limiter: RateLimiter) -> DataResult<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.user_agent(USER_AGENT)
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()?;
|
||||
Ok(Self {
|
||||
client,
|
||||
limiter: RateLimiter::new(RATE_LIMIT),
|
||||
})
|
||||
Ok(Self { client, limiter })
|
||||
}
|
||||
|
||||
/// Get a clone of the rate limiter for sharing with other MB clients.
|
||||
pub fn limiter(&self) -> RateLimiter {
|
||||
self.limiter.clone()
|
||||
}
|
||||
|
||||
async fn get_json<T: serde::de::DeserializeOwned>(&self, url: &str) -> DataResult<T> {
|
||||
@@ -84,6 +91,24 @@ impl MusicBrainzFetcher {
|
||||
urls,
|
||||
})
|
||||
}
|
||||
|
||||
/// Resolve a release-group MBID to a release MBID (first release in the group).
|
||||
pub async fn resolve_release_from_group(&self, release_group_mbid: &str) -> DataResult<String> {
|
||||
let url = format!("{BASE_URL}/release?release-group={release_group_mbid}&fmt=json&limit=1");
|
||||
let resp: serde_json::Value = self.get_json(&url).await?;
|
||||
|
||||
resp.get("releases")
|
||||
.and_then(|r| r.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|r| r.get("id"))
|
||||
.and_then(|id| id.as_str())
|
||||
.map(String::from)
|
||||
.ok_or_else(|| {
|
||||
DataError::Other(format!(
|
||||
"no releases for release-group {release_group_mbid}"
|
||||
))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl MetadataFetcher for MusicBrainzFetcher {
|
||||
|
||||
@@ -2,8 +2,8 @@ use std::future::Future;
|
||||
|
||||
use crate::error::DataResult;
|
||||
use crate::types::{
|
||||
ArtistInfo, ArtistSearchResult, DiscographyEntry, LyricsResult, RecordingDetails,
|
||||
RecordingMatch, ReleaseGroupEntry, ReleaseMatch, ReleaseTrack,
|
||||
ArtistInfo, ArtistSearchResult, DiscographyEntry, LyricsResult, PopularTrack, RecordingDetails,
|
||||
RecordingMatch, ReleaseGroupEntry, ReleaseMatch, ReleaseTrack, SimilarArtist,
|
||||
};
|
||||
|
||||
/// Trait for metadata lookup backends. MusicBrainz is the default implementation;
|
||||
@@ -86,6 +86,21 @@ pub trait LyricsFetcher: Send + Sync {
|
||||
) -> impl Future<Output = DataResult<LyricsResult>> + Send;
|
||||
}
|
||||
|
||||
/// Fetches similar artists and top tracks from an external source (e.g. Last.fm).
|
||||
pub trait SimilarArtistFetcher: Send + Sync {
|
||||
fn get_similar_artists(
|
||||
&self,
|
||||
artist_name: &str,
|
||||
mbid: Option<&str>,
|
||||
) -> impl Future<Output = DataResult<Vec<SimilarArtist>>> + Send;
|
||||
|
||||
fn get_top_tracks(
|
||||
&self,
|
||||
artist_name: &str,
|
||||
mbid: Option<&str>,
|
||||
) -> impl Future<Output = DataResult<Vec<PopularTrack>>> + Send;
|
||||
}
|
||||
|
||||
/// Fetches cover art URLs for releases.
|
||||
pub trait CoverArtFetcher: Send + Sync {
|
||||
fn get_cover_art_url(&self, release_id: &str) -> Option<String>;
|
||||
|
||||
@@ -111,6 +111,22 @@ pub struct ReleaseTrack {
|
||||
pub duration_ms: Option<u64>,
|
||||
}
|
||||
|
||||
/// A similar artist returned by Last.fm or another provider.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SimilarArtist {
|
||||
pub name: String,
|
||||
pub mbid: Option<String>,
|
||||
pub match_score: f64,
|
||||
}
|
||||
|
||||
/// A popular/top track for an artist from Last.fm.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PopularTrack {
|
||||
pub name: String,
|
||||
pub mbid: Option<String>,
|
||||
pub playcount: u64,
|
||||
}
|
||||
|
||||
/// Result from a lyrics lookup.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LyricsResult {
|
||||
|
||||
Submodule shanty-db updated: 8a1435d9e9...f03f8f0362
@@ -7,7 +7,11 @@ description = "Playlist generation for Shanty"
|
||||
|
||||
[dependencies]
|
||||
shanty-db = { path = "../shanty-db" }
|
||||
clap = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
shanty-data = { path = "../shanty-data" }
|
||||
sea-orm = { version = "1", features = ["sqlx-sqlite", "runtime-tokio-native-tls"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
rand = "0.9"
|
||||
tracing = "0.1"
|
||||
thiserror = "2"
|
||||
chrono = "0.4"
|
||||
|
||||
@@ -2,3 +2,12 @@
|
||||
//!
|
||||
//! Generates playlists based on the indexed music library using strategies like
|
||||
//! similar artists, genre matching, smart rules, and weighted random selection.
|
||||
|
||||
pub mod ordering;
|
||||
pub mod scoring;
|
||||
pub mod selection;
|
||||
pub mod strategies;
|
||||
pub mod types;
|
||||
|
||||
pub use strategies::{PlaylistError, genre_based, random, similar_artists, smart, to_m3u};
|
||||
pub use types::{Candidate, PlaylistRequest, PlaylistResult, PlaylistTrack, SmartRules};
|
||||
|
||||
57
shanty-playlist/src/ordering.rs
Normal file
57
shanty-playlist/src/ordering.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use rand::prelude::*;
|
||||
|
||||
use crate::types::Candidate;
|
||||
|
||||
/// Reorder tracks so that artists are evenly spread out.
|
||||
/// Greedily picks from the artist with the most remaining tracks,
|
||||
/// avoiding back-to-back repeats when possible.
|
||||
/// Ported faithfully from drift's interleave_artists().
|
||||
pub fn interleave_artists(tracks: Vec<Candidate>) -> Vec<Candidate> {
|
||||
let mut rng = rand::rng();
|
||||
|
||||
let mut by_artist: BTreeMap<String, Vec<Candidate>> = BTreeMap::new();
|
||||
for track in tracks {
|
||||
by_artist
|
||||
.entry(track.artist.clone())
|
||||
.or_default()
|
||||
.push(track);
|
||||
}
|
||||
for group in by_artist.values_mut() {
|
||||
group.shuffle(&mut rng);
|
||||
}
|
||||
|
||||
let mut result = Vec::new();
|
||||
let mut last_artist: Option<String> = None;
|
||||
|
||||
while !by_artist.is_empty() {
|
||||
let mut artists: Vec<String> = by_artist.keys().cloned().collect();
|
||||
artists.sort_by(|a, b| by_artist[b].len().cmp(&by_artist[a].len()));
|
||||
|
||||
let pick = artists
|
||||
.iter()
|
||||
.find(|a| last_artist.as_ref() != Some(a))
|
||||
.or(artists.first())
|
||||
.cloned()
|
||||
.unwrap();
|
||||
|
||||
let group = by_artist.get_mut(&pick).unwrap();
|
||||
let track = group.pop().unwrap();
|
||||
if group.is_empty() {
|
||||
by_artist.remove(&pick);
|
||||
}
|
||||
|
||||
last_artist = Some(pick);
|
||||
result.push(track);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Full random shuffle.
|
||||
pub fn shuffle(mut tracks: Vec<Candidate>) -> Vec<Candidate> {
|
||||
let mut rng = rand::rng();
|
||||
tracks.shuffle(&mut rng);
|
||||
tracks
|
||||
}
|
||||
151
shanty-playlist/src/scoring.rs
Normal file
151
shanty-playlist/src/scoring.rs
Normal file
@@ -0,0 +1,151 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use shanty_data::PopularTrack;
|
||||
use shanty_db::entities::track::Model as Track;
|
||||
|
||||
use crate::types::ScoredTrack;
|
||||
|
||||
/// Popularity exponent curve (0-10 scale).
|
||||
/// 0 = no preference, 10 = heavy popular bias.
|
||||
const POPULARITY_EXPONENTS: [f64; 11] = [
|
||||
0.0, 0.06, 0.17, 0.33, 0.67, 1.30, 1.50, 1.70, 1.94, 2.22, 2.50,
|
||||
];
|
||||
|
||||
/// Score all tracks for the given artists, returning scored tracks for ranking.
|
||||
///
|
||||
/// `artists` is a list of (mbid_or_name, display_name, similarity_score) tuples.
|
||||
/// `tracks_by_artist` maps artist identifier -> their local tracks.
|
||||
/// `top_tracks_by_artist` maps artist identifier -> their Last.fm top tracks.
|
||||
pub fn score_tracks(
|
||||
artists: &[(String, String, f64)],
|
||||
tracks_by_artist: &HashMap<String, Vec<Track>>,
|
||||
top_tracks_by_artist: &HashMap<String, Vec<PopularTrack>>,
|
||||
popularity_bias: u8,
|
||||
) -> Vec<ScoredTrack> {
|
||||
let bias = popularity_bias.min(10) as usize;
|
||||
let mut scored = Vec::new();
|
||||
|
||||
for (artist_key, name, match_score) in artists {
|
||||
let local_tracks = match tracks_by_artist.get(artist_key) {
|
||||
Some(t) if !t.is_empty() => t,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let top_tracks = top_tracks_by_artist
|
||||
.get(artist_key)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
// Build playcount lookup by lowercase name
|
||||
let playcount_by_name: HashMap<String, u64> = top_tracks
|
||||
.iter()
|
||||
.map(|t| (t.name.to_lowercase(), t.playcount))
|
||||
.collect();
|
||||
|
||||
let max_playcount = playcount_by_name
|
||||
.values()
|
||||
.copied()
|
||||
.max()
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
|
||||
for track in local_tracks {
|
||||
let title_lower = track.title.as_ref().map(|t| t.to_lowercase());
|
||||
|
||||
let playcount = title_lower
|
||||
.as_ref()
|
||||
.and_then(|t| playcount_by_name.get(t).copied())
|
||||
.or_else(|| {
|
||||
track
|
||||
.musicbrainz_id
|
||||
.as_ref()
|
||||
.and_then(|id| playcount_by_name.get(id).copied())
|
||||
});
|
||||
|
||||
// If we have popularity data, require a match; otherwise assign uniform score
|
||||
let (popularity, similarity, score) = if !playcount_by_name.is_empty() {
|
||||
let Some(playcount) = playcount else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let popularity = if playcount > 0 {
|
||||
(playcount as f64 / max_playcount as f64).powf(POPULARITY_EXPONENTS[bias])
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
let similarity = (match_score.exp()) / std::f64::consts::E;
|
||||
let score = similarity * popularity;
|
||||
(popularity, similarity, score)
|
||||
} else {
|
||||
// No top tracks data — use uniform scoring based on similarity only
|
||||
let similarity = (match_score.exp()) / std::f64::consts::E;
|
||||
(1.0, similarity, similarity)
|
||||
};
|
||||
|
||||
scored.push(ScoredTrack {
|
||||
track_id: track.id,
|
||||
file_path: track.file_path.clone(),
|
||||
title: track.title.clone(),
|
||||
artist: name.clone(),
|
||||
artist_mbid: track
|
||||
.artist_id
|
||||
.map(|_| artist_key.clone())
|
||||
.or_else(|| Some(artist_key.clone())),
|
||||
album: track.album.clone(),
|
||||
duration: track.duration,
|
||||
score,
|
||||
popularity,
|
||||
similarity,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Step 1: Cap tracks per artist based on popularity bias
|
||||
let mut by_artist: HashMap<String, Vec<ScoredTrack>> = HashMap::new();
|
||||
for t in scored {
|
||||
let key = t.artist_mbid.clone().unwrap_or_else(|| t.artist.clone());
|
||||
by_artist.entry(key).or_default().push(t);
|
||||
}
|
||||
|
||||
let cap = if popularity_bias == 0 {
|
||||
None
|
||||
} else {
|
||||
let b = popularity_bias as f64;
|
||||
let c = if b <= 5.0 {
|
||||
90.0 - 12.8 * b
|
||||
} else {
|
||||
26.0 - 3.2 * (b - 5.0)
|
||||
};
|
||||
Some((c.round() as usize).max(1))
|
||||
};
|
||||
|
||||
for group in by_artist.values_mut() {
|
||||
group.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
if let Some(cap) = cap {
|
||||
group.truncate(cap);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Normalize so each artist's total weight = their similarity
|
||||
let similarity_map: HashMap<&str, f64> = artists
|
||||
.iter()
|
||||
.map(|(key, _, sim)| (key.as_str(), *sim))
|
||||
.collect();
|
||||
|
||||
for (key, group) in &mut by_artist {
|
||||
let total: f64 = group.iter().map(|t| t.score).sum();
|
||||
if total > 0.0 {
|
||||
let sim = similarity_map.get(key.as_str()).copied().unwrap_or(1.0);
|
||||
for t in group.iter_mut() {
|
||||
t.score *= sim / total;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
by_artist.into_values().flatten().collect()
|
||||
}
|
||||
95
shanty-playlist/src/selection.rs
Normal file
95
shanty-playlist/src/selection.rs
Normal file
@@ -0,0 +1,95 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use rand::distr::weighted::WeightedIndex;
|
||||
use rand::prelude::*;
|
||||
|
||||
use crate::types::Candidate;
|
||||
|
||||
/// Weighted random sampling with per-artist caps and seed enforcement.
|
||||
/// Ported faithfully from drift's generate_playlist().
|
||||
pub fn generate_playlist(
|
||||
candidates: &[Candidate],
|
||||
n: usize,
|
||||
seed_names: &HashSet<String>,
|
||||
) -> Vec<Candidate> {
|
||||
if candidates.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut rng = rand::rng();
|
||||
let mut pool: Vec<&Candidate> = candidates.iter().collect();
|
||||
let mut result: Vec<Candidate> = Vec::new();
|
||||
let mut artist_counts: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
let seed_min = (n / 10).max(1);
|
||||
|
||||
let distinct_artists: usize = {
|
||||
let mut seen = HashSet::new();
|
||||
for c in &pool {
|
||||
seen.insert(&c.artist);
|
||||
}
|
||||
seen.len()
|
||||
};
|
||||
|
||||
let divisor = match distinct_artists {
|
||||
1 => 1,
|
||||
2 => 2,
|
||||
3 => 3,
|
||||
4 => 3,
|
||||
5 => 4,
|
||||
_ => 5,
|
||||
};
|
||||
let artist_cap = n.div_ceil(divisor).max(1);
|
||||
|
||||
while result.len() < n && !pool.is_empty() {
|
||||
let seed_count: usize = seed_names
|
||||
.iter()
|
||||
.map(|name| *artist_counts.get(name).unwrap_or(&0))
|
||||
.sum();
|
||||
let remaining = n - result.len();
|
||||
let seed_deficit = seed_min.saturating_sub(seed_count);
|
||||
let force_seed = seed_deficit > 0 && remaining <= seed_deficit;
|
||||
|
||||
let eligible: Vec<usize> = pool
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, c)| {
|
||||
if force_seed {
|
||||
seed_names.contains(&c.artist)
|
||||
} else {
|
||||
*artist_counts.get(&c.artist).unwrap_or(&0) < artist_cap
|
||||
}
|
||||
})
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
let fallback_indices: Vec<usize> = (0..pool.len()).collect();
|
||||
let indices: &[usize] = if eligible.is_empty() {
|
||||
&fallback_indices
|
||||
} else {
|
||||
&eligible
|
||||
};
|
||||
|
||||
let weights: Vec<f64> = indices.iter().map(|&i| pool[i].score.max(0.001)).collect();
|
||||
let dist = match WeightedIndex::new(&weights) {
|
||||
Ok(d) => d,
|
||||
Err(_) => break,
|
||||
};
|
||||
|
||||
let picked = indices[dist.sample(&mut rng)];
|
||||
let track = pool.remove(picked);
|
||||
*artist_counts.entry(track.artist.clone()).or_insert(0) += 1;
|
||||
result.push(Candidate {
|
||||
score: track.score,
|
||||
artist: track.artist.clone(),
|
||||
artist_mbid: track.artist_mbid.clone(),
|
||||
track_id: track.track_id,
|
||||
file_path: track.file_path.clone(),
|
||||
title: track.title.clone(),
|
||||
album: track.album.clone(),
|
||||
duration: track.duration,
|
||||
});
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
490
shanty-playlist/src/strategies.rs
Normal file
490
shanty-playlist/src/strategies.rs
Normal file
@@ -0,0 +1,490 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use sea_orm::DatabaseConnection;
|
||||
use shanty_data::{PopularTrack, SimilarArtist, SimilarArtistFetcher};
|
||||
use shanty_db::queries;
|
||||
|
||||
use crate::ordering;
|
||||
use crate::scoring;
|
||||
use crate::selection;
|
||||
use crate::types::*;
|
||||
|
||||
/// Cache TTL: 7 days in seconds.
|
||||
const CACHE_TTL: i64 = 7 * 24 * 3600;
|
||||
|
||||
/// Generate a playlist based on similar artists (the primary strategy).
|
||||
///
|
||||
/// Flow:
|
||||
/// 1. For each seed artist: resolve MBID from DB
|
||||
/// 2. Fetch similar artists (check cache first, else call Last.fm, cache result)
|
||||
/// 3. Merge multi-seed: accumulate scores, normalize by seed count
|
||||
/// 4. Filter: only keep artists that have tracks in the local library
|
||||
/// 5. Score all tracks
|
||||
/// 6. Select via weighted sampling
|
||||
/// 7. Order (interleave or shuffle)
|
||||
pub async fn similar_artists(
|
||||
conn: &DatabaseConnection,
|
||||
fetcher: &impl SimilarArtistFetcher,
|
||||
seed_artists: Vec<String>,
|
||||
count: usize,
|
||||
popularity_bias: u8,
|
||||
ordering: &str,
|
||||
) -> Result<PlaylistResult, PlaylistError> {
|
||||
if seed_artists.is_empty() {
|
||||
return Err(PlaylistError::InvalidInput(
|
||||
"at least one seed artist is required".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let num_seeds = seed_artists.len() as f64;
|
||||
|
||||
// Merge similar artists from all seeds: key -> (name, total_score)
|
||||
let mut merged: HashMap<String, (String, f64)> = HashMap::new();
|
||||
// Track resolved seed names for enforcement (use DB names, not raw input)
|
||||
let mut resolved_seed_names: HashSet<String> = HashSet::new();
|
||||
|
||||
for seed in &seed_artists {
|
||||
// Resolve the seed artist: try name lookup in DB
|
||||
let (artist_name, artist_mbid) = resolve_artist(conn, seed).await?;
|
||||
resolved_seed_names.insert(artist_name.clone());
|
||||
|
||||
// Insert the seed itself with score 1.0
|
||||
let key = artist_mbid
|
||||
.clone()
|
||||
.unwrap_or_else(|| artist_name.to_lowercase());
|
||||
let entry = merged
|
||||
.entry(key)
|
||||
.or_insert_with(|| (artist_name.clone(), 0.0));
|
||||
entry.1 += 1.0;
|
||||
|
||||
// Fetch similar artists (cached or fresh)
|
||||
let similar = fetch_cached_similar(conn, fetcher, &artist_name, artist_mbid.as_deref())
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
for sa in similar {
|
||||
let sa_key = sa.mbid.clone().unwrap_or_else(|| sa.name.to_lowercase());
|
||||
let entry = merged
|
||||
.entry(sa_key)
|
||||
.or_insert_with(|| (sa.name.clone(), 0.0));
|
||||
entry.1 += sa.match_score;
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize scores by seed count
|
||||
let artists: Vec<(String, String, f64)> = merged
|
||||
.into_iter()
|
||||
.map(|(key, (name, total))| (key, name, total / num_seeds))
|
||||
.collect();
|
||||
|
||||
// Build track and top-track maps for scoring
|
||||
let mut tracks_by_artist: HashMap<String, Vec<shanty_db::entities::track::Model>> =
|
||||
HashMap::new();
|
||||
let mut top_tracks_by_artist: HashMap<String, Vec<PopularTrack>> = HashMap::new();
|
||||
|
||||
for (key, name, _) in &artists {
|
||||
// Get local tracks for this artist
|
||||
let local = get_artist_tracks(conn, key, name).await;
|
||||
if local.is_empty() {
|
||||
continue;
|
||||
}
|
||||
tracks_by_artist.insert(key.clone(), local);
|
||||
|
||||
// Get top tracks (cached or fresh)
|
||||
let top = fetch_cached_top_tracks(conn, fetcher, name, Some(key.as_str()))
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
if !top.is_empty() {
|
||||
top_tracks_by_artist.insert(key.clone(), top);
|
||||
}
|
||||
}
|
||||
|
||||
// Score
|
||||
let scored = scoring::score_tracks(
|
||||
&artists,
|
||||
&tracks_by_artist,
|
||||
&top_tracks_by_artist,
|
||||
popularity_bias,
|
||||
);
|
||||
|
||||
// Convert to candidates
|
||||
let candidates: Vec<Candidate> = scored
|
||||
.into_iter()
|
||||
.map(|t| Candidate {
|
||||
score: t.score,
|
||||
artist: t.artist,
|
||||
artist_mbid: t.artist_mbid,
|
||||
track_id: t.track_id,
|
||||
file_path: t.file_path,
|
||||
title: t.title,
|
||||
album: t.album,
|
||||
duration: t.duration,
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Select (use resolved DB names for seed enforcement, not raw input)
|
||||
let selected = selection::generate_playlist(&candidates, count, &resolved_seed_names);
|
||||
|
||||
// Order
|
||||
let ordered = apply_ordering(selected, ordering);
|
||||
|
||||
Ok(PlaylistResult {
|
||||
tracks: candidates_to_tracks(ordered),
|
||||
strategy: "similar".to_string(),
|
||||
resolved_seeds: resolved_seed_names.into_iter().collect(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a genre-based playlist.
|
||||
pub async fn genre_based(
|
||||
conn: &DatabaseConnection,
|
||||
genres: Vec<String>,
|
||||
count: usize,
|
||||
ordering: &str,
|
||||
) -> Result<PlaylistResult, PlaylistError> {
|
||||
if genres.is_empty() {
|
||||
return Err(PlaylistError::InvalidInput(
|
||||
"at least one genre is required".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut all_tracks = Vec::new();
|
||||
for genre in &genres {
|
||||
let tracks = queries::tracks::get_by_genre(conn, genre)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
all_tracks.extend(tracks);
|
||||
}
|
||||
|
||||
// Deduplicate by track ID
|
||||
let mut seen = HashSet::new();
|
||||
all_tracks.retain(|t| seen.insert(t.id));
|
||||
|
||||
// Convert to candidates with uniform scoring
|
||||
let candidates: Vec<Candidate> = all_tracks
|
||||
.into_iter()
|
||||
.map(|t| Candidate {
|
||||
score: 1.0,
|
||||
artist: t.artist.clone().unwrap_or_default(),
|
||||
artist_mbid: t.musicbrainz_id.clone(),
|
||||
track_id: t.id,
|
||||
file_path: t.file_path.clone(),
|
||||
title: t.title.clone(),
|
||||
album: t.album.clone(),
|
||||
duration: t.duration,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let seed_names = HashSet::new();
|
||||
let selected = selection::generate_playlist(&candidates, count, &seed_names);
|
||||
let ordered = apply_ordering(selected, ordering);
|
||||
|
||||
Ok(PlaylistResult {
|
||||
tracks: candidates_to_tracks(ordered),
|
||||
strategy: "genre".to_string(),
|
||||
resolved_seeds: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a random playlist.
|
||||
pub async fn random(
|
||||
conn: &DatabaseConnection,
|
||||
count: usize,
|
||||
no_repeat_artist: bool,
|
||||
) -> Result<PlaylistResult, PlaylistError> {
|
||||
let tracks = queries::tracks::get_random(conn, count as u64 * 2)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
|
||||
let mut result = Vec::new();
|
||||
let mut seen_artists: HashSet<String> = HashSet::new();
|
||||
|
||||
for t in tracks {
|
||||
if result.len() >= count {
|
||||
break;
|
||||
}
|
||||
let artist = t.artist.clone().unwrap_or_default();
|
||||
if no_repeat_artist && !artist.is_empty() && !seen_artists.insert(artist.clone()) {
|
||||
continue;
|
||||
}
|
||||
result.push(PlaylistTrack {
|
||||
track_id: t.id,
|
||||
file_path: t.file_path.clone(),
|
||||
title: t.title.clone(),
|
||||
artist: t.artist.clone(),
|
||||
album: t.album.clone(),
|
||||
score: 0.0,
|
||||
duration: t.duration,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(PlaylistResult {
|
||||
tracks: result,
|
||||
strategy: "random".to_string(),
|
||||
resolved_seeds: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate a smart playlist based on rules.
|
||||
pub async fn smart(
|
||||
conn: &DatabaseConnection,
|
||||
rules: SmartRules,
|
||||
count: usize,
|
||||
) -> Result<PlaylistResult, PlaylistError> {
|
||||
let mut all_tracks: Vec<shanty_db::entities::track::Model> = Vec::new();
|
||||
|
||||
// Genre filter
|
||||
if !rules.genres.is_empty() {
|
||||
for genre in &rules.genres {
|
||||
let tracks = queries::tracks::get_by_genre(conn, genre)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
all_tracks.extend(tracks);
|
||||
}
|
||||
}
|
||||
|
||||
// Artist filter
|
||||
if !rules.artists.is_empty() {
|
||||
for artist_name in &rules.artists {
|
||||
let tracks = queries::tracks::get_by_artist_name(conn, artist_name)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
all_tracks.extend(tracks);
|
||||
}
|
||||
}
|
||||
|
||||
// Recently added filter
|
||||
if let Some(days) = rules.added_within_days {
|
||||
let tracks = queries::tracks::get_recent(conn, days, 10000)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
if all_tracks.is_empty() {
|
||||
all_tracks = tracks;
|
||||
} else {
|
||||
let recent_ids: HashSet<i32> = tracks.iter().map(|t| t.id).collect();
|
||||
all_tracks.retain(|t| recent_ids.contains(&t.id));
|
||||
}
|
||||
}
|
||||
|
||||
// Year range filter
|
||||
if let Some((min_year, max_year)) = rules.year_range {
|
||||
all_tracks.retain(|t| {
|
||||
t.year
|
||||
.map(|y| y >= min_year && y <= max_year)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
}
|
||||
|
||||
// If no filters were specified, get random tracks
|
||||
if rules.genres.is_empty()
|
||||
&& rules.artists.is_empty()
|
||||
&& rules.added_within_days.is_none()
|
||||
&& rules.year_range.is_none()
|
||||
{
|
||||
all_tracks = queries::tracks::get_random(conn, count as u64 * 2)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
}
|
||||
|
||||
// Deduplicate
|
||||
let mut seen = HashSet::new();
|
||||
all_tracks.retain(|t| seen.insert(t.id));
|
||||
|
||||
// Convert to candidates
|
||||
let candidates: Vec<Candidate> = all_tracks
|
||||
.into_iter()
|
||||
.map(|t| Candidate {
|
||||
score: 1.0,
|
||||
artist: t.artist.clone().unwrap_or_default(),
|
||||
artist_mbid: t.musicbrainz_id.clone(),
|
||||
track_id: t.id,
|
||||
file_path: t.file_path.clone(),
|
||||
title: t.title.clone(),
|
||||
album: t.album.clone(),
|
||||
duration: t.duration,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let seed_names = HashSet::new();
|
||||
let selected = selection::generate_playlist(&candidates, count, &seed_names);
|
||||
let ordered = ordering::interleave_artists(selected);
|
||||
|
||||
Ok(PlaylistResult {
|
||||
tracks: candidates_to_tracks(ordered),
|
||||
strategy: "smart".to_string(),
|
||||
resolved_seeds: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
/// Generate an M3U playlist string from tracks.
|
||||
pub fn to_m3u(tracks: &[PlaylistTrack]) -> String {
|
||||
let mut out = String::from("#EXTM3U\n");
|
||||
for t in tracks {
|
||||
let duration = t.duration.unwrap_or(0.0) as i64;
|
||||
let artist = t.artist.as_deref().unwrap_or("Unknown");
|
||||
let title = t.title.as_deref().unwrap_or("Unknown");
|
||||
out.push_str(&format!("#EXTINF:{duration},{artist} - {title}\n"));
|
||||
out.push_str(&t.file_path);
|
||||
out.push('\n');
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Apply the chosen ordering mode to selected candidates.
|
||||
fn apply_ordering(mut candidates: Vec<Candidate>, mode: &str) -> Vec<Candidate> {
|
||||
match mode {
|
||||
"random" => ordering::shuffle(candidates),
|
||||
"score" => {
|
||||
candidates.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
candidates
|
||||
}
|
||||
_ => ordering::interleave_artists(candidates), // "interleave" is default
|
||||
}
|
||||
}
|
||||
|
||||
// --- Helper functions ---
|
||||
|
||||
/// Resolve an artist name or MBID to (name, optional_mbid).
|
||||
async fn resolve_artist(
|
||||
conn: &DatabaseConnection,
|
||||
query: &str,
|
||||
) -> Result<(String, Option<String>), PlaylistError> {
|
||||
// Try as MBID first (if it looks like a UUID)
|
||||
if query.len() == 36 && query.contains('-') {
|
||||
let tracks = queries::tracks::get_by_artist_mbid(conn, query)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::Db(e.to_string()))?;
|
||||
if let Some(t) = tracks.first() {
|
||||
let name = t.artist.clone().unwrap_or_else(|| query.to_string());
|
||||
return Ok((name, Some(query.to_string())));
|
||||
}
|
||||
}
|
||||
|
||||
// Try by name in DB
|
||||
if let Ok(Some(artist)) = queries::artists::find_by_name(conn, query).await {
|
||||
return Ok((artist.name, artist.musicbrainz_id));
|
||||
}
|
||||
|
||||
// Fall back to using the query as the name
|
||||
Ok((query.to_string(), None))
|
||||
}
|
||||
|
||||
/// Get local tracks for an artist by key (MBID) or name.
|
||||
async fn get_artist_tracks(
|
||||
conn: &DatabaseConnection,
|
||||
key: &str,
|
||||
name: &str,
|
||||
) -> Vec<shanty_db::entities::track::Model> {
|
||||
// Try by MBID first
|
||||
if let Ok(tracks) = queries::tracks::get_by_artist_mbid(conn, key).await
|
||||
&& !tracks.is_empty()
|
||||
{
|
||||
return tracks;
|
||||
}
|
||||
|
||||
// Try by name
|
||||
queries::tracks::get_by_artist_name(conn, name)
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Fetch similar artists from cache or Last.fm.
|
||||
async fn fetch_cached_similar(
|
||||
conn: &DatabaseConnection,
|
||||
fetcher: &impl SimilarArtistFetcher,
|
||||
artist_name: &str,
|
||||
mbid: Option<&str>,
|
||||
) -> Result<Vec<SimilarArtist>, PlaylistError> {
|
||||
let cache_key = format!(
|
||||
"lastfm_similar:{}",
|
||||
mbid.unwrap_or(&artist_name.to_lowercase())
|
||||
);
|
||||
|
||||
// Check cache
|
||||
if let Ok(Some(json)) = queries::cache::get(conn, &cache_key).await
|
||||
&& let Ok(cached) = serde_json::from_str::<Vec<SimilarArtist>>(&json)
|
||||
{
|
||||
tracing::debug!(artist = artist_name, "using cached similar artists");
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
// Fetch from provider
|
||||
let similar = fetcher
|
||||
.get_similar_artists(artist_name, mbid)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::FetchError(e.to_string()))?;
|
||||
|
||||
// Cache the result
|
||||
if let Ok(json) = serde_json::to_string(&similar) {
|
||||
let _ = queries::cache::set(conn, &cache_key, "lastfm", &json, CACHE_TTL).await;
|
||||
}
|
||||
|
||||
Ok(similar)
|
||||
}
|
||||
|
||||
/// Fetch top tracks from cache or Last.fm.
|
||||
async fn fetch_cached_top_tracks(
|
||||
conn: &DatabaseConnection,
|
||||
fetcher: &impl SimilarArtistFetcher,
|
||||
artist_name: &str,
|
||||
mbid: Option<&str>,
|
||||
) -> Result<Vec<PopularTrack>, PlaylistError> {
|
||||
let cache_key = format!(
|
||||
"lastfm_toptracks:{}",
|
||||
mbid.unwrap_or(&artist_name.to_lowercase())
|
||||
);
|
||||
|
||||
// Check cache
|
||||
if let Ok(Some(json)) = queries::cache::get(conn, &cache_key).await
|
||||
&& let Ok(cached) = serde_json::from_str::<Vec<PopularTrack>>(&json)
|
||||
{
|
||||
tracing::debug!(artist = artist_name, "using cached top tracks");
|
||||
return Ok(cached);
|
||||
}
|
||||
|
||||
// Fetch from provider
|
||||
let tracks = fetcher
|
||||
.get_top_tracks(artist_name, mbid)
|
||||
.await
|
||||
.map_err(|e| PlaylistError::FetchError(e.to_string()))?;
|
||||
|
||||
// Cache the result
|
||||
if let Ok(json) = serde_json::to_string(&tracks) {
|
||||
let _ = queries::cache::set(conn, &cache_key, "lastfm", &json, CACHE_TTL).await;
|
||||
}
|
||||
|
||||
Ok(tracks)
|
||||
}
|
||||
|
||||
/// Convert candidates to playlist tracks.
|
||||
fn candidates_to_tracks(candidates: Vec<Candidate>) -> Vec<PlaylistTrack> {
|
||||
candidates
|
||||
.into_iter()
|
||||
.map(|c| PlaylistTrack {
|
||||
track_id: c.track_id,
|
||||
file_path: c.file_path,
|
||||
title: c.title,
|
||||
artist: Some(c.artist),
|
||||
album: c.album,
|
||||
score: c.score,
|
||||
duration: c.duration,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Playlist generation error.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PlaylistError {
|
||||
#[error("invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
#[error("database error: {0}")]
|
||||
Db(String),
|
||||
|
||||
#[error("fetch error: {0}")]
|
||||
FetchError(String),
|
||||
}
|
||||
93
shanty-playlist/src/types.rs
Normal file
93
shanty-playlist/src/types.rs
Normal file
@@ -0,0 +1,93 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Request to generate a playlist.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlaylistRequest {
|
||||
pub strategy: String,
|
||||
#[serde(default)]
|
||||
pub seed_artists: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub genres: Vec<String>,
|
||||
#[serde(default = "default_count")]
|
||||
pub count: usize,
|
||||
#[serde(default = "default_popularity_bias")]
|
||||
pub popularity_bias: u8,
|
||||
/// Ordering mode: "score" (by score), "interleave" (spread artists), "random" (full shuffle).
|
||||
#[serde(default = "default_ordering")]
|
||||
pub ordering: String,
|
||||
#[serde(default)]
|
||||
pub rules: Option<SmartRules>,
|
||||
}
|
||||
|
||||
fn default_count() -> usize {
|
||||
50
|
||||
}
|
||||
|
||||
fn default_popularity_bias() -> u8 {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_ordering() -> String {
|
||||
"interleave".to_string()
|
||||
}
|
||||
|
||||
/// Result of generating a playlist.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlaylistResult {
|
||||
pub tracks: Vec<PlaylistTrack>,
|
||||
pub strategy: String,
|
||||
/// Resolved seed artist names (for display — may differ from input query).
|
||||
#[serde(default)]
|
||||
pub resolved_seeds: Vec<String>,
|
||||
}
|
||||
|
||||
/// A track in a generated playlist.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlaylistTrack {
|
||||
pub track_id: i32,
|
||||
pub file_path: String,
|
||||
pub title: Option<String>,
|
||||
pub artist: Option<String>,
|
||||
pub album: Option<String>,
|
||||
pub score: f64,
|
||||
pub duration: Option<f64>,
|
||||
}
|
||||
|
||||
/// A weighted candidate for playlist selection (internal).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Candidate {
|
||||
pub score: f64,
|
||||
pub artist: String,
|
||||
pub artist_mbid: Option<String>,
|
||||
pub track_id: i32,
|
||||
pub file_path: String,
|
||||
pub title: Option<String>,
|
||||
pub album: Option<String>,
|
||||
pub duration: Option<f64>,
|
||||
}
|
||||
|
||||
/// A scored track before candidate conversion (internal).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ScoredTrack {
|
||||
pub track_id: i32,
|
||||
pub file_path: String,
|
||||
pub title: Option<String>,
|
||||
pub artist: String,
|
||||
pub artist_mbid: Option<String>,
|
||||
pub album: Option<String>,
|
||||
pub duration: Option<f64>,
|
||||
pub score: f64,
|
||||
pub popularity: f64,
|
||||
pub similarity: f64,
|
||||
}
|
||||
|
||||
/// Rules for the "smart" playlist strategy.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct SmartRules {
|
||||
#[serde(default)]
|
||||
pub genres: Vec<String>,
|
||||
pub added_within_days: Option<u32>,
|
||||
pub year_range: Option<(i32, i32)>,
|
||||
#[serde(default)]
|
||||
pub artists: Vec<String>,
|
||||
}
|
||||
333
shanty-playlist/tests/unit.rs
Normal file
333
shanty-playlist/tests/unit.rs
Normal file
@@ -0,0 +1,333 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use chrono::Utc;
|
||||
use shanty_data::PopularTrack;
|
||||
use shanty_db::entities::track::Model as Track;
|
||||
use shanty_playlist::ordering::{interleave_artists, shuffle};
|
||||
use shanty_playlist::scoring::score_tracks;
|
||||
use shanty_playlist::selection::generate_playlist;
|
||||
use shanty_playlist::types::Candidate;
|
||||
|
||||
fn make_track(id: i32, title: &str, artist_id: Option<i32>, mbid: Option<&str>) -> Track {
|
||||
let now = Utc::now().naive_utc();
|
||||
Track {
|
||||
id,
|
||||
file_path: format!("/music/{title}.opus"),
|
||||
title: Some(title.to_string()),
|
||||
artist: Some("Test Artist".to_string()),
|
||||
album: Some("Test Album".to_string()),
|
||||
album_artist: None,
|
||||
track_number: Some(id),
|
||||
disc_number: None,
|
||||
duration: Some(200.0),
|
||||
codec: None,
|
||||
bitrate: None,
|
||||
genre: None,
|
||||
year: None,
|
||||
musicbrainz_id: mbid.map(String::from),
|
||||
file_size: 1000,
|
||||
fingerprint: None,
|
||||
artist_id,
|
||||
album_id: None,
|
||||
added_at: now,
|
||||
updated_at: now,
|
||||
file_mtime: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn make_candidate(id: i32, artist: &str, score: f64) -> Candidate {
|
||||
Candidate {
|
||||
score,
|
||||
artist: artist.to_string(),
|
||||
artist_mbid: Some(format!("mbid-{artist}")),
|
||||
track_id: id,
|
||||
file_path: format!("/music/{id}.opus"),
|
||||
title: Some(format!("Track {id}")),
|
||||
album: None,
|
||||
duration: Some(200.0),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Scoring tests ---
|
||||
|
||||
#[test]
|
||||
fn test_score_tracks_basic() {
|
||||
let artists = vec![("artist-1".to_string(), "Artist One".to_string(), 1.0)];
|
||||
|
||||
let tracks: Vec<Track> = (1..=5)
|
||||
.map(|i| make_track(i, &format!("Song {i}"), Some(1), None))
|
||||
.collect();
|
||||
|
||||
let top_tracks = vec![
|
||||
PopularTrack {
|
||||
name: "Song 1".to_string(),
|
||||
mbid: None,
|
||||
playcount: 1000,
|
||||
},
|
||||
PopularTrack {
|
||||
name: "Song 2".to_string(),
|
||||
mbid: None,
|
||||
playcount: 500,
|
||||
},
|
||||
PopularTrack {
|
||||
name: "Song 3".to_string(),
|
||||
mbid: None,
|
||||
playcount: 100,
|
||||
},
|
||||
];
|
||||
|
||||
let mut tracks_map = HashMap::new();
|
||||
tracks_map.insert("artist-1".to_string(), tracks);
|
||||
|
||||
let mut top_map = HashMap::new();
|
||||
top_map.insert("artist-1".to_string(), top_tracks);
|
||||
|
||||
let scored = score_tracks(&artists, &tracks_map, &top_map, 5);
|
||||
|
||||
// Should have 3 tracks (only ones matching top tracks)
|
||||
assert_eq!(scored.len(), 3);
|
||||
|
||||
// Higher playcount should yield higher score
|
||||
let scores: Vec<f64> = scored.iter().map(|t| t.score).collect();
|
||||
let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
|
||||
assert!(max_score > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_score_tracks_no_top_tracks_uses_uniform() {
|
||||
let artists = vec![("artist-1".to_string(), "Artist One".to_string(), 0.8)];
|
||||
|
||||
let tracks: Vec<Track> = (1..=3)
|
||||
.map(|i| make_track(i, &format!("Song {i}"), Some(1), None))
|
||||
.collect();
|
||||
|
||||
let mut tracks_map = HashMap::new();
|
||||
tracks_map.insert("artist-1".to_string(), tracks);
|
||||
|
||||
let top_map = HashMap::new(); // No top tracks
|
||||
|
||||
let scored = score_tracks(&artists, &tracks_map, &top_map, 5);
|
||||
|
||||
// All 3 tracks should be included with uniform scoring
|
||||
assert_eq!(scored.len(), 3);
|
||||
// All should have the same score (similarity only)
|
||||
let first_score = scored[0].score;
|
||||
for t in &scored {
|
||||
assert!((t.score - first_score).abs() < 1e-10);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_score_tracks_per_artist_cap() {
|
||||
let artists = vec![("artist-1".to_string(), "Artist One".to_string(), 1.0)];
|
||||
|
||||
// 50 tracks
|
||||
let tracks: Vec<Track> = (1..=50)
|
||||
.map(|i| make_track(i, &format!("Song {i}"), Some(1), None))
|
||||
.collect();
|
||||
|
||||
let top_tracks: Vec<PopularTrack> = (1..=50)
|
||||
.map(|i| PopularTrack {
|
||||
name: format!("Song {i}"),
|
||||
mbid: None,
|
||||
playcount: (50 - i + 1) as u64 * 100,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut tracks_map = HashMap::new();
|
||||
tracks_map.insert("artist-1".to_string(), tracks);
|
||||
let mut top_map = HashMap::new();
|
||||
top_map.insert("artist-1".to_string(), top_tracks);
|
||||
|
||||
// bias 10 → cap = 10
|
||||
let scored = score_tracks(&artists, &tracks_map, &top_map, 10);
|
||||
assert!(scored.len() <= 10);
|
||||
|
||||
// bias 0 → no cap
|
||||
let scored_no_cap = score_tracks(&artists, &tracks_map, &top_map, 0);
|
||||
assert_eq!(scored_no_cap.len(), 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_similarity_transform() {
|
||||
// Higher match_score should produce higher similarity
|
||||
let artists = vec![
|
||||
("high".to_string(), "High".to_string(), 0.9),
|
||||
("low".to_string(), "Low".to_string(), 0.1),
|
||||
];
|
||||
|
||||
let track_high = make_track(1, "Song", Some(1), None);
|
||||
let track_low = make_track(2, "Song", Some(2), None);
|
||||
|
||||
let mut tracks_map = HashMap::new();
|
||||
tracks_map.insert("high".to_string(), vec![track_high]);
|
||||
tracks_map.insert("low".to_string(), vec![track_low]);
|
||||
|
||||
let scored = score_tracks(&artists, &tracks_map, &HashMap::new(), 5);
|
||||
assert_eq!(scored.len(), 2);
|
||||
|
||||
let high_score = scored
|
||||
.iter()
|
||||
.find(|t| t.artist == "High")
|
||||
.unwrap()
|
||||
.similarity;
|
||||
let low_score = scored
|
||||
.iter()
|
||||
.find(|t| t.artist == "Low")
|
||||
.unwrap()
|
||||
.similarity;
|
||||
assert!(high_score > low_score);
|
||||
}
|
||||
|
||||
// --- Selection tests ---
|
||||
|
||||
#[test]
|
||||
fn test_generate_playlist_basic() {
|
||||
let candidates: Vec<Candidate> = (1..=20)
|
||||
.map(|i| make_candidate(i, &format!("Artist{}", i % 4), 1.0))
|
||||
.collect();
|
||||
|
||||
let seeds = HashSet::new();
|
||||
let result = generate_playlist(&candidates, 10, &seeds);
|
||||
|
||||
assert_eq!(result.len(), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_playlist_respects_count() {
|
||||
let candidates: Vec<Candidate> = (1..=5).map(|i| make_candidate(i, "Artist", 1.0)).collect();
|
||||
|
||||
let seeds = HashSet::new();
|
||||
let result = generate_playlist(&candidates, 3, &seeds);
|
||||
assert_eq!(result.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_playlist_not_more_than_available() {
|
||||
let candidates: Vec<Candidate> = (1..=3).map(|i| make_candidate(i, "Artist", 1.0)).collect();
|
||||
|
||||
let seeds = HashSet::new();
|
||||
let result = generate_playlist(&candidates, 100, &seeds);
|
||||
assert_eq!(result.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_playlist_empty_candidates() {
|
||||
let candidates: Vec<Candidate> = vec![];
|
||||
let seeds = HashSet::new();
|
||||
let result = generate_playlist(&candidates, 10, &seeds);
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_playlist_per_artist_cap() {
|
||||
// 20 tracks from one artist, 5 from another
|
||||
let mut candidates: Vec<Candidate> = (1..=20)
|
||||
.map(|i| make_candidate(i, "Prolific", 1.0))
|
||||
.collect();
|
||||
candidates.extend((21..=25).map(|i| make_candidate(i, "Minor", 1.0)));
|
||||
|
||||
let seeds = HashSet::new();
|
||||
let result = generate_playlist(&candidates, 15, &seeds);
|
||||
|
||||
let prolific_count = result.iter().filter(|c| c.artist == "Prolific").count();
|
||||
let minor_count = result.iter().filter(|c| c.artist == "Minor").count();
|
||||
|
||||
// With 2 artists and n=15, cap = ceil(15/2) = 8.
|
||||
// Minor only has 5 tracks, so Prolific fills the rest via fallback.
|
||||
// Key check: both artists are represented.
|
||||
assert!(
|
||||
minor_count >= 1,
|
||||
"Minor should get at least 1 track, got {minor_count}"
|
||||
);
|
||||
assert!(
|
||||
prolific_count >= 1,
|
||||
"Prolific should get at least 1 track, got {prolific_count}"
|
||||
);
|
||||
assert_eq!(result.len(), 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_playlist_seed_enforcement() {
|
||||
// Many tracks from "Other" with high scores, few from "Seed" with low scores
|
||||
let mut candidates: Vec<Candidate> =
|
||||
(1..=50).map(|i| make_candidate(i, "Other", 10.0)).collect();
|
||||
candidates.push(make_candidate(51, "Seed", 0.01));
|
||||
candidates.push(make_candidate(52, "Seed", 0.01));
|
||||
|
||||
let mut seeds = HashSet::new();
|
||||
seeds.insert("Seed".to_string());
|
||||
|
||||
let result = generate_playlist(&candidates, 10, &seeds);
|
||||
let seed_count = result.iter().filter(|c| c.artist == "Seed").count();
|
||||
|
||||
// seed_min = (10/10).max(1) = 1, so at least 1 seed track
|
||||
assert!(
|
||||
seed_count >= 1,
|
||||
"Expected at least 1 seed track, got {seed_count}"
|
||||
);
|
||||
}
|
||||
|
||||
// --- Ordering tests ---
|
||||
|
||||
#[test]
|
||||
fn test_interleave_no_back_to_back() {
|
||||
let tracks: Vec<Candidate> = vec![
|
||||
make_candidate(1, "A", 1.0),
|
||||
make_candidate(2, "A", 1.0),
|
||||
make_candidate(3, "A", 1.0),
|
||||
make_candidate(4, "B", 1.0),
|
||||
make_candidate(5, "B", 1.0),
|
||||
make_candidate(6, "B", 1.0),
|
||||
];
|
||||
|
||||
let result = interleave_artists(tracks);
|
||||
assert_eq!(result.len(), 6);
|
||||
|
||||
// Check no back-to-back same artist
|
||||
for window in result.windows(2) {
|
||||
assert_ne!(
|
||||
window[0].artist, window[1].artist,
|
||||
"Back-to-back: {} at positions",
|
||||
window[0].artist
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interleave_single_artist() {
|
||||
let tracks: Vec<Candidate> = (1..=5).map(|i| make_candidate(i, "Solo", 1.0)).collect();
|
||||
|
||||
let result = interleave_artists(tracks);
|
||||
assert_eq!(result.len(), 5);
|
||||
// All same artist, so back-to-back is unavoidable — just check count
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffle_preserves_count() {
|
||||
let tracks: Vec<Candidate> = (1..=10).map(|i| make_candidate(i, "Artist", 1.0)).collect();
|
||||
|
||||
let result = shuffle(tracks);
|
||||
assert_eq!(result.len(), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interleave_many_artists() {
|
||||
let mut tracks = Vec::new();
|
||||
for artist_idx in 0..5 {
|
||||
for track_idx in 0..4 {
|
||||
let id = artist_idx * 4 + track_idx + 1;
|
||||
tracks.push(make_candidate(id, &format!("Artist{artist_idx}"), 1.0));
|
||||
}
|
||||
}
|
||||
|
||||
let result = interleave_artists(tracks);
|
||||
assert_eq!(result.len(), 20);
|
||||
|
||||
// Count back-to-back violations (should be 0 with 5 artists)
|
||||
let violations = result
|
||||
.windows(2)
|
||||
.filter(|w| w[0].artist == w[1].artist)
|
||||
.count();
|
||||
assert_eq!(violations, 0, "Expected no back-to-back with 5 artists");
|
||||
}
|
||||
Submodule shanty-search updated: cbd0243516...b39dd6cc8e
Submodule shanty-web updated: 9d6c0e31c1...ea6a6410f3
@@ -58,7 +58,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let db = Database::new(&config.database_url).await?;
|
||||
|
||||
let mb_client = MusicBrainzFetcher::new()?;
|
||||
let search = MusicBrainzSearch::new()?;
|
||||
let search = MusicBrainzSearch::with_limiter(mb_client.limiter())?;
|
||||
let wiki_fetcher = WikipediaFetcher::new()?;
|
||||
|
||||
let bind = format!("{}:{}", config.web.bind, config.web.port);
|
||||
@@ -77,6 +77,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
scheduler: tokio::sync::Mutex::new(shanty_web::state::SchedulerInfo {
|
||||
next_pipeline: None,
|
||||
next_monitor: None,
|
||||
skip_pipeline: false,
|
||||
skip_monitor: false,
|
||||
}),
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user