mirror of
https://github.com/lone-cloud/gerbil
synced 2026-06-03 19:54:44 -07:00
359 lines
10 KiB
TypeScript
359 lines
10 KiB
TypeScript
import { join, basename, dirname } from 'path';
|
|
import { mkdir, readdir, stat, rename, unlink } from 'fs/promises';
|
|
import { createWriteStream } from 'fs';
|
|
import { get as httpGet } from 'http';
|
|
import { get as httpsGet } from 'https';
|
|
import { getInstallDir } from '@/main/modules/config';
|
|
import { pathExists } from '@/utils/node/fs';
|
|
import { logError } from '@/utils/node/logging';
|
|
import { sendKoboldOutput } from '@/main/modules/window';
|
|
import type { ModelParamType, CachedModel } from '@/types';
|
|
import type { IncomingMessage } from 'http';
|
|
|
|
const activeDownloads = new Set<{
|
|
abort: () => void;
|
|
tempPath: string;
|
|
}>();
|
|
|
|
interface DownloadProgress {
|
|
type: 'progress' | 'complete' | 'error';
|
|
percent?: number;
|
|
downloaded?: string;
|
|
total?: string;
|
|
speed?: string;
|
|
eta?: string;
|
|
error?: string;
|
|
localPath?: string;
|
|
}
|
|
|
|
function parseHuggingFaceUrl(url: string) {
|
|
const hfMatch = url.match(
|
|
/huggingface\.co\/([^/]+)\/([^/]+)\/(?:resolve|blob)\/[^/]+\/(.+)/
|
|
);
|
|
|
|
if (hfMatch) {
|
|
const pathWithQuery = hfMatch[3];
|
|
const pathWithoutQuery = pathWithQuery.split('?')[0];
|
|
return {
|
|
author: hfMatch[1],
|
|
model: hfMatch[2],
|
|
filename: basename(pathWithoutQuery),
|
|
};
|
|
}
|
|
|
|
return {
|
|
author: 'external',
|
|
model: 'models',
|
|
filename: basename(url.split('?')[0]),
|
|
};
|
|
}
|
|
|
|
export function getModelLocalPath(url: string, paramType: ModelParamType) {
|
|
const installDir = getInstallDir();
|
|
const { author, model, filename } = parseHuggingFaceUrl(url);
|
|
|
|
return join(installDir, 'models', paramType, author, model, filename);
|
|
}
|
|
|
|
function normalizeUrl(url: string) {
|
|
if (url.includes('huggingface.co') && url.includes('/blob/')) {
|
|
return url.replace('/blob/', '/resolve/');
|
|
}
|
|
return url;
|
|
}
|
|
|
|
async function downloadFile(
|
|
url: string,
|
|
outputPath: string,
|
|
onProgress: (progress: DownloadProgress) => void
|
|
) {
|
|
const normalizedUrl = normalizeUrl(url);
|
|
const outputDir = dirname(outputPath);
|
|
const tempPath = `${outputPath}.tmp`;
|
|
|
|
await mkdir(outputDir, { recursive: true });
|
|
|
|
return new Promise<boolean>((resolve, reject) => {
|
|
const httpModule = normalizedUrl.startsWith('https') ? httpsGet : httpGet;
|
|
let currentRequest: IncomingMessage | null = null;
|
|
let fileStream: ReturnType<typeof createWriteStream>;
|
|
let isAborted = false;
|
|
|
|
const cleanup = async () => {
|
|
await new Promise<void>((resolve) => {
|
|
if (fileStream) {
|
|
fileStream.close(() => resolve());
|
|
} else {
|
|
resolve();
|
|
}
|
|
});
|
|
await unlink(tempPath).catch(() => void 0);
|
|
};
|
|
|
|
const abortController = {
|
|
abort: () => {
|
|
isAborted = true;
|
|
currentRequest?.destroy();
|
|
cleanup();
|
|
reject(new Error('Download aborted by user'));
|
|
},
|
|
tempPath,
|
|
};
|
|
|
|
activeDownloads.add(abortController);
|
|
|
|
const handleRedirect = (requestUrl: string, redirectCount = 0): void => {
|
|
if (isAborted) return;
|
|
|
|
if (redirectCount > 10) {
|
|
activeDownloads.delete(abortController);
|
|
reject(new Error('Too many redirects'));
|
|
return;
|
|
}
|
|
|
|
httpModule(requestUrl, (response) => {
|
|
if (isAborted) return;
|
|
currentRequest = response;
|
|
if (
|
|
response.statusCode === 301 ||
|
|
response.statusCode === 302 ||
|
|
response.statusCode === 307 ||
|
|
response.statusCode === 308
|
|
) {
|
|
const redirectUrl = response.headers.location;
|
|
if (redirectUrl) {
|
|
handleRedirect(redirectUrl, redirectCount + 1);
|
|
} else {
|
|
reject(new Error('Redirect without location header'));
|
|
}
|
|
return;
|
|
}
|
|
|
|
if (response.statusCode !== 200) {
|
|
activeDownloads.delete(abortController);
|
|
reject(
|
|
new Error(`HTTP ${response.statusCode}: ${response.statusMessage}`)
|
|
);
|
|
return;
|
|
}
|
|
|
|
const totalBytes = parseInt(response.headers['content-length'] || '0');
|
|
let downloadedBytes = 0;
|
|
let lastReportTime = Date.now();
|
|
let lastReportedBytes = 0;
|
|
|
|
fileStream = createWriteStream(tempPath);
|
|
|
|
response.on('data', (chunk: Buffer) => {
|
|
downloadedBytes += chunk.length;
|
|
fileStream.write(chunk);
|
|
|
|
const now = Date.now();
|
|
const timeDiff = (now - lastReportTime) / 1000;
|
|
|
|
if (timeDiff >= 0.5) {
|
|
const bytesDiff = downloadedBytes - lastReportedBytes;
|
|
const speedBytesPerSec = bytesDiff / timeDiff;
|
|
const percent = totalBytes
|
|
? Math.round((downloadedBytes / totalBytes) * 100)
|
|
: 0;
|
|
const downloadedMB = (downloadedBytes / 1024 / 1024).toFixed(2);
|
|
const totalMB = (totalBytes / 1024 / 1024).toFixed(2);
|
|
const speedMBPerSec = (speedBytesPerSec / 1024 / 1024).toFixed(2);
|
|
|
|
const remainingBytes = totalBytes - downloadedBytes;
|
|
const etaSeconds = speedBytesPerSec
|
|
? Math.round(remainingBytes / speedBytesPerSec)
|
|
: 0;
|
|
const etaMinutes = Math.floor(etaSeconds / 60);
|
|
const etaSecondsRemainder = etaSeconds % 60;
|
|
const etaStr =
|
|
etaMinutes > 0
|
|
? `${etaMinutes}m${etaSecondsRemainder}s`
|
|
: `${etaSeconds}s`;
|
|
|
|
const progressMsg = totalBytes
|
|
? `Downloaded ${downloadedMB}MB / ${totalMB}MB (${percent}%) - ${speedMBPerSec}MB/s - ETA: ${etaStr}`
|
|
: `Downloaded ${downloadedMB}MB - ${speedMBPerSec}MB/s`;
|
|
|
|
sendKoboldOutput(`\r${progressMsg}`);
|
|
|
|
onProgress({
|
|
type: 'progress',
|
|
percent,
|
|
downloaded: `${downloadedMB}MB`,
|
|
total: totalBytes ? `${totalMB}MB` : undefined,
|
|
speed: `${speedMBPerSec}MB/s`,
|
|
eta: totalBytes ? etaStr : undefined,
|
|
});
|
|
|
|
lastReportTime = now;
|
|
lastReportedBytes = downloadedBytes;
|
|
}
|
|
});
|
|
|
|
response.on('end', () => {
|
|
if (isAborted) return;
|
|
|
|
if (totalBytes > 0 && downloadedBytes !== totalBytes) {
|
|
activeDownloads.delete(abortController);
|
|
cleanup();
|
|
reject(
|
|
new Error(
|
|
`Incomplete download: received ${downloadedBytes} bytes, expected ${totalBytes} bytes`
|
|
)
|
|
);
|
|
return;
|
|
}
|
|
|
|
fileStream.end();
|
|
fileStream.on('finish', async () => {
|
|
try {
|
|
await rename(tempPath, outputPath);
|
|
sendKoboldOutput('\n');
|
|
activeDownloads.delete(abortController);
|
|
resolve(true);
|
|
} catch (err) {
|
|
activeDownloads.delete(abortController);
|
|
reject(err instanceof Error ? err : new Error(String(err)));
|
|
}
|
|
});
|
|
});
|
|
|
|
response.on('error', async (err) => {
|
|
if (isAborted) return;
|
|
activeDownloads.delete(abortController);
|
|
await cleanup();
|
|
reject(err);
|
|
});
|
|
|
|
fileStream.on('error', async (err) => {
|
|
if (isAborted) return;
|
|
activeDownloads.delete(abortController);
|
|
await cleanup();
|
|
reject(err);
|
|
});
|
|
}).on('error', (err) => {
|
|
if (isAborted) return;
|
|
activeDownloads.delete(abortController);
|
|
reject(err);
|
|
});
|
|
};
|
|
|
|
handleRedirect(normalizedUrl);
|
|
});
|
|
}
|
|
|
|
function isValidModelUrl(url: string) {
|
|
if (!url || url.trim() === '') {
|
|
return false;
|
|
}
|
|
|
|
if (!url.startsWith('http://') && !url.startsWith('https://')) {
|
|
return false;
|
|
}
|
|
|
|
const validExtensions = ['.gguf', '.safetensors', '.bin', '.ggml'];
|
|
const urlPath = url.split('?')[0];
|
|
return validExtensions.some((ext) => urlPath.toLowerCase().endsWith(ext));
|
|
}
|
|
|
|
export async function resolveModelPath(
|
|
urlOrPath: string,
|
|
paramType: ModelParamType,
|
|
onProgress?: (progress: DownloadProgress) => void
|
|
) {
|
|
if (!isValidModelUrl(urlOrPath)) {
|
|
return urlOrPath;
|
|
}
|
|
|
|
const localPath = getModelLocalPath(urlOrPath, paramType);
|
|
|
|
if (await pathExists(localPath)) {
|
|
sendKoboldOutput(`Using cached model at: ${localPath}\n`);
|
|
onProgress?.({
|
|
type: 'complete',
|
|
localPath,
|
|
});
|
|
return localPath;
|
|
}
|
|
|
|
sendKoboldOutput(`Downloading model from ${urlOrPath} to ${localPath}...\n`);
|
|
|
|
const progressCallback = onProgress || ((p: DownloadProgress) => p);
|
|
|
|
try {
|
|
await downloadFile(urlOrPath, localPath, progressCallback);
|
|
|
|
sendKoboldOutput(`Model downloaded successfully to: ${localPath}\n\n`);
|
|
progressCallback({
|
|
type: 'complete',
|
|
localPath,
|
|
});
|
|
return localPath;
|
|
} catch (error) {
|
|
progressCallback({
|
|
type: 'error',
|
|
error: `Download failed: ${error instanceof Error ? error.message : 'Unknown error'}`,
|
|
});
|
|
throw new Error(
|
|
`Failed to download model from ${urlOrPath}: ${error instanceof Error ? error.message : 'Unknown error'}`
|
|
);
|
|
}
|
|
}
|
|
|
|
export async function getLocalModelsForType(paramType: ModelParamType) {
|
|
const installDir = getInstallDir();
|
|
const modelsDir = join(installDir, 'models', paramType);
|
|
|
|
if (!(await pathExists(modelsDir))) {
|
|
return [];
|
|
}
|
|
|
|
const models: CachedModel[] = [];
|
|
|
|
try {
|
|
const authors = await readdir(modelsDir);
|
|
|
|
for (const author of authors) {
|
|
const authorPath = join(modelsDir, author);
|
|
const authorStat = await stat(authorPath);
|
|
|
|
if (authorStat.isDirectory()) {
|
|
const modelDirs = await readdir(authorPath);
|
|
|
|
for (const modelDir of modelDirs) {
|
|
const modelPath = join(authorPath, modelDir);
|
|
const modelStat = await stat(modelPath);
|
|
|
|
if (modelStat.isDirectory()) {
|
|
const files = await readdir(modelPath);
|
|
|
|
for (const file of files) {
|
|
const filePath = join(modelPath, file);
|
|
const fileStat = await stat(filePath);
|
|
|
|
if (fileStat.isFile()) {
|
|
models.push({
|
|
path: filePath,
|
|
author,
|
|
model: modelDir,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
} catch (error) {
|
|
logError('Error scanning local models:', error as Error);
|
|
}
|
|
|
|
return models;
|
|
}
|
|
|
|
export function abortActiveDownloads() {
|
|
const downloads = Array.from(activeDownloads);
|
|
downloads.forEach((controller) => controller.abort());
|
|
activeDownloads.clear();
|
|
}
|