#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
const int G = 3;
using VI = vector<int>;
int fastpow(int a, int b, int m = MOD) {
int res = 1;
while (b) {
if (b & 1) res = 1LL * res * a % m;
a = 1LL * a * a % m;
b >>= 1;
}
return res;
}
void ntt(vector<int>& a, bool invert) {
int n = a.size();
for (int i = 1, j = 0; i < n; ++i) {
int bit = n >> 1;
for (; j & bit; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
int wlen = fastpow(G, (MOD - 1) / len);
if (invert) wlen = fastpow(wlen, MOD - 2);
for (int i = 0; i < n; i += len) {
int w = 1;
for (int j = 0; j < len / 2; ++j) {
int u = a[i + j], v = 1LL * a[i + j + len / 2] * w % MOD;
a[i + j] = (u + v) % MOD;
a[i + j + len / 2] = (u - v + MOD) % MOD;
w = 1LL * w * wlen % MOD;
}
}
}
if (invert) {
int inv_n = fastpow(n, MOD - 2);
for (int& x : a)
x = 1LL * x * inv_n % MOD;
}
}
VI parse(const string& s) {
VI v;
for (char c : s)
v.push_back(c - '0');
reverse(v.begin(), v.end());
return v;
}
VI add(VI a, VI b) {
VI res;
int carry = 0;
for (int i = 0; i < max(a.size(), b.size()) || carry; i++) {
if (i == a.size()) a.push_back(0);
int val = a[i] + (i < b.size() ? b[i] : 0) + carry;
res.push_back(val % 10);
carry = val / 10;
}
while (res.size() > 1 && res.back() == 0) res.pop_back();
reverse(res.begin(), res.end());
return res;
}
VI sub(VI a, VI b) {
// assume a >= b
VI res;
int carry = 0;
for (int i = 0; i < a.size(); i++) {
int val = a[i] - (i < b.size() ? b[i] : 0) - carry;
if (val < 0) val += 10, carry = 1;
else carry = 0;
res.push_back(val);
}
while (res.size() > 1 && res.back() == 0) res.pop_back();
reverse(res.begin(), res.end());
return res;
}
VI multiply(VI a, VI b) {
int n = 1;
while (n < a.size() + b.size()) n <<= 1;
a.resize(n); b.resize(n);
ntt(a, false); ntt(b, false);
for (int i = 0; i < n; i++)
a[i] = 1LL * a[i] * b[i] % MOD;
ntt(a, true);
VI res(n);
int carry = 0;
for (int i = 0; i < n; i++) {
long long cur = a[i] + carry;
res[i] = cur % 10;
carry = cur / 10;
}
while (carry) {
res.push_back(carry % 10);
carry /= 10;
}
while (res.size() > 1 && res.back() == 0) res.pop_back();
reverse(res.begin(), res.end());
return res;
}
void print(VI v) {
for (int d : v) cout << d;
cout << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
string A, B, op;
cin >> A >> op >> B;
auto a = parse(A);
auto b = parse(B);
if (op == "+") {
print(add(a, b));
} else if (op == "-") {
if (A == B) {
cout << 0 << '\n';
} else {
// a > b assumed
print(sub(a, b));
}
} else if (op == "*") {
print(multiply(a, b));
}
return 0;
}
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
const int G = 3;
using VI = vector<int>;
int fastpow(int a, int b) {
int res = 1;
int base = a;
while (b) {
if (b & 1) res = (int)((int64_t)res * base % MOD);
base = (int)((int64_t)base * base % MOD);
b >>= 1;
}
return res;
}
void ntt(VI &a, bool invert) {
int n = (int)a.size();
for (int i = 1, j = 0; i < n; ++i) {
int bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
int wlen = fastpow(G, (MOD - 1) / len);
if (invert) wlen = fastpow(wlen, MOD - 2);
for (int i = 0; i < n; i += len) {
int w = 1;
int half = len >> 1;
for (int j = 0; j < half; ++j) {
int u = a[i + j];
int v = (int)((int64_t)a[i + j + half] * w % MOD);
a[i + j] = u + v < MOD ? u + v : u + v - MOD;
a[i + j + half] = u - v >= 0 ? u - v : u - v + MOD;
w = (int)((int64_t)w * wlen % MOD);
}
}
}
if (invert) {
int inv_n = fastpow(n, MOD - 2);
for (int &x : a)
x = (int)((int64_t)x * inv_n % MOD);
}
}
VI parse(const string &s) {
VI v(s.size());
for (int i = 0; i < (int)s.size(); i++)
v[(int)s.size() - 1 - i] = s[i] - '0'; // 反轉放好,後續不用再反轉
return v;
}
VI add(const VI &a, const VI &b) {
int n = (int)max(a.size(), b.size());
VI res;
res.reserve(n + 1);
int carry = 0;
for (int i = 0; i < n || carry; ++i) {
int x = carry;
if (i < (int)a.size()) x += a[i];
if (i < (int)b.size()) x += b[i];
carry = x / 10;
res.push_back(x % 10);
}
// 去除末尾多餘0
while (res.size() > 1 && res.back() == 0)
res.pop_back();
return res;
}
VI sub(const VI &a, const VI &b) {
// 假設 a >= b
VI res;
res.reserve(a.size());
int carry = 0;
for (int i = 0; i < (int)a.size(); i++) {
int x = a[i] - carry - (i < (int)b.size() ? b[i] : 0);
if (x < 0) x += 10, carry = 1;
else carry = 0;
res.push_back(x);
}
while (res.size() > 1 && res.back() == 0)
res.pop_back();
return res;
}
VI multiply(VI a, VI b) {
int n = 1;
while (n < (int)(a.size() + b.size()))
n <<= 1;
a.resize(n);
b.resize(n);
ntt(a, false);
ntt(b, false);
for (int i = 0; i < n; i++)
a[i] = (int)((int64_t)a[i] * b[i] % MOD);
ntt(a, true);
VI res;
res.reserve(n);
int64_t carry = 0;
for (int i = 0; i < n; i++) {
int64_t cur = a[i] + carry;
res.push_back(int(cur % 10));
carry = cur / 10;
}
while (carry > 0) {
res.push_back(int(carry % 10));
carry /= 10;
}
while (res.size() > 1 && res.back() == 0)
res.pop_back();
return res;
}
void print(const VI &v) {
for (int i = (int)v.size() - 1; i >= 0; i--)
cout << v[i];
cout << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
string A, B, op;
cin >> A >> op >> B;
VI a = parse(A), b = parse(B);
if (op == "+") {
print(add(a, b));
} else if (op == "-") {
if (A == B) {
cout << 0 << '\n';
} else {
print(sub(a, b)); // 假設 A >= B
}
} else if (op == "*") {
print(multiply(a, b));
}
return 0;
}
| AC (0.7s, 43.4MB) |
CPP
|
#include <cstdio> // For scanf, printf
#include <string> // For std::string, std::to_string
#include <vector> // For std::vector
#include <algorithm> // For std::reverse, std::swap
#include <cmath> // For std::acos, std::sin, std::cos, std::round
#include <complex> // For std::complex (use with long double)
#include <cstring> // For strlen
#include <iostream> // For std::ios_base, std::cin, std::tie
// Helper: Removes leading zeros from a string, handles "0" correctly
std::string removeLeadingZeros(std::string s) {
s.erase(0, s.find_first_not_of('0'));
return s.empty() ? "0" : s;
}
// Adds two large numbers (string representation)
std::string Add(std::string a, std::string b) {
if (a.length() < b.length()) std::swap(a, b); // Ensure 'a' is longer
std::string res = "";
int carry = 0;
for (int i = 0; i < a.length(); ++i) {
int sum = (a[a.length() - 1 - i] - '0') + carry;
if (i < b.length()) sum += (b[b.length() - 1 - i] - '0');
res += (sum % 10 + '0');
carry = sum / 10;
}
if (carry) res += (carry + '0');
std::reverse(res.begin(), res.end());
return removeLeadingZeros(res);
}
// Subtracts two large numbers (string representation)
std::string Sub(std::string a, std::string b) {
bool neg = false;
if (a.length() < b.length() || (a.length() == b.length() && a < b)) {
neg = true;
std::swap(a, b);
}
if (a == b) return "0";
std::string res = "";
int borrow = 0;
for (int i = 0; i < a.length(); ++i) {
int diff = (a[a.length() - 1 - i] - '0') - borrow;
if (i < b.length()) diff -= (b[b.length() - 1 - i] - '0');
if (diff < 0) { diff += 10; borrow = 1; } else { borrow = 0; }
res += (diff + '0');
}
std::reverse(res.begin(), res.end());
return (neg ? "-" : "") + removeLeadingZeros(res);
}
// Constant for PI, now using long double for higher precision
const long double PI = std::acos(-1.0L); // Use -1.0L for long double argument
// Fast Fourier Transform (FFT) function - now using long double
void fft(std::vector<std::complex<long double>>& a, bool invert) { // Changed template parameter to long double
int n = a.size();
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1; for (; j & bit; bit >>= 1) j ^= bit; j ^= bit;
if (i < j) std::swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
long double ang = 2 * PI / len * (invert ? -1 : 1); // Angle in long double
std::complex<long double> wlen(std::cos(ang), std::sin(ang)); // Complex numbers with long double
for (int i = 0; i < n; i += len) {
std::complex<long double> w(1); // Complex number with long double
for (int j = 0; j < len / 2; j++) {
std::complex<long double> u = a[i + j], v = a[i + j + len / 2] * w; // Complex operations
a[i + j] = u + v;
a[i + j + len / 2] = u - v;
w *= wlen;
}
}
}
if (invert) for (auto& x : a) x /= n;
}
// Multiplies two large numbers using FFT, returns string
std::string multiply_fft(const char* a_str, const char* b_str) {
// Define the base and how many decimal digits it represents
const int BASE = 10000; // Using base 10^4
const int POWER = 4; // Each FFT element represents 4 decimal digits
int na_len = std::strlen(a_str);
int nb_len = std::strlen(b_str);
std::vector<std::complex<long double>> fa, fb; // Changed vector element type to long double
// Convert string to vector of numbers in the chosen base (reverse order)
for (int i = na_len - 1; i >= 0; ) {
long long chunk = 0;
long long p10 = 1;
for (int j = 0; j < POWER && i >= 0; ++j) {
chunk += (a_str[i--] - '0') * p10;
p10 *= 10;
}
fa.push_back(static_cast<long double>(chunk)); // Cast chunk to long double
}
for (int i = nb_len - 1; i >= 0; ) {
long long chunk = 0;
long long p10 = 1;
for (int j = 0; j < POWER && i >= 0; ++j) {
chunk += (b_str[i--] - '0') * p10;
p10 *= 10;
}
fb.push_back(static_cast<long double>(chunk)); // Cast chunk to long double
}
// Determine FFT size
int fn = 1;
while (fn < fa.size() + fb.size() - 1) fn <<= 1; // Minimum length for product polynomial
fa.resize(fn); fb.resize(fn);
// Perform FFT on both coefficient vectors
fft(fa, false);
fft(fb, false);
// Multiply transformed polynomials element-wise
for (int i = 0; i < fn; ++i) fa[i] *= fb[i];
// Perform inverse FFT
fft(fa, true);
// Extract results and handle carries in the new base
std::vector<long long> res_digits; // No initial size, will resize as needed
long long carry = 0;
// Iterate until all coefficients are processed AND there is no more carry
for (int i = 0; i < fn || carry; ++i) {
// If current index goes beyond fa's original range (fn), then fa[i] is 0,
// so current_val just consists of the carry
long long current_val = (i < fn ? static_cast<long long>(std::round(fa[i].real())) : 0) + carry;
// Ensure res_digits can hold the current index. Resize grows only if necessary.
if (i >= res_digits.size()) {
res_digits.resize(i + 1);
}
res_digits[i] = current_val % BASE; // Current digit is current_val % BASE
carry = current_val / BASE; // Carry is current_val / BASE
}
// Remove leading zero chunks (from the most significant end)
// Only remove if more than one chunk and the last one is 0. This ensures "0" itself is not trimmed.
while (res_digits.size() > 1 && res_digits.back() == 0) {
res_digits.pop_back();
}
// Convert the result from the chosen base back to a decimal string
std::string result_str = "";
// Append the most significant chunk first
result_str += std::to_string(res_digits.back());
// Append remaining chunks, padding with leading zeros if necessary
for (int i = res_digits.size() - 2; i >= 0; --i) {
// Convert chunk to string
std::string s_chunk = std::to_string(res_digits[i]);
// Pad with leading zeros to POWER digits (e.g., "0078" for 78 if POWER=4)
result_str += std::string(POWER - s_chunk.length(), '0') + s_chunk;
}
return removeLeadingZeros(result_str); // Final cleanup for any overall leading zeros (e.g. from BASE logic or `to_string` for single '0' number)
}
char a_str_input[1000002];
char b_str_input[1000002];
int main() {
// Optimize C++ standard streams for competitive programming (though scanf/printf are used here)
std::ios_base::sync_with_stdio(false);
std::cin.tie(NULL); // This line requires #include <iostream>
char op;
// Read the two numbers and the operator. Cast to void to suppress -Wunused-result warning.
(void)scanf("%s %c %s", a_str_input, &op, b_str_input);
std::string answer;
switch (op) {
case '+': answer = Add(std::string(a_str_input), std::string(b_str_input)); break;
case '-': answer = Sub(std::string(a_str_input), std::string(b_str_input)); break;
case '*': answer = multiply_fft(a_str_input, b_str_input); break;
default:
fprintf(stderr, "Error: Unknown operator '%c'\n", op);
return 1;
}
printf("%s\n", answer.c_str());
return 0;
}
大佬我的程式哪邊還能優化?