feat: per-session model persistence (model no longer bleeds across sessions)
This commit is contained in:
parent
4b08855e13
commit
23e70b6f0b
1 changed files with 67 additions and 19 deletions
|
|
@ -16,6 +16,12 @@ const getPermissionModesForProvider = (provider: LLMProvider): PermissionMode[]
|
||||||
return ['default', 'acceptEdits', 'bypassPermissions', 'plan'];
|
return ['default', 'acceptEdits', 'bypassPermissions', 'plan'];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Per-session model persistence keys.
|
||||||
|
// Global key (`${provider}-model`) holds the "last used" default for NEW sessions.
|
||||||
|
// Session key (`${provider}-model-${sessionId}`) pins the model for that session.
|
||||||
|
const globalModelKey = (provider: string) => `${provider}-model`;
|
||||||
|
const sessionModelKey = (provider: string, sessionId: string) => `${provider}-model-${sessionId}`;
|
||||||
|
|
||||||
interface UseChatProviderStateArgs {
|
interface UseChatProviderStateArgs {
|
||||||
selectedSession: ProjectSession | null;
|
selectedSession: ProjectSession | null;
|
||||||
}
|
}
|
||||||
|
|
@ -26,18 +32,21 @@ export function useChatProviderState({ selectedSession }: UseChatProviderStateAr
|
||||||
const [provider, setProvider] = useState<LLMProvider>(() => {
|
const [provider, setProvider] = useState<LLMProvider>(() => {
|
||||||
return (localStorage.getItem('selected-provider') as LLMProvider) || 'claude';
|
return (localStorage.getItem('selected-provider') as LLMProvider) || 'claude';
|
||||||
});
|
});
|
||||||
const [cursorModel, setCursorModel] = useState<string>(() => {
|
|
||||||
return localStorage.getItem('cursor-model') || CURSOR_MODELS.DEFAULT;
|
// Initial model values: prefer the per-session pin if a session is already
|
||||||
});
|
// selected on mount, otherwise the global last-used default.
|
||||||
const [claudeModel, setClaudeModel] = useState<string>(() => {
|
const initialModel = (p: string, fallback: string) => {
|
||||||
return localStorage.getItem('claude-model') || CLAUDE_MODELS.DEFAULT;
|
if (selectedSession?.id) {
|
||||||
});
|
const pinned = localStorage.getItem(sessionModelKey(p, selectedSession.id));
|
||||||
const [codexModel, setCodexModel] = useState<string>(() => {
|
if (pinned) return pinned;
|
||||||
return localStorage.getItem('codex-model') || CODEX_MODELS.DEFAULT;
|
}
|
||||||
});
|
return localStorage.getItem(globalModelKey(p)) || fallback;
|
||||||
const [geminiModel, setGeminiModel] = useState<string>(() => {
|
};
|
||||||
return localStorage.getItem('gemini-model') || GEMINI_MODELS.DEFAULT;
|
|
||||||
});
|
const [cursorModel, setCursorModelState] = useState<string>(() => initialModel('cursor', CURSOR_MODELS.DEFAULT));
|
||||||
|
const [claudeModel, setClaudeModelState] = useState<string>(() => initialModel('claude', CLAUDE_MODELS.DEFAULT));
|
||||||
|
const [codexModel, setCodexModelState] = useState<string>(() => initialModel('codex', CODEX_MODELS.DEFAULT));
|
||||||
|
const [geminiModel, setGeminiModelState] = useState<string>(() => initialModel('gemini', GEMINI_MODELS.DEFAULT));
|
||||||
|
|
||||||
// Live model lists — fall back to static constants until API responds
|
// Live model lists — fall back to static constants until API responds
|
||||||
const [claudeModelOptions, setClaudeModelOptions] = useState<ModelOption[]>(CLAUDE_MODELS.OPTIONS);
|
const [claudeModelOptions, setClaudeModelOptions] = useState<ModelOption[]>(CLAUDE_MODELS.OPTIONS);
|
||||||
|
|
@ -45,7 +54,43 @@ export function useChatProviderState({ selectedSession }: UseChatProviderStateAr
|
||||||
const [geminiModelOptions] = useState<ModelOption[]>(GEMINI_MODELS.OPTIONS);
|
const [geminiModelOptions] = useState<ModelOption[]>(GEMINI_MODELS.OPTIONS);
|
||||||
const [cursorModelOptions] = useState<ModelOption[]>(CURSOR_MODELS.OPTIONS);
|
const [cursorModelOptions] = useState<ModelOption[]>(CURSOR_MODELS.OPTIONS);
|
||||||
|
|
||||||
// Fetch live model list and validate the saved claude model
|
// Persisted setters: write BOTH the per-session pin (if a session is active)
|
||||||
|
// and the global last-used default, then update state.
|
||||||
|
const makePersistedSetter = useCallback(
|
||||||
|
(p: string, setState: (v: string) => void) => (value: string) => {
|
||||||
|
setState(value);
|
||||||
|
localStorage.setItem(globalModelKey(p), value);
|
||||||
|
if (selectedSession?.id) {
|
||||||
|
localStorage.setItem(sessionModelKey(p, selectedSession.id), value);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[selectedSession?.id],
|
||||||
|
);
|
||||||
|
|
||||||
|
const setCursorModel = useCallback(makePersistedSetter('cursor', setCursorModelState), [makePersistedSetter]);
|
||||||
|
const setClaudeModel = useCallback(makePersistedSetter('claude', setClaudeModelState), [makePersistedSetter]);
|
||||||
|
const setCodexModel = useCallback(makePersistedSetter('codex', setCodexModelState), [makePersistedSetter]);
|
||||||
|
const setGeminiModel = useCallback(makePersistedSetter('gemini', setGeminiModelState), [makePersistedSetter]);
|
||||||
|
|
||||||
|
// When the selected session changes, load each provider's pinned model for
|
||||||
|
// that session (falling back to the global default). This is what stops a
|
||||||
|
// model picked in session A from leaking into session B.
|
||||||
|
useEffect(() => {
|
||||||
|
if (!selectedSession?.id) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const load = (p: string, setState: (v: string) => void, fallback: string) => {
|
||||||
|
const pinned = localStorage.getItem(sessionModelKey(p, selectedSession.id));
|
||||||
|
const next = pinned || localStorage.getItem(globalModelKey(p)) || fallback;
|
||||||
|
setState(next);
|
||||||
|
};
|
||||||
|
load('cursor', setCursorModelState, CURSOR_MODELS.DEFAULT);
|
||||||
|
load('claude', setClaudeModelState, CLAUDE_MODELS.DEFAULT);
|
||||||
|
load('codex', setCodexModelState, CODEX_MODELS.DEFAULT);
|
||||||
|
load('gemini', setGeminiModelState, GEMINI_MODELS.DEFAULT);
|
||||||
|
}, [selectedSession?.id]);
|
||||||
|
|
||||||
|
// Fetch live Claude model list and validate the current claude model
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
authenticatedFetch('/api/models')
|
authenticatedFetch('/api/models')
|
||||||
.then((res) => {
|
.then((res) => {
|
||||||
|
|
@ -57,18 +102,21 @@ export function useChatProviderState({ selectedSession }: UseChatProviderStateAr
|
||||||
const options: ModelOption[] = data.claude;
|
const options: ModelOption[] = data.claude;
|
||||||
setClaudeModelOptions(options);
|
setClaudeModelOptions(options);
|
||||||
|
|
||||||
// Validate saved model — if it's no longer in the list, reset to default
|
setClaudeModelState((current) => {
|
||||||
setClaudeModel((current) => {
|
|
||||||
const valid = options.some((o) => o.value === current);
|
const valid = options.some((o) => o.value === current);
|
||||||
if (valid) return current;
|
if (valid) return current;
|
||||||
const fallback = options[0]?.value ?? CLAUDE_MODELS.DEFAULT;
|
const fallback = options[0]?.value ?? CLAUDE_MODELS.DEFAULT;
|
||||||
localStorage.setItem('claude-model', fallback);
|
localStorage.setItem(globalModelKey('claude'), fallback);
|
||||||
|
if (selectedSession?.id) {
|
||||||
|
localStorage.setItem(sessionModelKey('claude', selectedSession.id), fallback);
|
||||||
|
}
|
||||||
return fallback;
|
return fallback;
|
||||||
});
|
});
|
||||||
})
|
})
|
||||||
.catch(() => {
|
.catch(() => {
|
||||||
// Static fallback already in place — nothing to do
|
// Static fallback already in place
|
||||||
});
|
});
|
||||||
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const lastProviderRef = useRef(provider);
|
const lastProviderRef = useRef(provider);
|
||||||
|
|
@ -119,8 +167,8 @@ export function useChatProviderState({ selectedSession }: UseChatProviderStateAr
|
||||||
}
|
}
|
||||||
|
|
||||||
const modelId = data.config.model.modelId as string;
|
const modelId = data.config.model.modelId as string;
|
||||||
if (!localStorage.getItem('cursor-model')) {
|
if (!localStorage.getItem(globalModelKey('cursor'))) {
|
||||||
setCursorModel(modelId);
|
setCursorModelState(modelId);
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue