Compare commits
28 Commits
144393756e
..
main
| Author | SHA1 | Date | |
|---|---|---|---|
| e1beb5f12b | |||
| 219404263c | |||
| ada336d945 | |||
| 494c8ddecb | |||
| cab680ad5d | |||
| 31c9785ed2 | |||
| e198643c57 | |||
| beb80b8770 | |||
| 3aecde5d0b | |||
| 01365dbb80 | |||
| b2f030b52d | |||
| f77cea47b1 | |||
| cb4105564c | |||
| 68c0f477dd | |||
| 3153518c57 | |||
| b3fd844c15 | |||
| f14ddea805 | |||
| 0d22da6aaa | |||
| a63d72ba48 | |||
| 8dcf40fe7c | |||
| 5e44275bff | |||
| f5fd450aaf | |||
| e3fc3789ce | |||
| 42ab414f83 | |||
| 79440617ba | |||
| 295380d5ad | |||
| 1f68208547 | |||
| 9df13a1c3a |
Generated
+1
@@ -3465,6 +3465,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"clap",
|
||||
"dirs",
|
||||
"futures-util",
|
||||
"hex",
|
||||
"md-5",
|
||||
"quick-xml",
|
||||
|
||||
@@ -59,6 +59,13 @@ impl HybridMusicBrainzFetcher {
|
||||
self.remote.get_artist_by_mbid(mbid).await
|
||||
}
|
||||
|
||||
/// Get artist info from local DB only (no remote API fallback).
|
||||
/// Returns `None` if the local DB is unavailable or doesn't have this artist.
|
||||
pub fn get_artist_info_local(&self, mbid: &str) -> Option<ArtistInfo> {
|
||||
self.local_if_available()
|
||||
.and_then(|l| l.get_artist_info_sync(mbid).ok())
|
||||
}
|
||||
|
||||
/// Get detailed artist info by MBID. Tries local first, then remote.
|
||||
pub async fn get_artist_info(&self, mbid: &str) -> DataResult<ArtistInfo> {
|
||||
if let Some(local) = self.local_if_available()
|
||||
|
||||
+1
-1
Submodule shanty-db updated: b4e0756a90...181f736f25
+1
-1
Submodule shanty-index updated: 4f4e6e794a...11a8d3a88e
@@ -9,5 +9,9 @@ pub mod selection;
|
||||
pub mod strategies;
|
||||
pub mod types;
|
||||
|
||||
pub use strategies::{PlaylistError, genre_based, random, similar_artists, smart, to_m3u};
|
||||
pub use types::{Candidate, PlaylistRequest, PlaylistResult, PlaylistTrack, SmartRules};
|
||||
pub use strategies::{
|
||||
CountryLookup, PlaylistError, genre_based, random, similar_artists, smart, to_m3u,
|
||||
};
|
||||
pub use types::{
|
||||
Candidate, PlaylistRequest, PlaylistResult, PlaylistTrack, SimilarConfig, SmartRules,
|
||||
};
|
||||
|
||||
+102
-15
@@ -21,6 +21,8 @@ pub fn score_tracks(
|
||||
tracks_by_artist: &HashMap<String, Vec<Track>>,
|
||||
top_tracks_by_artist: &HashMap<String, Vec<PopularTrack>>,
|
||||
popularity_bias: u8,
|
||||
_global_popularity: u8,
|
||||
max_tracks_per_artist: Option<u8>,
|
||||
) -> Vec<ScoredTrack> {
|
||||
let bias = popularity_bias.min(10) as usize;
|
||||
let mut scored = Vec::new();
|
||||
@@ -36,12 +38,17 @@ pub fn score_tracks(
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
// Build playcount lookup by lowercase name
|
||||
// Build playcount lookups by lowercase name and by MBID
|
||||
let playcount_by_name: HashMap<String, u64> = top_tracks
|
||||
.iter()
|
||||
.map(|t| (t.name.to_lowercase(), t.playcount))
|
||||
.collect();
|
||||
|
||||
let playcount_by_mbid: HashMap<String, u64> = top_tracks
|
||||
.iter()
|
||||
.filter_map(|t| t.mbid.as_ref().map(|m| (m.clone(), t.playcount)))
|
||||
.collect();
|
||||
|
||||
let max_playcount = playcount_by_name
|
||||
.values()
|
||||
.copied()
|
||||
@@ -52,26 +59,54 @@ pub fn score_tracks(
|
||||
for track in local_tracks {
|
||||
let title_lower = track.title.as_ref().map(|t| t.to_lowercase());
|
||||
|
||||
let playcount = title_lower
|
||||
// Match by: exact title, MBID, and prefix — take the MAXIMUM playcount
|
||||
// across all methods so a popular base track isn't hidden by a less
|
||||
// popular variant that happens to match exactly.
|
||||
let mut best_playcount: Option<u64> = None;
|
||||
let mut consider = |pc: u64| {
|
||||
best_playcount = Some(best_playcount.map_or(pc, |cur: u64| cur.max(pc)));
|
||||
};
|
||||
|
||||
// Exact title match
|
||||
if let Some(pc) = title_lower
|
||||
.as_ref()
|
||||
.and_then(|t| playcount_by_name.get(t).copied())
|
||||
.or_else(|| {
|
||||
track
|
||||
.musicbrainz_id
|
||||
.as_ref()
|
||||
.and_then(|id| playcount_by_name.get(id).copied())
|
||||
});
|
||||
{
|
||||
consider(pc);
|
||||
}
|
||||
|
||||
// If we have popularity data, require a match; otherwise assign uniform score
|
||||
// MBID match
|
||||
if let Some(pc) = track
|
||||
.musicbrainz_id
|
||||
.as_ref()
|
||||
.and_then(|id| playcount_by_mbid.get(id).copied())
|
||||
{
|
||||
consider(pc);
|
||||
}
|
||||
|
||||
// Prefix match: local title starts with a top track name, or vice versa
|
||||
if let Some(local) = title_lower.as_ref()
|
||||
&& let Some((_, &pc)) = playcount_by_name
|
||||
.iter()
|
||||
.filter(|(top_name, _)| {
|
||||
local.starts_with(top_name.as_str()) || top_name.starts_with(local.as_str())
|
||||
})
|
||||
.max_by_key(|&(_, &pc)| pc)
|
||||
{
|
||||
consider(pc);
|
||||
}
|
||||
|
||||
let playcount = best_playcount;
|
||||
|
||||
// If we have popularity data, use it; unmatched tracks get a low base score
|
||||
let (popularity, similarity, score) = if !playcount_by_name.is_empty() {
|
||||
let Some(playcount) = playcount else {
|
||||
continue;
|
||||
};
|
||||
let playcount = playcount.unwrap_or(0);
|
||||
|
||||
let popularity = if playcount > 0 {
|
||||
(playcount as f64 / max_playcount as f64).powf(POPULARITY_EXPONENTS[bias])
|
||||
} else {
|
||||
0.0
|
||||
// Unmatched track: small base score so it can still appear
|
||||
0.01
|
||||
};
|
||||
|
||||
let similarity = (match_score.exp()) / std::f64::consts::E;
|
||||
@@ -108,7 +143,9 @@ pub fn score_tracks(
|
||||
by_artist.entry(key).or_default().push(t);
|
||||
}
|
||||
|
||||
let cap = if popularity_bias == 0 {
|
||||
let cap = if let Some(explicit) = max_tracks_per_artist {
|
||||
Some((explicit as usize).max(1))
|
||||
} else if popularity_bias == 0 {
|
||||
None
|
||||
} else {
|
||||
let b = popularity_bias as f64;
|
||||
@@ -147,5 +184,55 @@ pub fn score_tracks(
|
||||
}
|
||||
}
|
||||
|
||||
by_artist.into_values().flatten().collect()
|
||||
let mut result: Vec<ScoredTrack> = by_artist.into_values().flatten().collect();
|
||||
|
||||
// Step 3: Apply global popularity weighting
|
||||
if _global_popularity > 0 {
|
||||
let gp = _global_popularity.min(10) as usize;
|
||||
let gp_exponent = POPULARITY_EXPONENTS[gp];
|
||||
let gp_strength = _global_popularity as f64 / 10.0;
|
||||
|
||||
// Find max playcount across ALL artists
|
||||
let global_max: u64 = top_tracks_by_artist
|
||||
.values()
|
||||
.flat_map(|tracks| tracks.iter().map(|t| t.playcount))
|
||||
.max()
|
||||
.unwrap_or(1)
|
||||
.max(1);
|
||||
|
||||
// Build a global playcount lookup (lowercase name -> max playcount)
|
||||
let mut global_playcounts: HashMap<String, u64> = HashMap::new();
|
||||
for tracks in top_tracks_by_artist.values() {
|
||||
for t in tracks {
|
||||
let key = t.name.to_lowercase();
|
||||
global_playcounts
|
||||
.entry(key)
|
||||
.and_modify(|c| *c = (*c).max(t.playcount))
|
||||
.or_insert(t.playcount);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply to ALL tracks: popular ones get boosted, unknown ones get reduced.
|
||||
// Factor range: unknown tracks get `1 - gp_strength` (minimum 0.01),
|
||||
// top global track gets 1.0 + gp_strength (up to 2.0 at max setting).
|
||||
for t in &mut result {
|
||||
let playcount = t
|
||||
.title
|
||||
.as_ref()
|
||||
.and_then(|title| global_playcounts.get(&title.to_lowercase()).copied())
|
||||
.unwrap_or(0);
|
||||
|
||||
let global_pop = if playcount > 0 {
|
||||
(playcount as f64 / global_max as f64).powf(gp_exponent)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
// Map global_pop [0, 1] to a factor centered around 1.0:
|
||||
// global_pop=0 → 1.0 - gp_strength, global_pop=1 → 1.0 + gp_strength
|
||||
let factor = (1.0 + gp_strength * (2.0 * global_pop - 1.0)).max(0.01);
|
||||
t.score *= factor;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ pub fn generate_playlist(
|
||||
candidates: &[Candidate],
|
||||
n: usize,
|
||||
seed_names: &HashSet<String>,
|
||||
max_artists: Option<u8>,
|
||||
skip_seed_enforcement: bool,
|
||||
) -> Vec<Candidate> {
|
||||
if candidates.is_empty() {
|
||||
return Vec::new();
|
||||
@@ -20,8 +22,14 @@ pub fn generate_playlist(
|
||||
let mut pool: Vec<&Candidate> = candidates.iter().collect();
|
||||
let mut result: Vec<Candidate> = Vec::new();
|
||||
let mut artist_counts: HashMap<String, usize> = HashMap::new();
|
||||
let mut distinct_artists_set: HashSet<String> = HashSet::new();
|
||||
let max_distinct = max_artists.map(|m| (m as usize).max(1));
|
||||
|
||||
let seed_min = (n / 10).max(1);
|
||||
let seed_min = if skip_seed_enforcement {
|
||||
0
|
||||
} else {
|
||||
(n / 10).max(1)
|
||||
};
|
||||
|
||||
let distinct_artists: usize = {
|
||||
let mut seen = HashSet::new();
|
||||
@@ -54,6 +62,13 @@ pub fn generate_playlist(
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, c)| {
|
||||
// Max distinct artists: reject new artists once we hit the cap
|
||||
if let Some(max) = max_distinct
|
||||
&& distinct_artists_set.len() >= max
|
||||
&& !distinct_artists_set.contains(&c.artist)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if force_seed {
|
||||
seed_names.contains(&c.artist)
|
||||
} else {
|
||||
@@ -79,6 +94,7 @@ pub fn generate_playlist(
|
||||
let picked = indices[dist.sample(&mut rng)];
|
||||
let track = pool.remove(picked);
|
||||
*artist_counts.entry(track.artist.clone()).or_insert(0) += 1;
|
||||
distinct_artists_set.insert(track.artist.clone());
|
||||
result.push(Candidate {
|
||||
score: track.score,
|
||||
artist: track.artist.clone(),
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
|
||||
use sea_orm::DatabaseConnection;
|
||||
use shanty_data::{PopularTrack, SimilarArtist, SimilarArtistFetcher};
|
||||
@@ -12,6 +14,16 @@ use crate::types::*;
|
||||
/// Cache TTL: 7 days in seconds.
|
||||
const CACHE_TTL: i64 = 7 * 24 * 3600;
|
||||
|
||||
/// Trait for looking up an artist's country by MBID.
|
||||
/// Implementations should return quickly (local DB or cache), never blocking
|
||||
/// on rate-limited remote APIs during playlist generation.
|
||||
pub trait CountryLookup: Send + Sync {
|
||||
fn get_country<'a>(
|
||||
&'a self,
|
||||
mbid: &'a str,
|
||||
) -> Pin<Box<dyn Future<Output = Option<String>> + Send + 'a>>;
|
||||
}
|
||||
|
||||
/// Generate a playlist based on similar artists (the primary strategy).
|
||||
///
|
||||
/// Flow:
|
||||
@@ -26,9 +38,8 @@ pub async fn similar_artists(
|
||||
conn: &DatabaseConnection,
|
||||
fetcher: &impl SimilarArtistFetcher,
|
||||
seed_artists: Vec<String>,
|
||||
count: usize,
|
||||
popularity_bias: u8,
|
||||
ordering: &str,
|
||||
config: &SimilarConfig,
|
||||
_country_fetcher: Option<&dyn CountryLookup>,
|
||||
) -> Result<PlaylistResult, PlaylistError> {
|
||||
if seed_artists.is_empty() {
|
||||
return Err(PlaylistError::InvalidInput(
|
||||
@@ -37,25 +48,32 @@ pub async fn similar_artists(
|
||||
}
|
||||
|
||||
let num_seeds = seed_artists.len() as f64;
|
||||
let seed_similarity = config.seed_weight as f64 * 0.2;
|
||||
|
||||
// Merge similar artists from all seeds: key -> (name, total_score)
|
||||
let mut merged: HashMap<String, (String, f64)> = HashMap::new();
|
||||
// Track resolved seed names for enforcement (use DB names, not raw input)
|
||||
let mut resolved_seed_names: HashSet<String> = HashSet::new();
|
||||
// Track which keys are seeds (for country filter)
|
||||
let mut seed_keys: HashSet<String> = HashSet::new();
|
||||
|
||||
for seed in &seed_artists {
|
||||
// Resolve the seed artist: try name lookup in DB
|
||||
let (artist_name, artist_mbid) = resolve_artist(conn, seed).await?;
|
||||
resolved_seed_names.insert(artist_name.clone());
|
||||
|
||||
// Insert the seed itself with score 1.0
|
||||
let key = artist_mbid
|
||||
.clone()
|
||||
.unwrap_or_else(|| artist_name.to_lowercase());
|
||||
let entry = merged
|
||||
.entry(key)
|
||||
.or_insert_with(|| (artist_name.clone(), 0.0));
|
||||
entry.1 += 1.0;
|
||||
seed_keys.insert(key.clone());
|
||||
|
||||
// Insert the seed itself with configured weight
|
||||
if seed_similarity > 0.0 {
|
||||
let entry = merged
|
||||
.entry(key.clone())
|
||||
.or_insert_with(|| (artist_name.clone(), 0.0));
|
||||
entry.1 += seed_similarity;
|
||||
}
|
||||
|
||||
// Fetch similar artists (cached or fresh)
|
||||
let similar = fetch_cached_similar(conn, fetcher, &artist_name, artist_mbid.as_deref())
|
||||
@@ -71,11 +89,41 @@ pub async fn similar_artists(
|
||||
}
|
||||
}
|
||||
|
||||
// Normalize scores by seed count
|
||||
let artists: Vec<(String, String, f64)> = merged
|
||||
// Normalize scores by seed count, sort by similarity descending
|
||||
let mut artists: Vec<(String, String, f64)> = merged
|
||||
.into_iter()
|
||||
.map(|(key, (name, total))| (key, name, total / num_seeds))
|
||||
.collect();
|
||||
artists.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
// Apply discovery range: truncate to pool size
|
||||
let pool_size = discovery_pool_size(config.discovery_range);
|
||||
artists.truncate(pool_size);
|
||||
|
||||
// Country filter: only keep artists from the same countries as seeds
|
||||
if config.country_filter
|
||||
&& let Some(cf) = _country_fetcher
|
||||
{
|
||||
let mut seed_countries: HashSet<String> = HashSet::new();
|
||||
for key in &seed_keys {
|
||||
if let Some(country) = cf.get_country(key).await {
|
||||
seed_countries.insert(country);
|
||||
}
|
||||
}
|
||||
|
||||
if !seed_countries.is_empty() {
|
||||
let mut filtered = Vec::new();
|
||||
for entry in artists {
|
||||
let country = cf.get_country(&entry.0).await;
|
||||
match country {
|
||||
Some(c) if seed_countries.contains(&c) => filtered.push(entry),
|
||||
None => filtered.push(entry), // unknown = pass through
|
||||
_ => {} // known but different = exclude
|
||||
}
|
||||
}
|
||||
artists = filtered;
|
||||
}
|
||||
}
|
||||
|
||||
// Build track and top-track maps for scoring
|
||||
let mut tracks_by_artist: HashMap<String, Vec<shanty_db::entities::track::Model>> =
|
||||
@@ -104,7 +152,9 @@ pub async fn similar_artists(
|
||||
&artists,
|
||||
&tracks_by_artist,
|
||||
&top_tracks_by_artist,
|
||||
popularity_bias,
|
||||
config.popularity_bias,
|
||||
config.global_popularity,
|
||||
config.max_tracks_per_artist,
|
||||
);
|
||||
|
||||
// Convert to candidates
|
||||
@@ -123,10 +173,17 @@ pub async fn similar_artists(
|
||||
.collect();
|
||||
|
||||
// Select (use resolved DB names for seed enforcement, not raw input)
|
||||
let selected = selection::generate_playlist(&candidates, count, &resolved_seed_names);
|
||||
let skip_seed_enforcement = config.seed_weight == 0;
|
||||
let selected = selection::generate_playlist(
|
||||
&candidates,
|
||||
config.count,
|
||||
&resolved_seed_names,
|
||||
config.max_artists,
|
||||
skip_seed_enforcement,
|
||||
);
|
||||
|
||||
// Order
|
||||
let ordered = apply_ordering(selected, ordering);
|
||||
let ordered = apply_ordering(selected, &config.ordering);
|
||||
|
||||
Ok(PlaylistResult {
|
||||
tracks: candidates_to_tracks(ordered),
|
||||
@@ -135,6 +192,14 @@ pub async fn similar_artists(
|
||||
})
|
||||
}
|
||||
|
||||
/// Map discovery_range (0-10) to artist pool size.
|
||||
/// 0 -> 15, 5 -> ~100, 10 -> 500 (exponential curve).
|
||||
fn discovery_pool_size(range: u8) -> usize {
|
||||
let r = range.min(10) as f64;
|
||||
let size = 15.0 * (500.0_f64 / 15.0).powf(r / 10.0);
|
||||
size.round() as usize
|
||||
}
|
||||
|
||||
/// Generate a genre-based playlist.
|
||||
pub async fn genre_based(
|
||||
conn: &DatabaseConnection,
|
||||
@@ -176,7 +241,7 @@ pub async fn genre_based(
|
||||
.collect();
|
||||
|
||||
let seed_names = HashSet::new();
|
||||
let selected = selection::generate_playlist(&candidates, count, &seed_names);
|
||||
let selected = selection::generate_playlist(&candidates, count, &seed_names, None, true);
|
||||
let ordered = apply_ordering(selected, ordering);
|
||||
|
||||
Ok(PlaylistResult {
|
||||
@@ -306,7 +371,7 @@ pub async fn smart(
|
||||
.collect();
|
||||
|
||||
let seed_names = HashSet::new();
|
||||
let selected = selection::generate_playlist(&candidates, count, &seed_names);
|
||||
let selected = selection::generate_playlist(&candidates, count, &seed_names, None, true);
|
||||
let ordered = ordering::interleave_artists(selected);
|
||||
|
||||
Ok(PlaylistResult {
|
||||
|
||||
@@ -17,10 +17,62 @@ pub struct PlaylistRequest {
|
||||
pub ordering: String,
|
||||
#[serde(default)]
|
||||
pub rules: Option<SmartRules>,
|
||||
|
||||
/// Discovery range: how many similar artists to consider (0-10).
|
||||
/// 0 = focused (~15), 10 = wide open (~500). Default: 5.
|
||||
#[serde(default)]
|
||||
pub discovery_range: Option<u8>,
|
||||
/// Global popularity weighting (0-10). 0 = off, 10 = strong bias toward
|
||||
/// globally popular tracks across all artists. Default: 0.
|
||||
#[serde(default)]
|
||||
pub global_popularity: Option<u8>,
|
||||
/// Filter to same countries as seed artists. Default: false.
|
||||
#[serde(default)]
|
||||
pub country_filter: Option<bool>,
|
||||
/// Seed artist weight (0-10). 0 = exclude seeds, 5 = normal (similarity 1.0),
|
||||
/// 10 = double weight (similarity 2.0). Default: 5.
|
||||
#[serde(default)]
|
||||
pub seed_weight: Option<u8>,
|
||||
/// Explicit per-artist track cap. None or 0 = auto (derived from popularity_bias).
|
||||
#[serde(default)]
|
||||
pub max_tracks_per_artist: Option<u8>,
|
||||
/// Maximum distinct artists in the result. None or 0 = unlimited.
|
||||
#[serde(default)]
|
||||
pub max_artists: Option<u8>,
|
||||
}
|
||||
|
||||
/// Resolved configuration for the similar-artists strategy.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SimilarConfig {
|
||||
pub count: usize,
|
||||
pub popularity_bias: u8,
|
||||
pub ordering: String,
|
||||
pub discovery_range: u8,
|
||||
pub global_popularity: u8,
|
||||
pub country_filter: bool,
|
||||
pub seed_weight: u8,
|
||||
pub max_tracks_per_artist: Option<u8>,
|
||||
pub max_artists: Option<u8>,
|
||||
}
|
||||
|
||||
impl SimilarConfig {
|
||||
pub fn from_request(req: &PlaylistRequest) -> Self {
|
||||
Self {
|
||||
count: req.count,
|
||||
popularity_bias: req.popularity_bias,
|
||||
ordering: req.ordering.clone(),
|
||||
discovery_range: req.discovery_range.unwrap_or(5),
|
||||
global_popularity: req.global_popularity.unwrap_or(0),
|
||||
country_filter: req.country_filter.unwrap_or(false),
|
||||
seed_weight: req.seed_weight.unwrap_or(5),
|
||||
max_tracks_per_artist: req.max_tracks_per_artist.filter(|&v| v > 0),
|
||||
max_artists: req.max_artists.filter(|&v| v > 0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_count() -> usize {
|
||||
50
|
||||
30
|
||||
}
|
||||
|
||||
fn default_popularity_bias() -> u8 {
|
||||
|
||||
@@ -83,15 +83,36 @@ fn test_score_tracks_basic() {
|
||||
let mut top_map = HashMap::new();
|
||||
top_map.insert("artist-1".to_string(), top_tracks);
|
||||
|
||||
let scored = score_tracks(&artists, &tracks_map, &top_map, 5);
|
||||
let scored = score_tracks(&artists, &tracks_map, &top_map, 5, 0, None);
|
||||
|
||||
// Should have 3 tracks (only ones matching top tracks)
|
||||
assert_eq!(scored.len(), 3);
|
||||
// All 5 tracks should be included (unmatched get a small base score)
|
||||
assert_eq!(scored.len(), 5);
|
||||
|
||||
// Higher playcount should yield higher score
|
||||
let scores: Vec<f64> = scored.iter().map(|t| t.score).collect();
|
||||
let max_score = scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
|
||||
assert!(max_score > 0.0);
|
||||
// Matched tracks should score higher than unmatched ones
|
||||
let mut matched: Vec<f64> = scored
|
||||
.iter()
|
||||
.filter(|t| {
|
||||
t.title.as_deref().is_some_and(|n| {
|
||||
n.starts_with("Song 1") || n.starts_with("Song 2") || n.starts_with("Song 3")
|
||||
})
|
||||
})
|
||||
.map(|t| t.score)
|
||||
.collect();
|
||||
let mut unmatched: Vec<f64> = scored
|
||||
.iter()
|
||||
.filter(|t| {
|
||||
t.title
|
||||
.as_deref()
|
||||
.is_some_and(|n| n.starts_with("Song 4") || n.starts_with("Song 5"))
|
||||
})
|
||||
.map(|t| t.score)
|
||||
.collect();
|
||||
matched.sort_by(|a, b| b.partial_cmp(a).unwrap());
|
||||
unmatched.sort_by(|a, b| b.partial_cmp(a).unwrap());
|
||||
assert!(
|
||||
matched[0] > unmatched[0],
|
||||
"matched tracks should score higher than unmatched"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -107,7 +128,7 @@ fn test_score_tracks_no_top_tracks_uses_uniform() {
|
||||
|
||||
let top_map = HashMap::new(); // No top tracks
|
||||
|
||||
let scored = score_tracks(&artists, &tracks_map, &top_map, 5);
|
||||
let scored = score_tracks(&artists, &tracks_map, &top_map, 5, 0, None);
|
||||
|
||||
// All 3 tracks should be included with uniform scoring
|
||||
assert_eq!(scored.len(), 3);
|
||||
@@ -141,11 +162,11 @@ fn test_score_tracks_per_artist_cap() {
|
||||
top_map.insert("artist-1".to_string(), top_tracks);
|
||||
|
||||
// bias 10 → cap = 10
|
||||
let scored = score_tracks(&artists, &tracks_map, &top_map, 10);
|
||||
let scored = score_tracks(&artists, &tracks_map, &top_map, 10, 0, None);
|
||||
assert!(scored.len() <= 10);
|
||||
|
||||
// bias 0 → no cap
|
||||
let scored_no_cap = score_tracks(&artists, &tracks_map, &top_map, 0);
|
||||
let scored_no_cap = score_tracks(&artists, &tracks_map, &top_map, 0, 0, None);
|
||||
assert_eq!(scored_no_cap.len(), 50);
|
||||
}
|
||||
|
||||
@@ -164,7 +185,7 @@ fn test_similarity_transform() {
|
||||
tracks_map.insert("high".to_string(), vec![track_high]);
|
||||
tracks_map.insert("low".to_string(), vec![track_low]);
|
||||
|
||||
let scored = score_tracks(&artists, &tracks_map, &HashMap::new(), 5);
|
||||
let scored = score_tracks(&artists, &tracks_map, &HashMap::new(), 5, 0, None);
|
||||
assert_eq!(scored.len(), 2);
|
||||
|
||||
let high_score = scored
|
||||
@@ -189,7 +210,7 @@ fn test_generate_playlist_basic() {
|
||||
.collect();
|
||||
|
||||
let seeds = HashSet::new();
|
||||
let result = generate_playlist(&candidates, 10, &seeds);
|
||||
let result = generate_playlist(&candidates, 10, &seeds, None, false);
|
||||
|
||||
assert_eq!(result.len(), 10);
|
||||
}
|
||||
@@ -199,7 +220,7 @@ fn test_generate_playlist_respects_count() {
|
||||
let candidates: Vec<Candidate> = (1..=5).map(|i| make_candidate(i, "Artist", 1.0)).collect();
|
||||
|
||||
let seeds = HashSet::new();
|
||||
let result = generate_playlist(&candidates, 3, &seeds);
|
||||
let result = generate_playlist(&candidates, 3, &seeds, None, false);
|
||||
assert_eq!(result.len(), 3);
|
||||
}
|
||||
|
||||
@@ -208,7 +229,7 @@ fn test_generate_playlist_not_more_than_available() {
|
||||
let candidates: Vec<Candidate> = (1..=3).map(|i| make_candidate(i, "Artist", 1.0)).collect();
|
||||
|
||||
let seeds = HashSet::new();
|
||||
let result = generate_playlist(&candidates, 100, &seeds);
|
||||
let result = generate_playlist(&candidates, 100, &seeds, None, false);
|
||||
assert_eq!(result.len(), 3);
|
||||
}
|
||||
|
||||
@@ -216,7 +237,7 @@ fn test_generate_playlist_not_more_than_available() {
|
||||
fn test_generate_playlist_empty_candidates() {
|
||||
let candidates: Vec<Candidate> = vec![];
|
||||
let seeds = HashSet::new();
|
||||
let result = generate_playlist(&candidates, 10, &seeds);
|
||||
let result = generate_playlist(&candidates, 10, &seeds, None, false);
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
|
||||
@@ -229,7 +250,7 @@ fn test_generate_playlist_per_artist_cap() {
|
||||
candidates.extend((21..=25).map(|i| make_candidate(i, "Minor", 1.0)));
|
||||
|
||||
let seeds = HashSet::new();
|
||||
let result = generate_playlist(&candidates, 15, &seeds);
|
||||
let result = generate_playlist(&candidates, 15, &seeds, None, false);
|
||||
|
||||
let prolific_count = result.iter().filter(|c| c.artist == "Prolific").count();
|
||||
let minor_count = result.iter().filter(|c| c.artist == "Minor").count();
|
||||
@@ -259,7 +280,7 @@ fn test_generate_playlist_seed_enforcement() {
|
||||
let mut seeds = HashSet::new();
|
||||
seeds.insert("Seed".to_string());
|
||||
|
||||
let result = generate_playlist(&candidates, 10, &seeds);
|
||||
let result = generate_playlist(&candidates, 10, &seeds, None, false);
|
||||
let seed_count = result.iter().filter(|c| c.artist == "Seed").count();
|
||||
|
||||
// seed_min = (10/10).max(1) = 1, so at least 1 seed track
|
||||
|
||||
+1
-1
Submodule shanty-tag updated: 042a137121...0f5d3f597a
+1
-1
Submodule shanty-watch updated: 827944170a...3593698854
+1
-1
Submodule shanty-web updated: 00d4e8d3e0...d17049d92a
Reference in New Issue
Block a user